mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-11 19:29:55 +00:00
Compare commits
15 Commits
refactor/p
...
v0.70.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7eba5dafd8 | ||
|
|
28fe26637b | ||
|
|
407e9d304b | ||
|
|
e5474e199f | ||
|
|
db44848e2d | ||
|
|
9417ce3b3a | ||
|
|
8fc4265995 | ||
|
|
9c50819f20 | ||
|
|
6f0eff3ba0 | ||
|
|
f8745723fc | ||
|
|
154b81645a | ||
|
|
34167c8a16 | ||
|
|
d6f08e4840 | ||
|
|
f732b01a05 | ||
|
|
c07c726ea7 |
158
.github/workflows/release.yml
vendored
158
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.2"
|
||||
SIGN_PIPE_VER: "v0.1.4"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -115,6 +115,12 @@ jobs:
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-latest-m
|
||||
outputs:
|
||||
release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }}
|
||||
linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }}
|
||||
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
|
||||
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
|
||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
@@ -213,10 +219,13 @@ jobs:
|
||||
if: always()
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: Tag and push images (amd64 only)
|
||||
id: tag_and_push_images
|
||||
if: |
|
||||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
|
||||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
resolve_tags() {
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
echo "pr-${{ github.event.pull_request.number }}"
|
||||
@@ -225,6 +234,17 @@ jobs:
|
||||
fi
|
||||
}
|
||||
|
||||
ghcr_package_url() {
|
||||
local image="$1" package encoded_package
|
||||
package="${image#ghcr.io/}"
|
||||
package="${package#*/}"
|
||||
package="${package%%:*}"
|
||||
encoded_package="${package//\//%2F}"
|
||||
echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}"
|
||||
}
|
||||
|
||||
image_refs=()
|
||||
|
||||
tag_and_push() {
|
||||
local src="$1" img_name tag dst
|
||||
img_name="${src%%:*}"
|
||||
@@ -233,35 +253,56 @@ jobs:
|
||||
echo "Tagging ${src} -> ${dst}"
|
||||
docker tag "$src" "$dst"
|
||||
docker push "$dst"
|
||||
image_refs+=("$dst")
|
||||
done
|
||||
}
|
||||
|
||||
export -f tag_and_push resolve_tags
|
||||
cat > /tmp/goreleaser-artifacts.json <<'JSON'
|
||||
${{ steps.goreleaser.outputs.artifacts }}
|
||||
JSON
|
||||
|
||||
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||
grep '^ghcr.io/' | while read -r SRC; do
|
||||
tag_and_push "$SRC"
|
||||
done
|
||||
mapfile -t src_images < <(
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
|
||||
)
|
||||
|
||||
for src in "${src_images[@]}"; do
|
||||
tag_and_push "$src"
|
||||
done
|
||||
|
||||
{
|
||||
echo "images_markdown<<EOF"
|
||||
if [[ ${#image_refs[@]} -eq 0 ]]; then
|
||||
echo "_No GHCR images were pushed._"
|
||||
else
|
||||
printf '%s\n' "${image_refs[@]}" | sort -u | while read -r image; do
|
||||
printf -- '- [`%s`](%s)\n' "$image" "$(ghcr_package_url "$image")"
|
||||
done
|
||||
fi
|
||||
echo "EOF"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 7
|
||||
- name: upload linux packages
|
||||
id: upload_linux_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 7
|
||||
- name: upload windows packages
|
||||
id: upload_windows_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 7
|
||||
- name: upload macos packages
|
||||
id: upload_macos_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: macos-packages
|
||||
@@ -270,6 +311,8 @@ jobs:
|
||||
|
||||
release_ui:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
@@ -360,6 +403,7 @@ jobs:
|
||||
if: always()
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-ui
|
||||
@@ -368,6 +412,8 @@ jobs:
|
||||
|
||||
release_ui_darwin:
|
||||
runs-on: macos-latest
|
||||
outputs:
|
||||
release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }}
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
@@ -402,12 +448,110 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui_darwin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
retention-days: 3
|
||||
|
||||
comment_release_artifacts:
|
||||
name: Comment release artifacts
|
||||
runs-on: ubuntu-latest
|
||||
needs: [release, release_ui, release_ui_darwin]
|
||||
if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Create or update PR comment
|
||||
uses: actions/github-script@v7
|
||||
env:
|
||||
RELEASE_RESULT: ${{ needs.release.result }}
|
||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
||||
RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }}
|
||||
RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }}
|
||||
LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }}
|
||||
WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }}
|
||||
MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }}
|
||||
RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }}
|
||||
RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }}
|
||||
GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }}
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const marker = '<!-- netbird-release-artifacts -->';
|
||||
const { owner, repo } = context.repo;
|
||||
const issue_number = context.payload.pull_request.number;
|
||||
const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`;
|
||||
const shortSha = context.payload.pull_request.head.sha.slice(0, 7);
|
||||
|
||||
const artifactCell = (url, result) => {
|
||||
if (url) return `[Download](${url})`;
|
||||
return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_';
|
||||
};
|
||||
|
||||
const artifacts = [
|
||||
['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
||||
['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT],
|
||||
['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT],
|
||||
];
|
||||
|
||||
const artifactRows = artifacts
|
||||
.map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`)
|
||||
.join('\n');
|
||||
|
||||
const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._';
|
||||
|
||||
const body = [
|
||||
marker,
|
||||
'## Release artifacts',
|
||||
'',
|
||||
`Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`,
|
||||
'',
|
||||
'| Artifact | Link |',
|
||||
'| --- | --- |',
|
||||
artifactRows,
|
||||
'',
|
||||
'### GHCR images (amd64)',
|
||||
ghcrImages,
|
||||
'',
|
||||
'_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._',
|
||||
].join('\n');
|
||||
|
||||
const comments = await github.paginate(github.rest.issues.listComments, {
|
||||
owner,
|
||||
repo,
|
||||
issue_number,
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
const previous = comments.find(comment =>
|
||||
comment.user?.type === 'Bot' && comment.body?.includes(marker)
|
||||
);
|
||||
|
||||
if (previous) {
|
||||
await github.rest.issues.updateComment({
|
||||
owner,
|
||||
repo,
|
||||
comment_id: previous.id,
|
||||
body,
|
||||
});
|
||||
core.info(`Updated release artifacts comment ${previous.id}`);
|
||||
} else {
|
||||
const { data } = await github.rest.issues.createComment({
|
||||
owner,
|
||||
repo,
|
||||
issue_number,
|
||||
body,
|
||||
});
|
||||
core.info(`Created release artifacts comment ${data.id}`);
|
||||
}
|
||||
|
||||
trigger_signer:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [release, release_ui, release_ui_darwin]
|
||||
|
||||
@@ -13,8 +13,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
@@ -31,6 +29,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -98,7 +97,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersmanager := peers.NewManager(store)
|
||||
peersmanager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersmanager)
|
||||
|
||||
@@ -201,7 +201,16 @@ Pop $0
|
||||
|
||||
Function .onInit
|
||||
StrCpy $INSTDIR "${INSTALL_DIR}"
|
||||
|
||||
; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live
|
||||
; in the 32-bit view. Fall back to it so upgrades still find them.
|
||||
SetRegView 64
|
||||
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
||||
${If} $R0 == ""
|
||||
SetRegView 32
|
||||
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
||||
SetRegView 64
|
||||
${EndIf}
|
||||
${If} $R0 != ""
|
||||
# if silent install jump to uninstall step
|
||||
IfSilent uninstall
|
||||
@@ -214,6 +223,10 @@ ${If} $R0 != ""
|
||||
|
||||
${EndIf}
|
||||
FunctionEnd
|
||||
|
||||
Function un.onInit
|
||||
SetRegView 64
|
||||
FunctionEnd
|
||||
######################################################################
|
||||
Section -MainProgram
|
||||
${INSTALL_TYPE}
|
||||
@@ -228,6 +241,7 @@ Section -MainProgram
|
||||
!else
|
||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||
!endif
|
||||
File "..\\client\\ui\\assets\\netbird.png"
|
||||
SectionEnd
|
||||
######################################################################
|
||||
|
||||
@@ -247,9 +261,11 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
; Create autostart registry entry based on checkbox
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
|
||||
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
@@ -283,6 +299,8 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart registry entry
|
||||
DetailPrint "Removing autostart registry entry if exists..."
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
|
||||
; Handle data deletion based on checkbox
|
||||
@@ -321,6 +339,7 @@ DetailPrint "Removing registry keys..."
|
||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKCU "Software\Classes\AppUserModelId\${APP_NAME}"
|
||||
|
||||
DetailPrint "Removing application directory from PATH..."
|
||||
EnVar::SetHKLM
|
||||
|
||||
@@ -333,6 +333,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.statusRecorder.MarkSignalConnected()
|
||||
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
||||
relayURLs = override
|
||||
}
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||
|
||||
@@ -944,7 +944,12 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
return fmt.Errorf("update relay token: %w", err)
|
||||
}
|
||||
|
||||
e.relayManager.UpdateServerURLs(update.Urls)
|
||||
urls := update.Urls
|
||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
||||
urls = override
|
||||
}
|
||||
e.relayManager.UpdateServerURLs(urls)
|
||||
|
||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
@@ -58,6 +57,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -1632,7 +1632,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
|
||||
@@ -7,7 +7,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
|
||||
)
|
||||
|
||||
func IsForceRelayed() bool {
|
||||
@@ -16,3 +17,28 @@ func IsForceRelayed() bool {
|
||||
}
|
||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||
}
|
||||
|
||||
// OverrideRelayURLs returns the relay server URL list set in
|
||||
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
|
||||
// the override is active. When the env var is unset, the boolean is false
|
||||
// and the caller should keep the list received from the management server.
|
||||
// Intended for lab/debug scenarios where a peer must pin to a specific home
|
||||
// relay regardless of what management offers.
|
||||
func OverrideRelayURLs() ([]string, bool) {
|
||||
raw := os.Getenv(EnvKeyNBHomeRelayServers)
|
||||
if raw == "" {
|
||||
return nil, false
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
urls := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
urls = append(urls, p)
|
||||
}
|
||||
}
|
||||
if len(urls) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return urls, true
|
||||
}
|
||||
|
||||
@@ -2,217 +2,358 @@
|
||||
|
||||
package sleep
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
|
||||
#include <IOKit/pwr_mgt/IOPMLib.h>
|
||||
#include <IOKit/IOMessage.h>
|
||||
#include <CoreFoundation/CoreFoundation.h>
|
||||
|
||||
extern void sleepCallbackBridge();
|
||||
extern void poweredOnCallbackBridge();
|
||||
extern void suspendedCallbackBridge();
|
||||
extern void resumedCallbackBridge();
|
||||
|
||||
|
||||
// C global variables for IOKit state
|
||||
static IONotificationPortRef g_notifyPortRef = NULL;
|
||||
static io_object_t g_notifierObject = 0;
|
||||
static io_object_t g_generalInterestNotifier = 0;
|
||||
static io_connect_t g_rootPort = 0;
|
||||
static CFRunLoopRef g_runLoop = NULL;
|
||||
|
||||
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
|
||||
switch (messageType) {
|
||||
case kIOMessageSystemWillSleep:
|
||||
sleepCallbackBridge();
|
||||
IOAllowPowerChange(g_rootPort, (long)messageArgument);
|
||||
break;
|
||||
case kIOMessageSystemHasPoweredOn:
|
||||
poweredOnCallbackBridge();
|
||||
break;
|
||||
case kIOMessageServiceIsSuspended:
|
||||
suspendedCallbackBridge();
|
||||
break;
|
||||
case kIOMessageServiceIsResumed:
|
||||
resumedCallbackBridge();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void registerNotifications() {
|
||||
g_rootPort = IORegisterForSystemPower(
|
||||
NULL,
|
||||
&g_notifyPortRef,
|
||||
(IOServiceInterestCallback)sleepCallback,
|
||||
&g_notifierObject
|
||||
);
|
||||
|
||||
if (g_rootPort == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
CFRunLoopAddSource(CFRunLoopGetCurrent(),
|
||||
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||
kCFRunLoopCommonModes);
|
||||
|
||||
g_runLoop = CFRunLoopGetCurrent();
|
||||
CFRunLoopRun();
|
||||
}
|
||||
|
||||
static void unregisterNotifications() {
|
||||
CFRunLoopRemoveSource(g_runLoop,
|
||||
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||
kCFRunLoopCommonModes);
|
||||
|
||||
IODeregisterForSystemPower(&g_notifierObject);
|
||||
IOServiceClose(g_rootPort);
|
||||
IONotificationPortDestroy(g_notifyPortRef);
|
||||
CFRunLoopStop(g_runLoop);
|
||||
|
||||
g_notifyPortRef = NULL;
|
||||
g_notifierObject = 0;
|
||||
g_rootPort = 0;
|
||||
g_runLoop = NULL;
|
||||
}
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
serviceRegistry = make(map[*Detector]struct{})
|
||||
serviceRegistryMu sync.Mutex
|
||||
// IOKit message types from IOKit/IOMessage.h.
|
||||
const (
|
||||
kIOMessageCanSystemSleep uintptr = 0xe0000270
|
||||
kIOMessageSystemWillSleep uintptr = 0xe0000280
|
||||
kIOMessageSystemHasPoweredOn uintptr = 0xe0000300
|
||||
)
|
||||
|
||||
//export sleepCallbackBridge
|
||||
func sleepCallbackBridge() {
|
||||
log.Info("sleepCallbackBridge event triggered")
|
||||
var (
|
||||
ioKit iokitFuncs
|
||||
cf cfFuncs
|
||||
cfCommonModes uintptr
|
||||
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
libInitOnce sync.Once
|
||||
libInitErr error
|
||||
|
||||
for svc := range serviceRegistry {
|
||||
svc.triggerCallback(EventTypeSleep)
|
||||
}
|
||||
// callbackThunk is the single C-callable trampoline registered with IOKit.
|
||||
callbackThunk uintptr
|
||||
|
||||
serviceRegistry = make(map[*Detector]struct{})
|
||||
serviceRegistryMu sync.Mutex
|
||||
session *runLoopSession
|
||||
|
||||
// lifecycleMu serializes Register/Deregister so a new registration can't
|
||||
// start a second runloop while a previous teardown is still pending.
|
||||
lifecycleMu sync.Mutex
|
||||
)
|
||||
|
||||
// iokitFuncs holds IOKit symbols resolved once at init.
|
||||
type iokitFuncs struct {
|
||||
IORegisterForSystemPower func(refcon uintptr, portRef *uintptr, callback uintptr, notifier *uintptr) uintptr
|
||||
IODeregisterForSystemPower func(notifier *uintptr) int32
|
||||
IOAllowPowerChange func(kernelPort uintptr, notificationID uintptr) int32
|
||||
IOServiceClose func(connect uintptr) int32
|
||||
IONotificationPortGetRunLoopSource func(port uintptr) uintptr
|
||||
IONotificationPortDestroy func(port uintptr)
|
||||
}
|
||||
|
||||
//export resumedCallbackBridge
|
||||
func resumedCallbackBridge() {
|
||||
log.Info("resumedCallbackBridge event triggered")
|
||||
// cfFuncs holds CoreFoundation symbols resolved once at init.
|
||||
type cfFuncs struct {
|
||||
CFRunLoopGetCurrent func() uintptr
|
||||
CFRunLoopRun func()
|
||||
CFRunLoopStop func(rl uintptr)
|
||||
CFRunLoopAddSource func(rl, source, mode uintptr)
|
||||
CFRunLoopRemoveSource func(rl, source, mode uintptr)
|
||||
}
|
||||
|
||||
//export suspendedCallbackBridge
|
||||
func suspendedCallbackBridge() {
|
||||
log.Info("suspendedCallbackBridge event triggered")
|
||||
// runLoopSession bundles the handles owned by one CFRunLoop lifetime. A nil
|
||||
// session means no runloop is active and the next Register must start one.
|
||||
type runLoopSession struct {
|
||||
rl uintptr
|
||||
port uintptr
|
||||
notifier uintptr
|
||||
rp uintptr
|
||||
}
|
||||
|
||||
//export poweredOnCallbackBridge
|
||||
func poweredOnCallbackBridge() {
|
||||
log.Info("poweredOnCallbackBridge event triggered")
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
for svc := range serviceRegistry {
|
||||
svc.triggerCallback(EventTypeWakeUp)
|
||||
}
|
||||
// detectorSnapshot pins a detector's callback and done channel so dispatch
|
||||
// runs with values valid at snapshot time, even if a concurrent
|
||||
// Deregister/Register rewrites the detector's fields.
|
||||
type detectorSnapshot struct {
|
||||
detector *Detector
|
||||
callback func(event EventType)
|
||||
done <-chan struct{}
|
||||
}
|
||||
|
||||
// Detector delivers sleep and wake events to a registered callback.
|
||||
type Detector struct {
|
||||
callback func(event EventType)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewDetector() (*Detector, error) {
|
||||
return &Detector{}, nil
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// Register installs callback for power events. The first registration starts
|
||||
// the CFRunLoop on a dedicated OS-locked thread and blocks until IOKit
|
||||
// registration succeeds or fails; subsequent registrations just add to the
|
||||
// dispatch set.
|
||||
func (d *Detector) Register(callback func(event EventType)) error {
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
lifecycleMu.Lock()
|
||||
defer lifecycleMu.Unlock()
|
||||
|
||||
serviceRegistryMu.Lock()
|
||||
if _, exists := serviceRegistry[d]; exists {
|
||||
serviceRegistryMu.Unlock()
|
||||
return fmt.Errorf("detector service already registered")
|
||||
}
|
||||
|
||||
d.callback = callback
|
||||
d.done = make(chan struct{})
|
||||
serviceRegistry[d] = struct{}{}
|
||||
needSetup := session == nil
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||
|
||||
if len(serviceRegistry) > 0 {
|
||||
serviceRegistry[d] = struct{}{}
|
||||
if !needSetup {
|
||||
return nil
|
||||
}
|
||||
|
||||
serviceRegistry[d] = struct{}{}
|
||||
|
||||
// CFRunLoop must run on a single fixed OS thread
|
||||
go func() {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
C.registerNotifications()
|
||||
}()
|
||||
errCh := make(chan error, 1)
|
||||
go runRunLoop(errCh)
|
||||
if err := <-errCh; err != nil {
|
||||
serviceRegistryMu.Lock()
|
||||
delete(serviceRegistry, d)
|
||||
close(d.done)
|
||||
d.done = nil
|
||||
serviceRegistryMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("sleep detection service started on macOS")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
|
||||
// and the runloop is stopped and cleaned up.
|
||||
// Deregister removes the detector. When the last detector leaves, IOKit
|
||||
// notifications are torn down and the runloop is stopped.
|
||||
func (d *Detector) Deregister() error {
|
||||
lifecycleMu.Lock()
|
||||
defer lifecycleMu.Unlock()
|
||||
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
_, exists := serviceRegistry[d]
|
||||
if !exists {
|
||||
if _, exists := serviceRegistry[d]; !exists {
|
||||
serviceRegistryMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// cancel and remove this detector
|
||||
d.cancel()
|
||||
close(d.done)
|
||||
delete(serviceRegistry, d)
|
||||
|
||||
// If other Detectors still exist, leave IOKit running
|
||||
if len(serviceRegistry) > 0 {
|
||||
serviceRegistryMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
sess := session
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
log.Info("sleep detection service stopping (deregister)")
|
||||
|
||||
// Deregister IOKit notifications, stop runloop, and free resources
|
||||
C.unregisterNotifications()
|
||||
if sess == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if sess.rl != 0 && sess.port != 0 {
|
||||
source := ioKit.IONotificationPortGetRunLoopSource(sess.port)
|
||||
cf.CFRunLoopRemoveSource(sess.rl, source, cfCommonModes)
|
||||
}
|
||||
if sess.notifier != 0 {
|
||||
n := sess.notifier
|
||||
ioKit.IODeregisterForSystemPower(&n)
|
||||
}
|
||||
|
||||
// Clear session only after IODeregisterForSystemPower returns so any
|
||||
// in-flight powerCallback can still look up session.rp to ack sleep.
|
||||
serviceRegistryMu.Lock()
|
||||
session = nil
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
if sess.rp != 0 {
|
||||
ioKit.IOServiceClose(sess.rp)
|
||||
}
|
||||
if sess.port != 0 {
|
||||
ioKit.IONotificationPortDestroy(sess.port)
|
||||
}
|
||||
if sess.rl != 0 {
|
||||
cf.CFRunLoopStop(sess.rl)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Detector) triggerCallback(event EventType) {
|
||||
doneChan := make(chan struct{})
|
||||
func (d *Detector) triggerCallback(event EventType, cb func(event EventType), done <-chan struct{}) {
|
||||
if cb == nil || done == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
timeout := time.NewTimer(500 * time.Millisecond)
|
||||
defer timeout.Stop()
|
||||
|
||||
cb := d.callback
|
||||
go func(callback func(event EventType)) {
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("panic in sleep callback: %v", r)
|
||||
}
|
||||
}()
|
||||
log.Info("sleep detection event fired")
|
||||
callback(event)
|
||||
close(doneChan)
|
||||
}(cb)
|
||||
cb(event)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-doneChan:
|
||||
case <-d.ctx.Done():
|
||||
case <-done:
|
||||
case <-timeout.C:
|
||||
log.Warnf("sleep callback timed out")
|
||||
log.Warn("sleep callback timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// NewDetector initializes IOKit/CoreFoundation bindings and returns a Detector.
|
||||
func NewDetector() (*Detector, error) {
|
||||
if err := initLibs(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Detector{}, nil
|
||||
}
|
||||
|
||||
func initLibs() error {
|
||||
libInitOnce.Do(func() {
|
||||
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libInitErr = fmt.Errorf("dlopen IOKit: %w", err)
|
||||
return
|
||||
}
|
||||
cfLib, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libInitErr = fmt.Errorf("dlopen CoreFoundation: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&ioKit.IORegisterForSystemPower, iokit, "IORegisterForSystemPower")
|
||||
purego.RegisterLibFunc(&ioKit.IODeregisterForSystemPower, iokit, "IODeregisterForSystemPower")
|
||||
purego.RegisterLibFunc(&ioKit.IOAllowPowerChange, iokit, "IOAllowPowerChange")
|
||||
purego.RegisterLibFunc(&ioKit.IOServiceClose, iokit, "IOServiceClose")
|
||||
purego.RegisterLibFunc(&ioKit.IONotificationPortGetRunLoopSource, iokit, "IONotificationPortGetRunLoopSource")
|
||||
purego.RegisterLibFunc(&ioKit.IONotificationPortDestroy, iokit, "IONotificationPortDestroy")
|
||||
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopGetCurrent, cfLib, "CFRunLoopGetCurrent")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopRun, cfLib, "CFRunLoopRun")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopStop, cfLib, "CFRunLoopStop")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopAddSource, cfLib, "CFRunLoopAddSource")
|
||||
purego.RegisterLibFunc(&cf.CFRunLoopRemoveSource, cfLib, "CFRunLoopRemoveSource")
|
||||
|
||||
modeAddr, err := purego.Dlsym(cfLib, "kCFRunLoopCommonModes")
|
||||
if err != nil {
|
||||
libInitErr = fmt.Errorf("dlsym kCFRunLoopCommonModes: %w", err)
|
||||
return
|
||||
}
|
||||
// Launder the uintptr-to-pointer conversion through a Go variable so
|
||||
// go vet's unsafeptr analyzer doesn't flag a system-library global.
|
||||
cfCommonModes = **(**uintptr)(unsafe.Pointer(&modeAddr))
|
||||
|
||||
// NewCallback slots are a finite, non-reclaimable resource, so register
|
||||
// a single thunk that dispatches to the current Detector set.
|
||||
callbackThunk = purego.NewCallback(powerCallback)
|
||||
})
|
||||
return libInitErr
|
||||
}
|
||||
|
||||
// powerCallback is the IOServiceInterestCallback trampoline, invoked on the
|
||||
// runloop thread. A Go panic crossing the purego boundary has undefined
|
||||
// behavior, so contain it here.
|
||||
func powerCallback(refcon, service, messageType, messageArgument uintptr) uintptr {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("panic in sleep powerCallback: %v", r)
|
||||
}
|
||||
}()
|
||||
switch messageType {
|
||||
case kIOMessageCanSystemSleep:
|
||||
// Not acknowledging forces a 30s IOKit timeout before idle sleep.
|
||||
allowPowerChange(messageArgument)
|
||||
case kIOMessageSystemWillSleep:
|
||||
dispatchEvent(EventTypeSleep)
|
||||
allowPowerChange(messageArgument)
|
||||
case kIOMessageSystemHasPoweredOn:
|
||||
dispatchEvent(EventTypeWakeUp)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func allowPowerChange(messageArgument uintptr) {
|
||||
serviceRegistryMu.Lock()
|
||||
var port uintptr
|
||||
if session != nil {
|
||||
port = session.rp
|
||||
}
|
||||
serviceRegistryMu.Unlock()
|
||||
if port != 0 {
|
||||
ioKit.IOAllowPowerChange(port, messageArgument)
|
||||
}
|
||||
}
|
||||
|
||||
func dispatchEvent(event EventType) {
|
||||
serviceRegistryMu.Lock()
|
||||
snaps := make([]detectorSnapshot, 0, len(serviceRegistry))
|
||||
for d := range serviceRegistry {
|
||||
snaps = append(snaps, detectorSnapshot{
|
||||
detector: d,
|
||||
callback: d.callback,
|
||||
done: d.done,
|
||||
})
|
||||
}
|
||||
serviceRegistryMu.Unlock()
|
||||
|
||||
for _, s := range snaps {
|
||||
s.detector.triggerCallback(event, s.callback, s.done)
|
||||
}
|
||||
}
|
||||
|
||||
// runRunLoop owns the OS-locked thread that CFRunLoop is pinned to. Setup
|
||||
// result is reported on errCh so Register can surface failures synchronously.
|
||||
func runRunLoop(errCh chan<- error) {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
sess, err := setupSession()
|
||||
if err == nil {
|
||||
serviceRegistryMu.Lock()
|
||||
session = sess
|
||||
serviceRegistryMu.Unlock()
|
||||
}
|
||||
errCh <- err
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("panic in sleep runloop: %v", r)
|
||||
}
|
||||
}()
|
||||
cf.CFRunLoopRun()
|
||||
}
|
||||
|
||||
// setupSession performs the IOKit registration on the current thread. Panics
|
||||
// are converted to errors so runRunLoop never leaves errCh unsent.
|
||||
func setupSession() (s *runLoopSession, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic during runloop setup: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
var portRef, notifier uintptr
|
||||
rp := ioKit.IORegisterForSystemPower(0, &portRef, callbackThunk, ¬ifier)
|
||||
if rp == 0 {
|
||||
return nil, fmt.Errorf("IORegisterForSystemPower returned zero")
|
||||
}
|
||||
|
||||
rl := cf.CFRunLoopGetCurrent()
|
||||
source := ioKit.IONotificationPortGetRunLoopSource(portRef)
|
||||
cf.CFRunLoopAddSource(rl, source, cfCommonModes)
|
||||
|
||||
return &runLoopSession{rl: rl, port: portRef, notifier: notifier, rp: rp}, nil
|
||||
}
|
||||
|
||||
@@ -18,10 +18,17 @@
|
||||
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
|
||||
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon">
|
||||
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" />
|
||||
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
|
||||
</Shortcut>
|
||||
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon">
|
||||
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" />
|
||||
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
|
||||
</Shortcut>
|
||||
</File>
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
|
||||
<File Id="NetbirdToastIcon" Name="netbird.png" Source=".\client\ui\assets\netbird.png" />
|
||||
<?if $(var.ArchSuffix) = "amd64" ?>
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
|
||||
<?endif ?>
|
||||
@@ -46,8 +53,19 @@
|
||||
</Directory>
|
||||
</StandardDirectory>
|
||||
|
||||
<!-- Per-user component: HKCU keypath (auto GUID via "*"), separate from
|
||||
the per-machine NetbirdFiles component to satisfy ICE57. -->
|
||||
<StandardDirectory Id="ProgramMenuFolder">
|
||||
<Component Id="NetbirdAumidRegistry" Guid="*">
|
||||
<RegistryKey Root="HKCU" Key="Software\Classes\AppUserModelId\NetBird" ForceDeleteOnUninstall="yes">
|
||||
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
|
||||
</RegistryKey>
|
||||
</Component>
|
||||
</StandardDirectory>
|
||||
|
||||
<ComponentGroup Id="NetbirdFilesComponent">
|
||||
<ComponentRef Id="NetbirdFiles" />
|
||||
<ComponentRef Id="NetbirdAumidRegistry" />
|
||||
</ComponentGroup>
|
||||
|
||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -104,8 +104,6 @@ service DaemonService {
|
||||
// StopCPUProfile stops CPU profiling in the daemon
|
||||
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
|
||||
|
||||
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||
|
||||
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
||||
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
@@ -114,20 +112,6 @@ service DaemonService {
|
||||
|
||||
|
||||
|
||||
message OSLifecycleRequest {
|
||||
// avoid collision with loglevel enum
|
||||
enum CycleType {
|
||||
UNKNOWN = 0;
|
||||
SLEEP = 1;
|
||||
WAKEUP = 2;
|
||||
}
|
||||
|
||||
CycleType type = 1;
|
||||
}
|
||||
|
||||
message OSLifecycleResponse {}
|
||||
|
||||
|
||||
message LoginRequest {
|
||||
// setupKey netbird setup key.
|
||||
string setupKey = 1;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -120,6 +120,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
}
|
||||
agent := &serverAgent{s}
|
||||
s.sleepHandler = sleephandler.New(agent)
|
||||
s.startSleepDetector()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -15,8 +15,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
@@ -40,6 +38,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -306,7 +305,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersManager := peers.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
@@ -2,13 +2,18 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/sleep"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const envDisableSleepDetector = "NB_DISABLE_SLEEP_DETECTOR"
|
||||
|
||||
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
|
||||
type serverAgent struct {
|
||||
s *Server
|
||||
@@ -28,19 +33,61 @@ func (a *serverAgent) Status() (internal.StatusType, error) {
|
||||
return internal.CtxGetState(a.s.rootCtx).Status()
|
||||
}
|
||||
|
||||
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||
switch req.GetType() {
|
||||
case proto.OSLifecycleRequest_WAKEUP:
|
||||
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
case proto.OSLifecycleRequest_SLEEP:
|
||||
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
default:
|
||||
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||
// startSleepDetector starts the OS sleep/wake detector and forwards events to
|
||||
// the sleep handler. On platforms without a supported detector the attempt
|
||||
// logs a warning and returns. Setting NB_DISABLE_SLEEP_DETECTOR=true skips
|
||||
// registration entirely.
|
||||
func (s *Server) startSleepDetector() {
|
||||
if sleepDetectorDisabled() {
|
||||
log.Info("sleep detection disabled via " + envDisableSleepDetector)
|
||||
return
|
||||
}
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
|
||||
svc, err := sleep.New()
|
||||
if err != nil {
|
||||
log.Warnf("failed to initialize sleep detection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = svc.Register(func(event sleep.EventType) {
|
||||
switch event {
|
||||
case sleep.EventTypeSleep:
|
||||
log.Info("handling sleep event")
|
||||
if err := s.sleepHandler.HandleSleep(s.rootCtx); err != nil {
|
||||
log.Errorf("failed to handle sleep event: %v", err)
|
||||
}
|
||||
case sleep.EventTypeWakeUp:
|
||||
log.Info("handling wakeup event")
|
||||
if err := s.sleepHandler.HandleWakeUp(s.rootCtx); err != nil {
|
||||
log.Errorf("failed to handle wakeup event: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed to register sleep detector: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("sleep detection service initialized")
|
||||
|
||||
go func() {
|
||||
<-s.rootCtx.Done()
|
||||
log.Info("stopping sleep event listener")
|
||||
if err := svc.Deregister(); err != nil {
|
||||
log.Errorf("failed to deregister sleep detector: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func sleepDetectorDisabled() bool {
|
||||
val := os.Getenv(envDisableSleepDetector)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
disabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s=%q: %v", envDisableSleepDetector, val, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
@@ -38,10 +38,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/sleep"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||
"github.com/netbirdio/netbird/client/ui/event"
|
||||
"github.com/netbirdio/netbird/client/ui/notifier"
|
||||
"github.com/netbirdio/netbird/client/ui/process"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -260,6 +260,7 @@ type serviceClient struct {
|
||||
|
||||
// application with main windows.
|
||||
app fyne.App
|
||||
notifier notifier.Notifier
|
||||
wSettings fyne.Window
|
||||
showAdvancedSettings bool
|
||||
sendNotification bool
|
||||
@@ -364,6 +365,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
cancel: cancel,
|
||||
addr: args.addr,
|
||||
app: args.app,
|
||||
notifier: notifier.New(args.app),
|
||||
logFile: args.logFile,
|
||||
sendNotification: false,
|
||||
|
||||
@@ -892,7 +894,7 @@ func (s *serviceClient) updateStatus() error {
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
if s.connected {
|
||||
s.app.SendNotification(fyne.NewNotification("Error", "Connection to service lost"))
|
||||
s.notifier.Send("Error", "Connection to service lost")
|
||||
}
|
||||
s.setDisconnectedStatus()
|
||||
return err
|
||||
@@ -1109,7 +1111,7 @@ func (s *serviceClient) onTrayReady() {
|
||||
}
|
||||
}()
|
||||
|
||||
s.eventManager = event.NewManager(s.app, s.addr)
|
||||
s.eventManager = event.NewManager(s.notifier, s.addr)
|
||||
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
||||
if event.Category == proto.SystemEvent_SYSTEM {
|
||||
@@ -1146,9 +1148,6 @@ func (s *serviceClient) onTrayReady() {
|
||||
|
||||
go s.eventManager.Start(s.ctx)
|
||||
go s.eventHandler.listen(s.ctx)
|
||||
|
||||
// Start sleep detection listener
|
||||
go s.startSleepListener()
|
||||
}
|
||||
|
||||
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
|
||||
@@ -1209,62 +1208,6 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
|
||||
return s.conn, nil
|
||||
}
|
||||
|
||||
// startSleepListener initializes the sleep detection service and listens for sleep events
|
||||
func (s *serviceClient) startSleepListener() {
|
||||
sleepService, err := sleep.New()
|
||||
if err != nil {
|
||||
log.Warnf("%v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := sleepService.Register(s.handleSleepEvents); err != nil {
|
||||
log.Errorf("failed to start sleep detection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("sleep detection service initialized")
|
||||
|
||||
// Cleanup on context cancellation
|
||||
go func() {
|
||||
<-s.ctx.Done()
|
||||
log.Info("stopping sleep event listener")
|
||||
if err := sleepService.Deregister(); err != nil {
|
||||
log.Errorf("failed to deregister sleep detection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleSleepEvents sends a sleep notification to the daemon via gRPC
|
||||
func (s *serviceClient) handleSleepEvents(event sleep.EventType) {
|
||||
conn, err := s.getSrvClient(0)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get daemon client for sleep notification: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
req := &proto.OSLifecycleRequest{}
|
||||
|
||||
switch event {
|
||||
case sleep.EventTypeWakeUp:
|
||||
log.Infof("handle wakeup event: %v", event)
|
||||
req.Type = proto.OSLifecycleRequest_WAKEUP
|
||||
case sleep.EventTypeSleep:
|
||||
log.Infof("handle sleep event: %v", event)
|
||||
req.Type = proto.OSLifecycleRequest_SLEEP
|
||||
default:
|
||||
log.Infof("unknown event: %v", event)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.NotifyOSLifecycle(s.ctx, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to notify daemon about os lifecycle notification: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("successfully notified daemon about os lifecycle")
|
||||
}
|
||||
|
||||
// setSettingsEnabled enables or disables the settings menu based on the provided state
|
||||
func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
||||
if s.mSettings != nil {
|
||||
@@ -1548,7 +1491,7 @@ func (s *serviceClient) onUpdateAvailable(newVersion string, enforced bool) {
|
||||
|
||||
if enforced && s.lastNotifiedVersion != newVersion {
|
||||
s.lastNotifiedVersion = newVersion
|
||||
s.app.SendNotification(fyne.NewNotification("Update available", "A new version "+newVersion+" is ready to install"))
|
||||
s.notifier.Send("Update available", "A new version "+newVersion+" is ready to install")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
@@ -18,11 +17,17 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||
)
|
||||
|
||||
// Notifier sends desktop notifications. Defined here so the event package
|
||||
// does not depend on fyne or the platform-specific notifier implementation.
|
||||
type Notifier interface {
|
||||
Send(title, body string)
|
||||
}
|
||||
|
||||
type Handler func(*proto.SystemEvent)
|
||||
|
||||
type Manager struct {
|
||||
app fyne.App
|
||||
addr string
|
||||
notifier Notifier
|
||||
addr string
|
||||
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
@@ -31,10 +36,10 @@ type Manager struct {
|
||||
handlers []Handler
|
||||
}
|
||||
|
||||
func NewManager(app fyne.App, addr string) *Manager {
|
||||
func NewManager(notifier Notifier, addr string) *Manager {
|
||||
return &Manager{
|
||||
app: app,
|
||||
addr: addr,
|
||||
notifier: notifier,
|
||||
addr: addr,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,7 +119,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) {
|
||||
if id != "" {
|
||||
body += fmt.Sprintf(" ID: %s", id)
|
||||
}
|
||||
e.app.SendNotification(fyne.NewNotification(title, body))
|
||||
e.notifier.Send(title, body)
|
||||
}
|
||||
|
||||
for _, handler := range handlers {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/systray"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -87,7 +86,7 @@ func (h *eventHandler) handleConnectClick() {
|
||||
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
|
||||
log.Debugf("connect operation cancelled by user")
|
||||
} else {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
|
||||
h.client.notifier.Send("Error", "Failed to connect")
|
||||
log.Errorf("connect failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -112,7 +111,7 @@ func (h *eventHandler) handleDisconnectClick() {
|
||||
if err := h.client.menuDownClick(); err != nil {
|
||||
st, ok := status.FromError(err)
|
||||
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
|
||||
h.client.notifier.Send("Error", "Failed to disconnect")
|
||||
log.Errorf("disconnect failed: %v", err)
|
||||
} else {
|
||||
log.Debugf("disconnect cancelled or already disconnecting")
|
||||
@@ -130,7 +129,7 @@ func (h *eventHandler) handleAllowSSHClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update SSH settings"))
|
||||
h.client.notifier.Send("Error", "Failed to update SSH settings")
|
||||
}
|
||||
|
||||
}
|
||||
@@ -140,7 +139,7 @@ func (h *eventHandler) handleAutoConnectClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update auto-connect settings"))
|
||||
h.client.notifier.Send("Error", "Failed to update auto-connect settings")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,7 +148,7 @@ func (h *eventHandler) handleRosenpassClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update Rosenpass settings"))
|
||||
h.client.notifier.Send("Error", "Failed to update Rosenpass settings")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +157,7 @@ func (h *eventHandler) handleLazyConnectionClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update lazy connection settings"))
|
||||
h.client.notifier.Send("Error", "Failed to update lazy connection settings")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,7 +166,7 @@ func (h *eventHandler) handleBlockInboundClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update block inbound settings"))
|
||||
h.client.notifier.Send("Error", "Failed to update block inbound settings")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,7 +175,7 @@ func (h *eventHandler) handleNotificationsClick() {
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update notifications settings"))
|
||||
h.client.notifier.Send("Error", "Failed to update notifications settings")
|
||||
} else if h.client.eventManager != nil {
|
||||
h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked())
|
||||
}
|
||||
|
||||
27
client/ui/notifier/notifier.go
Normal file
27
client/ui/notifier/notifier.go
Normal file
@@ -0,0 +1,27 @@
|
||||
// Package notifier sends desktop notifications. On Windows it uses the WinRT
|
||||
// COM API directly via go-toast/v2 to avoid the PowerShell window flash that
|
||||
// fyne's default implementation produces. On other platforms it delegates to
|
||||
// fyne.
|
||||
package notifier
|
||||
|
||||
import "fyne.io/fyne/v2"
|
||||
|
||||
// Notifier sends desktop notifications.
|
||||
type Notifier interface {
|
||||
Send(title, body string)
|
||||
}
|
||||
|
||||
// New returns a platform-specific Notifier. The fyne app is used as the
|
||||
// fallback notifier on platforms where no native implementation is wired up,
|
||||
// and on Windows when the COM path fails to initialize.
|
||||
func New(app fyne.App) Notifier {
|
||||
return newNotifier(app)
|
||||
}
|
||||
|
||||
type fyneNotifier struct {
|
||||
app fyne.App
|
||||
}
|
||||
|
||||
func (f *fyneNotifier) Send(title, body string) {
|
||||
f.app.SendNotification(fyne.NewNotification(title, body))
|
||||
}
|
||||
9
client/ui/notifier/notifier_other.go
Normal file
9
client/ui/notifier/notifier_other.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !windows
|
||||
|
||||
package notifier
|
||||
|
||||
import "fyne.io/fyne/v2"
|
||||
|
||||
func newNotifier(app fyne.App) Notifier {
|
||||
return &fyneNotifier{app: app}
|
||||
}
|
||||
88
client/ui/notifier/notifier_windows.go
Normal file
88
client/ui/notifier/notifier_windows.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
toast "git.sr.ht/~jackmordaunt/go-toast/v2"
|
||||
"git.sr.ht/~jackmordaunt/go-toast/v2/wintoast"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// appID is the AppUserModelID shown in the Windows Action Center. It
|
||||
// must match the System.AppUserModel.ID property set on the Start Menu
|
||||
// shortcut by the MSI (see client/netbird.wxs); otherwise Windows
|
||||
// groups toasts under a separate, unbranded entry.
|
||||
appID = "NetBird"
|
||||
|
||||
// appGUID identifies the COM activation callback class. Generated once
|
||||
// for NetBird; do not change without coordinating an installer bump,
|
||||
// since old registry entries pointing at the previous GUID would orphan.
|
||||
appGUID = "{0E1B4DE7-E148-432B-9814-544F941826EC}"
|
||||
)
|
||||
|
||||
type comNotifier struct {
|
||||
fallback *fyneNotifier
|
||||
ready bool
|
||||
iconPath string
|
||||
}
|
||||
|
||||
var (
|
||||
initOnce sync.Once
|
||||
initErr error
|
||||
)
|
||||
|
||||
func newNotifier(app fyne.App) Notifier {
|
||||
n := &comNotifier{
|
||||
fallback: &fyneNotifier{app: app},
|
||||
iconPath: resolveIcon(),
|
||||
}
|
||||
initOnce.Do(func() {
|
||||
initErr = wintoast.SetAppData(wintoast.AppData{
|
||||
AppID: appID,
|
||||
GUID: appGUID,
|
||||
IconPath: n.iconPath,
|
||||
})
|
||||
})
|
||||
if initErr != nil {
|
||||
log.Warnf("toast: register app data failed, falling back to fyne notifications: %v", initErr)
|
||||
return n.fallback
|
||||
}
|
||||
n.ready = true
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *comNotifier) Send(title, body string) {
|
||||
if !n.ready {
|
||||
n.fallback.Send(title, body)
|
||||
return
|
||||
}
|
||||
notification := toast.Notification{
|
||||
AppID: appID,
|
||||
Title: title,
|
||||
Body: body,
|
||||
Icon: n.iconPath,
|
||||
}
|
||||
if err := notification.Push(); err != nil {
|
||||
log.Warnf("toast: push failed, using fyne fallback: %v", err)
|
||||
n.fallback.Send(title, body)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveIcon returns an absolute path to the toast icon, or an empty string
|
||||
// when no icon can be located. Windows requires a PNG/JPG for the
|
||||
// AppUserModelId IconUri registry value; .ico is silently ignored.
|
||||
func resolveIcon() string {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
candidate := filepath.Join(filepath.Dir(exe), "netbird.png")
|
||||
if _, err := os.Stat(candidate); err == nil {
|
||||
return candidate
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -548,7 +548,7 @@ func (p *profileMenu) refresh() {
|
||||
if err != nil {
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
// show notification dialog
|
||||
p.app.SendNotification(fyne.NewNotification("Error", "Failed to switch profile"))
|
||||
p.serviceClient.notifier.Send("Error", "Failed to switch profile")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -628,9 +628,9 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
if err := p.eventHandler.logout(p.ctx); err != nil {
|
||||
log.Errorf("logout failed: %v", err)
|
||||
p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
|
||||
p.serviceClient.notifier.Send("Error", "Failed to deregister")
|
||||
} else {
|
||||
p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
|
||||
p.serviceClient.notifier.Send("Success", "Deregistered successfully")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
5
go.mod
5
go.mod
@@ -30,6 +30,7 @@ require (
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.7.0
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
||||
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.38.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.6
|
||||
@@ -46,6 +47,7 @@ require (
|
||||
github.com/crowdsecurity/go-cs-bouncer v0.0.21
|
||||
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
||||
github.com/dexidp/dex/api/v2 v2.4.0
|
||||
github.com/ebitengine/purego v0.8.4
|
||||
github.com/eko/gocache/lib/v4 v4.2.0
|
||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||
@@ -71,7 +73,7 @@ require (
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/oapi-codegen/runtime v1.1.2
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
@@ -178,7 +180,6 @@ require (
|
||||
github.com/docker/docker v28.0.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/ebitengine/purego v0.8.4 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fredbi/uri v1.1.1 // indirect
|
||||
github.com/fyne-io/gl-js v0.2.0 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -15,6 +15,8 @@ fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ=
|
||||
fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE=
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4=
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
|
||||
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 h1:N3IGoHHp9pb6mj1cbXbuaSXV/UMKwmbKLf53nQmtqMA=
|
||||
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3/go.mod h1:QtOLZGz8olr4qH2vWK0QH0w0O4T9fEIjMuWpKUsH7nc=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
|
||||
@@ -453,8 +455,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34 h1:g74mB64wnjCagzE1spKgPfTI/ont1SdSL3uX5bOecgM=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34/go.mod h1:lCOq5d1i19AQjEEW2d7aNK0Nn0KC0MKyfMz/PLwVBFg=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
|
||||
@@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) {
|
||||
// are stored with types that Dex can open.
|
||||
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
|
||||
switch connType {
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
|
||||
return "oidc", applyOIDCDefaults(connType, config)
|
||||
default:
|
||||
return connType, config
|
||||
@@ -218,6 +218,8 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin
|
||||
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
|
||||
case "okta", "pocketid":
|
||||
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "adfs":
|
||||
augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
|
||||
}
|
||||
|
||||
return augmented
|
||||
|
||||
@@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto
|
||||
var err error
|
||||
|
||||
switch cfg.Type {
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
|
||||
dexType = "oidc"
|
||||
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
|
||||
case "google":
|
||||
@@ -220,6 +220,8 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "adfs":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
|
||||
}
|
||||
return encodeConnectorConfig(oidcConfig)
|
||||
}
|
||||
@@ -283,7 +285,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa
|
||||
// inferOIDCProviderType infers the specific OIDC provider from connector ID
|
||||
func inferOIDCProviderType(connectorID string) string {
|
||||
connectorIDLower := strings.ToLower(connectorID)
|
||||
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
|
||||
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} {
|
||||
if strings.Contains(connectorIDLower, provider) {
|
||||
return provider
|
||||
}
|
||||
|
||||
@@ -231,7 +231,20 @@ get_upstream_host() {
|
||||
|
||||
wait_management_proxy() {
|
||||
local proxy_container="${1:-traefik}"
|
||||
local use_docker_logs=false
|
||||
set +e
|
||||
|
||||
if [[ "$proxy_container" == "detect-traefik" ]]; then
|
||||
proxy_container=$(docker ps --format "{{.ID}}\t{{.Image}}\t{{.Ports}}" \
|
||||
| awk -F'\t' '$2 ~ /traefik/ && $3 ~ /:(80|443)->/ {print $1; exit}')
|
||||
|
||||
if [[ -z "$proxy_container" ]]; then
|
||||
echo "Warning: could not auto-detect Traefik container, log output will be skipped on timeout." > /dev/stderr
|
||||
else
|
||||
use_docker_logs=true
|
||||
fi
|
||||
fi
|
||||
|
||||
echo -n "Waiting for NetBird server to become ready"
|
||||
counter=1
|
||||
while true; do
|
||||
@@ -242,7 +255,13 @@ wait_management_proxy() {
|
||||
if [[ $counter -eq 60 ]]; then
|
||||
echo ""
|
||||
echo "Taking too long. Checking logs..."
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container"
|
||||
if [[ -n "$proxy_container" ]]; then
|
||||
if [[ "$use_docker_logs" == "true" ]]; then
|
||||
docker logs --tail=20 "$proxy_container"
|
||||
else
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container"
|
||||
fi
|
||||
fi
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 netbird-server
|
||||
fi
|
||||
echo -n " ."
|
||||
@@ -518,7 +537,7 @@ start_services_and_show_instructions() {
|
||||
$DOCKER_COMPOSE_COMMAND up -d
|
||||
|
||||
sleep 3
|
||||
wait_management_direct
|
||||
wait_management_proxy detect-traefik
|
||||
|
||||
echo -e "$MSG_DONE"
|
||||
print_post_setup_instructions
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -16,11 +15,9 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -58,13 +55,6 @@ type Controller struct {
|
||||
proxyController port_forwarding.Controller
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
holder *types.Holder
|
||||
|
||||
expNewNetworkMap bool
|
||||
expNewNetworkMapAIDs map[string]struct{}
|
||||
|
||||
compactedNetworkMap bool
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
@@ -81,29 +71,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||
}
|
||||
|
||||
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
|
||||
newNetworkMapBuilder = false
|
||||
}
|
||||
|
||||
compactedNetworkMap := true
|
||||
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
|
||||
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
|
||||
if err != nil && len(compactedEnv) > 0 {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
|
||||
}
|
||||
if err == nil && !parsedCompactedNmap {
|
||||
log.WithContext(ctx).Info("disabling compacted mode")
|
||||
compactedNetworkMap = false
|
||||
}
|
||||
|
||||
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
||||
expIDs := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
expIDs[id] = struct{}{}
|
||||
}
|
||||
|
||||
return &Controller{
|
||||
repo: newRepository(store),
|
||||
metrics: nMetrics,
|
||||
@@ -117,12 +84,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
|
||||
proxyController: proxyController,
|
||||
EphemeralPeersManager: ephemeralPeersManager,
|
||||
|
||||
holder: types.NewHolder(),
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
expNewNetworkMapAIDs: expIDs,
|
||||
|
||||
compactedNetworkMap: compactedNetworkMap,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,17 +114,9 @@ func (c *Controller) CountStreams() int {
|
||||
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
@@ -197,10 +150,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
@@ -243,16 +192,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
switch {
|
||||
case c.experimentalNetworkMap(accountID):
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
case c.compactedNetworkMap:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
default:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
@@ -318,10 +258,6 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return fmt.Errorf("recalculate network map cache: %v", err)
|
||||
}
|
||||
|
||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -371,16 +307,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return err
|
||||
}
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
switch {
|
||||
case c.experimentalNetworkMap(accountId):
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
case c.compactedNetworkMap:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
default:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -451,17 +378,9 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return peer, emptyMap, nil, 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
@@ -493,20 +412,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
if c.compactedNetworkMap {
|
||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -518,108 +427,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
c.enrichAccountFromHolder(account)
|
||||
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
||||
}
|
||||
|
||||
func (c *Controller) getPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
||||
if account == nil {
|
||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||
return &types.NetworkMap{
|
||||
Network: &types.Network{},
|
||||
}
|
||||
}
|
||||
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
||||
c.enrichAccountFromHolder(account)
|
||||
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
c.enrichAccountFromHolder(account)
|
||||
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
||||
account := c.getAccountFromHolder(accountId)
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
account.UpdatePeerInNetworkMapCache(peer)
|
||||
}
|
||||
|
||||
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
account.RecalculateNetworkMapCache(validatedPeers)
|
||||
c.updateAccountInHolder(account)
|
||||
}
|
||||
|
||||
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
||||
return err
|
||||
}
|
||||
c.recalculateNetworkMapCache(account, validatedPeers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) experimentalNetworkMap(accountId string) bool {
|
||||
_, ok := c.expNewNetworkMapAIDs[accountId]
|
||||
return c.expNewNetworkMap || ok
|
||||
}
|
||||
|
||||
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
|
||||
a := c.holder.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
c.holder.AddAccount(account)
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache = a.NetworkMapCache
|
||||
if account.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
|
||||
return c.holder.GetAccount(accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
|
||||
a := c.holder.GetAccount(accountID)
|
||||
if a != nil {
|
||||
return a
|
||||
}
|
||||
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return account
|
||||
}
|
||||
|
||||
func (c *Controller) updateAccountInHolder(account *types.Account) {
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
||||
if settings == nil {
|
||||
@@ -756,16 +563,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get peers by ids: %w", err)
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
c.UpdatePeerInNetworkMapCache(accountID, peer)
|
||||
}
|
||||
|
||||
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||
}
|
||||
@@ -775,14 +573,6 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
|
||||
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
|
||||
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
|
||||
}
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -817,19 +607,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
})
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||
continue
|
||||
}
|
||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
@@ -872,21 +649,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||
} else {
|
||||
account.InjectProxyPolicies(ctx)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
if c.compactedNetworkMap {
|
||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
}
|
||||
account.InjectProxyPolicies(ctx)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
|
||||
@@ -12,9 +12,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
||||
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
||||
|
||||
DnsForwarderPort = nbdns.ForwarderServerPort
|
||||
OldForwarderPort = nbdns.ForwarderClientPort
|
||||
DnsForwarderPortMinVersion = "v0.59.0"
|
||||
|
||||
@@ -16,6 +16,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -35,15 +38,17 @@ type Manager interface {
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
accountManager account.Manager
|
||||
|
||||
networkMapController network_map.Controller
|
||||
}
|
||||
|
||||
func NewManager(store store.Store) Manager {
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,10 +65,28 @@ func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
|
||||
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
||||
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
}
|
||||
|
||||
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
}
|
||||
|
||||
|
||||
@@ -5,11 +5,8 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -18,15 +15,21 @@ type handler struct {
|
||||
manager accesslogs.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager, permissionsManager permissions.Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/events/proxy", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAccessLogs)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var filter accesslogs.AccessLogFilter
|
||||
filter.ParseFromRequest(r)
|
||||
|
||||
|
||||
@@ -9,19 +9,25 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
geo geolocation.Geolocation
|
||||
cleanupCancel context.CancelFunc
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
geo geolocation.Geolocation
|
||||
cleanupCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, geo geolocation.Geolocation) accesslogs.Manager {
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
geo: geo,
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
geo: geo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +63,14 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
|
||||
|
||||
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
|
||||
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, 0, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, 0, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
|
||||
}
|
||||
|
||||
@@ -6,11 +6,8 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -20,15 +17,15 @@ type handler struct {
|
||||
manager Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager Manager, permissionsManager permissions.Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllDomains)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Create, h.createCustomDomain)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteCustomDomain)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}/validate", permissionsManager.WithPermission(modules.Services, operations.Create, h.triggerCustomDomainValidation)).Methods("GET", "OPTIONS") // TODO: this should be a POST
|
||||
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
|
||||
@@ -59,7 +56,13 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
||||
return resp
|
||||
}
|
||||
|
||||
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -74,7 +77,13 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth
|
||||
util.WriteJSONObject(r.Context(), w, ret)
|
||||
}
|
||||
|
||||
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiReverseProxiesDomainsJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -90,7 +99,13 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, use
|
||||
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
|
||||
}
|
||||
|
||||
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
domainID := mux.Vars(r)["domainId"]
|
||||
if domainID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||
@@ -105,7 +120,13 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, use
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
domainID := mux.Vars(r)["domainId"]
|
||||
if domainID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||
|
||||
@@ -11,7 +11,11 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type store interface {
|
||||
@@ -33,22 +37,32 @@ type proxyManager interface {
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyManager proxyManager
|
||||
accountManager account.Manager
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyManager proxyManager
|
||||
permissionsManager permissions.Manager
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store, proxyMgr proxyManager, accountManager account.Manager) Manager {
|
||||
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
|
||||
return Manager{
|
||||
store: store,
|
||||
proxyManager: proxyMgr,
|
||||
validator: domain.Validator{Resolver: net.DefaultResolver},
|
||||
accountManager: accountManager,
|
||||
store: store,
|
||||
proxyManager: proxyMgr,
|
||||
validator: domain.Validator{Resolver: net.DefaultResolver},
|
||||
permissionsManager: permissionsManager,
|
||||
accountManager: accountManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
domains, err := m.store.ListCustomDomains(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list custom domains: %w", err)
|
||||
@@ -104,6 +118,14 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
}
|
||||
|
||||
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Verify the target cluster is in the available clusters
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
@@ -137,6 +159,14 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
}
|
||||
|
||||
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get domain from store: %w", err)
|
||||
@@ -153,6 +183,21 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
|
||||
}
|
||||
|
||||
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"accountID": accountID,
|
||||
"domainID": domainID,
|
||||
}).WithError(err).Error("validate domain")
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
log.WithFields(log.Fields{
|
||||
"accountID": accountID,
|
||||
"domainID": domainID,
|
||||
}).WithError(err).Error("validate domain")
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"accountID": accountID,
|
||||
"domainID": domainID,
|
||||
|
||||
@@ -23,7 +23,7 @@ type Proxy struct {
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Capabilities Capabilities `gorm:"embedded"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
@@ -6,14 +6,12 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -32,19 +30,25 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
|
||||
}
|
||||
|
||||
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
|
||||
domainmanager.RegisterEndpoints(domainRouter, domainManager, permissionsManager)
|
||||
domainmanager.RegisterEndpoints(domainRouter, domainManager)
|
||||
|
||||
accesslogsmanager.RegisterEndpoints(router, accessLogsManager, permissionsManager)
|
||||
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
||||
|
||||
router.HandleFunc("/reverse-proxies/clusters", permissionsManager.WithPermission(modules.Services, operations.Read, h.getClusters)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllServices)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Create, h.createService)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Read, h.getService)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Update, h.updateService)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteService)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -59,7 +63,13 @@ func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAut
|
||||
util.WriteJSONObject(r.Context(), w, apiServices)
|
||||
}
|
||||
|
||||
func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.ServiceRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -67,13 +77,12 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth
|
||||
}
|
||||
|
||||
service := new(rpservice.Service)
|
||||
var err error
|
||||
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.Validate(); err != nil {
|
||||
if err = service.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -87,7 +96,13 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth
|
||||
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
serviceID := mux.Vars(r)["serviceId"]
|
||||
if serviceID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||
@@ -103,7 +118,13 @@ func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
serviceID := mux.Vars(r)["serviceId"]
|
||||
if serviceID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||
@@ -118,13 +139,12 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth
|
||||
|
||||
service := new(rpservice.Service)
|
||||
service.ID = serviceID
|
||||
var err error
|
||||
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.Validate(); err != nil {
|
||||
if err = service.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -138,7 +158,13 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth
|
||||
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
serviceID := mux.Vars(r)["serviceId"]
|
||||
if serviceID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||
@@ -153,7 +179,13 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
@@ -85,17 +86,18 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
}
|
||||
|
||||
mgr := &Manager{
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
proxyController: mockCtrl,
|
||||
capabilities: mockCaps,
|
||||
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
permissionsManager: permissions.NewManager(testStore),
|
||||
proxyController: mockCtrl,
|
||||
capabilities: mockCaps,
|
||||
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
|
||||
}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
|
||||
|
||||
@@ -21,6 +21,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -79,22 +82,24 @@ type CapabilityProvider interface {
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
proxyController proxy.Controller
|
||||
capabilities CapabilityProvider
|
||||
clusterDeriver ClusterDeriver
|
||||
exposeReaper *exposeReaper
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
proxyController proxy.Controller
|
||||
capabilities CapabilityProvider
|
||||
clusterDeriver ClusterDeriver
|
||||
exposeReaper *exposeReaper
|
||||
}
|
||||
|
||||
// NewManager creates a new service manager.
|
||||
func NewManager(store store.Store, accountManager account.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
|
||||
mgr := &Manager{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
proxyController: proxyController,
|
||||
capabilities: capabilities,
|
||||
clusterDeriver: clusterDeriver,
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
proxyController: proxyController,
|
||||
capabilities: capabilities,
|
||||
clusterDeriver: clusterDeriver,
|
||||
}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
return mgr
|
||||
@@ -107,10 +112,26 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||
|
||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetActiveProxyClusters(ctx)
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
@@ -164,6 +185,14 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
||||
}
|
||||
|
||||
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
@@ -177,6 +206,14 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s
|
||||
}
|
||||
|
||||
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -187,7 +224,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
|
||||
|
||||
err := m.replaceHostByLookup(ctx, accountID, s)
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
@@ -454,6 +491,14 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
@@ -740,8 +785,16 @@ func validateResourceTargetType(target *service.Target, resource *resourcetypes.
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var s *service.Service
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
@@ -772,8 +825,16 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var services []*service.Service
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
@@ -1058,7 +1119,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
|
||||
}
|
||||
groupIDs := make([]string, 0, len(groupNames))
|
||||
for _, groupName := range groupNames {
|
||||
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
|
||||
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
|
||||
}
|
||||
|
||||
@@ -23,6 +23,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -697,10 +700,12 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
|
||||
require.NoError(t, err)
|
||||
|
||||
permsMgr := permissions.NewManager(testStore)
|
||||
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
}
|
||||
@@ -713,9 +718,10 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
require.NoError(t, err)
|
||||
|
||||
mgr := &Manager{
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
proxyController: proxyController,
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
permissionsManager: permsMgr,
|
||||
proxyController: proxyController,
|
||||
clusterDeriver: &testClusterDeriver{
|
||||
domains: []string{"test.netbird.io"},
|
||||
},
|
||||
@@ -1124,6 +1130,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPerms := permissions.NewMockManager(ctrl)
|
||||
mockAcct := account.NewMockManager(ctrl)
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
@@ -1134,9 +1141,10 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
mgr := &Manager{
|
||||
store: sqlStore,
|
||||
accountManager: mockAcct,
|
||||
proxyController: proxyController,
|
||||
store: sqlStore,
|
||||
permissionsManager: mockPerms,
|
||||
accountManager: mockAcct,
|
||||
proxyController: proxyController,
|
||||
}
|
||||
|
||||
service := &rpservice.Service{
|
||||
@@ -1159,6 +1167,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion")
|
||||
|
||||
mockPerms.EXPECT().
|
||||
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
|
||||
Return(true, nil)
|
||||
mockAcct.EXPECT().
|
||||
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
|
||||
mockAcct.EXPECT().
|
||||
|
||||
@@ -6,11 +6,8 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -20,19 +17,25 @@ type handler struct {
|
||||
manager zones.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager zones.Manager, permissionsManager permissions.Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllZones)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createZone)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getZone)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateZone)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteZone)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -47,7 +50,13 @@ func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
util.WriteJSONObject(r.Context(), w, apiZones)
|
||||
}
|
||||
|
||||
func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiDnsZonesJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -57,7 +66,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
zone := new(zones.Zone)
|
||||
zone.FromAPIRequest(&req)
|
||||
|
||||
if err := zone.Validate(); err != nil {
|
||||
if err = zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -71,7 +80,13 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -87,7 +102,13 @@ func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth
|
||||
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -95,7 +116,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
@@ -104,7 +125,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
zone.FromAPIRequest(&req)
|
||||
zone.ID = zoneID
|
||||
|
||||
if err := zone.Validate(); err != nil {
|
||||
if err = zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -118,14 +139,20 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
|
||||
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -7,34 +7,62 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
dnsDomain string
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
dnsDomain string
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, accountManager account.Manager, dnsDomain string) zones.Manager {
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
dnsDomain: dnsDomain,
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
dnsDomain: dnsDomain,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
|
||||
var err error
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -74,6 +102,14 @@ func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string,
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get zone: %w", err)
|
||||
@@ -114,6 +150,14 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
|
||||
@@ -13,6 +13,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -26,7 +29,7 @@ const (
|
||||
testDNSDomain = "netbird.selfhosted"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *gomock.Controller, func()) {
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -46,17 +49,23 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccoun
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockAccountManager := &mock_server.MockAccountManager{}
|
||||
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||
|
||||
manager := NewManager(testStore, mockAccountManager, testDNSDomain).(*managerImpl)
|
||||
manager := &managerImpl{
|
||||
store: testStore,
|
||||
accountManager: mockAccountManager,
|
||||
permissionsManager: mockPermissionsManager,
|
||||
dnsDomain: testDNSDomain,
|
||||
}
|
||||
|
||||
return manager, testStore, mockAccountManager, ctrl, cleanup
|
||||
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, _, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -68,6 +77,10 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
err = testStore.CreateZone(ctx, zone2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
@@ -75,13 +88,43 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
assert.Equal(t, zone2.ID, result[1].ID)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission validation error", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, _, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -89,6 +132,10 @@ func TestManagerImpl_GetZone(t *testing.T) {
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, zone.ID, result.ID)
|
||||
@@ -96,13 +143,29 @@ func TestManagerImpl_GetZone(t *testing.T) {
|
||||
assert.Equal(t, zone.Domain, result.Domain)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, _, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -114,6 +177,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
@@ -132,8 +199,31 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "new.example.com",
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("invalid group", func(t *testing.T) {
|
||||
manager, _, _, ctrl, cleanup := setupTest(t)
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -143,13 +233,17 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
DistributionGroups: []string{"invalid-group"},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("duplicate domain", func(t *testing.T) {
|
||||
manager, testStore, _, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -165,6 +259,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -175,7 +273,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("peer DNS domain conflict", func(t *testing.T) {
|
||||
manager, testStore, _, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -193,6 +291,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -203,7 +305,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("default DNS domain conflict", func(t *testing.T) {
|
||||
manager, _, _, ctrl, cleanup := setupTest(t)
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -215,6 +317,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -229,7 +335,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -246,6 +352,10 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
@@ -265,7 +375,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("domain change not allowed", func(t *testing.T) {
|
||||
manager, testStore, _, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -282,6 +392,10 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -291,8 +405,31 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: testZoneID,
|
||||
Name: "Updated Name",
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("zone not found", func(t *testing.T) {
|
||||
manager, _, _, ctrl, cleanup := setupTest(t)
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -302,6 +439,10 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -312,7 +453,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success with records", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -328,6 +469,10 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCallCount := 0
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCallCount++
|
||||
@@ -348,7 +493,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("success without records", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -356,6 +501,10 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
@@ -373,11 +522,31 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("zone not found", func(t *testing.T) {
|
||||
manager, _, _, ctrl, cleanup := setupTest(t)
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("zone not found", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
@@ -6,11 +6,8 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -20,19 +17,25 @@ type handler struct {
|
||||
manager records.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager records.Manager, permissionsManager permissions.Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllRecords)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createRecord)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getRecord)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateRecord)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteRecord)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -53,7 +56,13 @@ func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth
|
||||
util.WriteJSONObject(r.Context(), w, apiRecords)
|
||||
}
|
||||
|
||||
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -69,7 +78,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth
|
||||
record := new(records.Record)
|
||||
record.FromAPIRequest(&req)
|
||||
|
||||
if err := record.Validate(); err != nil {
|
||||
if err = record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -83,7 +92,13 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth
|
||||
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -105,7 +120,13 @@ func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *au
|
||||
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -119,7 +140,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
@@ -128,7 +149,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth
|
||||
record.FromAPIRequest(&req)
|
||||
record.ID = recordID
|
||||
|
||||
if err := record.Validate(); err != nil {
|
||||
if err = record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -142,7 +163,13 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth
|
||||
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -155,7 +182,7 @@ func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
|
||||
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,36 +9,64 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, accountManager account.Manager) records.Manager {
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
|
||||
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
@@ -73,11 +101,18 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
var record *records.Record
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
@@ -125,11 +160,18 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var record *records.Record
|
||||
var zone *zones.Zone
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
|
||||
@@ -12,8 +12,12 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,7 +27,7 @@ const (
|
||||
testGroupID = "test-group-id"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *gomock.Controller, func()) {
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -47,17 +51,22 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_serv
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockAccountManager := &mock_server.MockAccountManager{}
|
||||
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||
|
||||
manager := NewManager(testStore, mockAccountManager).(*managerImpl)
|
||||
manager := &managerImpl{
|
||||
store: testStore,
|
||||
accountManager: mockAccountManager,
|
||||
permissionsManager: mockPermissionsManager,
|
||||
}
|
||||
|
||||
return manager, testStore, zone, mockAccountManager, ctrl, cleanup
|
||||
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -69,6 +78,10 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
@@ -76,13 +89,43 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
assert.Equal(t, record2.ID, result[1].ID)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission validation error", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -90,6 +133,10 @@ func TestManagerImpl_GetRecord(t *testing.T) {
|
||||
err := testStore.CreateDNSRecord(ctx, record)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record.ID, result.ID)
|
||||
@@ -99,13 +146,29 @@ func TestManagerImpl_GetRecord(t *testing.T) {
|
||||
assert.Equal(t, record.TTL, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success - A record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -116,6 +179,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
@@ -135,7 +202,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("success - AAAA record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -146,6 +213,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
@@ -160,7 +231,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("success - CNAME record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -171,6 +242,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
@@ -184,8 +259,32 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
assert.Equal(t, inputRecord.Content, result.Content)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record name not in zone", func(t *testing.T) {
|
||||
manager, _, zone, _, ctrl, cleanup := setupTest(t)
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -196,6 +295,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -203,7 +306,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("duplicate record", func(t *testing.T) {
|
||||
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -218,6 +321,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -225,7 +332,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
|
||||
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -240,6 +347,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -251,7 +362,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -267,6 +378,10 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
TTL: 600, // Changed TTL
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
@@ -285,7 +400,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("update only TTL - no validation", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -301,6 +416,10 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
TTL: 600, // Only TTL changed
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
// Event should be stored
|
||||
}
|
||||
@@ -311,8 +430,33 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
assert.Equal(t, 600, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: testRecordID,
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.100",
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record not found", func(t *testing.T) {
|
||||
manager, _, zone, _, ctrl, cleanup := setupTest(t)
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -324,13 +468,17 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("update creates duplicate", func(t *testing.T) {
|
||||
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -350,6 +498,10 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -361,7 +513,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -369,6 +521,10 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
err := testStore.CreateDNSRecord(ctx, record)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
@@ -386,11 +542,31 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("record not found", func(t *testing.T) {
|
||||
manager, _, zone, _, ctrl, cleanup := setupTest(t)
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record not found", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
@@ -233,7 +233,7 @@ func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
|
||||
|
||||
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
||||
return Create(s, func() accesslogs.Manager {
|
||||
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.GeoLocationManager())
|
||||
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
||||
accessLogManager.StartPeriodicCleanup(
|
||||
context.Background(),
|
||||
s.Config.ReverseProxy.AccessLogRetentionDays,
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
@@ -28,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
)
|
||||
@@ -82,13 +82,13 @@ func (s *BaseServer) SettingsManager() settings.Manager {
|
||||
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
|
||||
}
|
||||
|
||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, idpConfig)
|
||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) PeersManager() peers.Manager {
|
||||
return Create(s, func() peers.Manager {
|
||||
manager := peers.NewManager(s.Store())
|
||||
manager := peers.NewManager(s.Store(), s.PermissionsManager())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
manager.SetNetworkMapController(s.NetworkMapController())
|
||||
manager.SetIntegratedPeerValidator(s.IntegratedValidator())
|
||||
@@ -161,43 +161,43 @@ func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
||||
|
||||
func (s *BaseServer) GroupsManager() groups.Manager {
|
||||
return Create(s, func() groups.Manager {
|
||||
return groups.NewManager(s.Store(), s.AccountManager())
|
||||
return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||
return Create(s, func() resources.Manager {
|
||||
return resources.NewManager(s.Store(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
|
||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RoutesManager() routers.Manager {
|
||||
return Create(s, func() routers.Manager {
|
||||
return routers.NewManager(s.Store(), s.AccountManager())
|
||||
return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) NetworksManager() networks.Manager {
|
||||
return Create(s, func() networks.Manager {
|
||||
return networks.NewManager(s.Store(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
||||
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ZonesManager() zones.Manager {
|
||||
return Create(s, func() zones.Manager {
|
||||
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.DNSDomain())
|
||||
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RecordsManager() records.Manager {
|
||||
return Create(s, func() records.Manager {
|
||||
return recordsManager.NewManager(s.Store(), s.AccountManager())
|
||||
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ServiceManager() service.Manager {
|
||||
return Create(s, func() service.Manager {
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
|
||||
|
||||
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
return Create(s, func() *manager.Manager {
|
||||
m := manager.NewManager(s.Store(), s.ProxyManager(), s.AccountManager())
|
||||
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
|
||||
return &m
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
@@ -40,6 +39,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -280,14 +282,22 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
|
||||
// User that performs the update has to belong to the account.
|
||||
// Returns an updated Settings
|
||||
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var oldSettings *types.Settings
|
||||
var updateAccountPeers bool
|
||||
var groupChangesAffectPeers bool
|
||||
var reloadReverseProxy bool
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var groupsUpdated bool
|
||||
var err error
|
||||
|
||||
oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
@@ -715,6 +725,15 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return err
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
|
||||
}
|
||||
|
||||
userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users))
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||
@@ -957,10 +976,6 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, fmt.Errorf("user %s does not belong to account %s", userID, accountID)
|
||||
}
|
||||
|
||||
key := user.IntegrationReference.CacheKey(accountID, userID)
|
||||
ud, err := am.externalCacheManager.Get(am.ctx, key)
|
||||
if err != nil {
|
||||
@@ -1272,16 +1287,41 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
|
||||
|
||||
// GetAccountByID returns an account associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAccountMeta returns the account metadata associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// GetAccountOnboarding retrieves the onboarding information for a specific account.
|
||||
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
|
||||
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
|
||||
log.Errorf("failed to get account onboarding for account %s: %v", accountID, err)
|
||||
@@ -1298,6 +1338,15 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
|
||||
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
|
||||
return nil, fmt.Errorf("failed to get account onboarding: %w", err)
|
||||
@@ -1356,8 +1405,9 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
// Permission checks are now handled by the HTTP middleware via WithPermission wrapper
|
||||
// User account association is already validated above by GetUserByUserID
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if !user.IsServiceUser && userAuth.Invited {
|
||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||
@@ -1799,6 +1849,13 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
@@ -2140,6 +2197,14 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate user permissions: %w", err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update peer IP transaction: %w", err)
|
||||
|
||||
@@ -60,7 +60,7 @@ type Manager interface {
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*nbpeer.Peer, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -75,7 +75,7 @@ type Manager interface {
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
|
||||
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
|
||||
|
||||
@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
|
||||
}
|
||||
|
||||
// GetGroupByName mocks base method.
|
||||
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
|
||||
ret0, _ := ret[0].(*types.Group)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupByName indicates an expected call of GetGroupByName.
|
||||
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
|
||||
}
|
||||
|
||||
// GetIdentityProvider mocks base method.
|
||||
@@ -946,18 +946,18 @@ func (mr *MockManagerMockRecorder) GetPeerNetwork(ctx, peerID interface{}) *gomo
|
||||
}
|
||||
|
||||
// GetPeers mocks base method.
|
||||
func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*peer.Peer, error) {
|
||||
func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter, all)
|
||||
ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter)
|
||||
ret0, _ := ret[0].([]*peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeers indicates an expected call of GetPeers.
|
||||
func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter, all interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter, all)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter)
|
||||
}
|
||||
|
||||
// GetPolicy mocks base method.
|
||||
|
||||
@@ -22,10 +22,8 @@ import (
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
@@ -51,6 +49,7 @@ import (
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -409,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@@ -1172,11 +1171,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
@@ -1232,11 +1226,6 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
@@ -1275,11 +1264,6 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
@@ -1333,11 +1317,6 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
@@ -1398,11 +1377,6 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
@@ -1634,75 +1608,6 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
}
|
||||
|
||||
func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, prefix2, err := route.ParseNetwork("192.168.0.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
account := &types.Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
|
||||
},
|
||||
Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
|
||||
Routes: map[route.ID]*route.Route{
|
||||
"route-1": {
|
||||
ID: "route-1",
|
||||
Network: prefix,
|
||||
NetID: "network-1",
|
||||
Description: "network-1",
|
||||
Peer: "peer-1",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
"route-2": {
|
||||
ID: "route-2",
|
||||
Network: prefix2,
|
||||
NetID: "network-2",
|
||||
Description: "network-2",
|
||||
Peer: "peer-2",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
"route-3": {
|
||||
ID: "route-3",
|
||||
Network: prefix,
|
||||
NetID: "network-1",
|
||||
Description: "network-1",
|
||||
Peer: "peer-2",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
|
||||
|
||||
assert.Len(t, routes, 2)
|
||||
routeIDs := make(map[route.ID]struct{}, 2)
|
||||
for _, r := range routes {
|
||||
routeIDs[r.ID] = struct{}{}
|
||||
}
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||
|
||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
|
||||
|
||||
assert.Len(t, emptyRoutes, 0)
|
||||
}
|
||||
|
||||
func TestAccount_Copy(t *testing.T) {
|
||||
account := &types.Account{
|
||||
Id: "account1",
|
||||
@@ -1825,9 +1730,7 @@ func TestAccount_Copy(t *testing.T) {
|
||||
AccountID: "account1",
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
}
|
||||
account.InitOnce()
|
||||
err := hasNilField(account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -3148,7 +3051,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
proxyManager := proxy.NewMockManager(ctrl)
|
||||
proxyManager.EXPECT().
|
||||
@@ -3165,7 +3068,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -3176,7 +3079,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, proxyController, proxyManager, nil))
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil))
|
||||
|
||||
return manager, updateManager, nil
|
||||
}
|
||||
@@ -3254,6 +3157,13 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
return manager, updateManager, account, peer1, peer2, peer3
|
||||
}
|
||||
|
||||
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
|
||||
// wrappers wait for an expected update message. Sized for slow CI runners
|
||||
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
|
||||
// seconds. Only runs down on failure; passing tests return immediately
|
||||
// when the channel delivers.
|
||||
const peerUpdateTimeout = 5 * time.Second
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
@@ -3272,7 +3182,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
|
||||
if msg == nil {
|
||||
t.Errorf("Received nil update message, expected valid message")
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("Timed out waiting for update message")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
@@ -20,6 +22,14 @@ const (
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
@@ -29,11 +39,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -14,11 +14,11 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -78,6 +79,16 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
if len(dnsSettings.DisabledManagementGroups) != 1 {
|
||||
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
|
||||
}
|
||||
|
||||
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
|
||||
if err == nil {
|
||||
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
|
||||
}
|
||||
|
||||
s, ok := status.FromError(err)
|
||||
if !ok && s.Type() != status.PermissionDenied {
|
||||
t.Errorf("returned error should be Permission Denied, got err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveDNSSettings(t *testing.T) {
|
||||
@@ -212,7 +223,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
// return empty extra settings for expected calls to UpdateAccountPeers
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -223,7 +234,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
|
||||
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
}
|
||||
@@ -447,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -467,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -507,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -9,8 +9,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func isEnabled() bool {
|
||||
@@ -20,6 +23,14 @@ func isEnabled() bool {
|
||||
|
||||
// GetEvents returns a list of activity events of an account
|
||||
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
@@ -30,24 +32,13 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||
// Permission checks are now handled by the HTTP middleware via WithPermission wrapper
|
||||
// This method is called from authenticated/authorized handlers, so we just validate
|
||||
// that the user exists and is part of the account
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return status.NewUserNotFoundError(userID)
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsBlocked() {
|
||||
return status.NewUserBlockedError()
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -70,17 +61,27 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
}
|
||||
|
||||
// CreateGroup object of the peers
|
||||
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -124,11 +125,19 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
|
||||
// UpdateGroup object of the peers
|
||||
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -187,24 +196,33 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
|
||||
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -225,7 +243,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -247,14 +264,21 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
|
||||
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -287,7 +311,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -393,6 +416,14 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
|
||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*types.Group
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
peer2 "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -619,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -637,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -655,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -688,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -729,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -756,17 +757,18 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving a group linked to network router should update account peers and send peer update
|
||||
t.Run("saving group linked to network router", func(t *testing.T) {
|
||||
groupsManager := groups.NewManager(manager.Store, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, groupsManager, manager, manager.serviceManager)
|
||||
routersManager := routers.NewManager(manager.Store, manager)
|
||||
networksManager := networks.NewManager(manager.Store, resourcesManager, routersManager, manager)
|
||||
permissionsManager := permissions.NewManager(manager.Store)
|
||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
|
||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||
|
||||
network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{
|
||||
ID: "network_test",
|
||||
@@ -802,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -22,21 +25,31 @@ type Manager interface {
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
type mockManager struct {
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, accountManager account.Manager) Manager {
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
accountManager: accountManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting account groups: %w", err)
|
||||
@@ -60,6 +73,14 @@ func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID str
|
||||
}
|
||||
|
||||
func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, userID, groupID, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error adding resource to group: %w", err)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/rs/cors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -32,8 +31,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||
|
||||
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
@@ -123,25 +124,25 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router, permissionsManager)
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
|
||||
users.AddEndpoints(accountManager, router, permissionsManager)
|
||||
users.AddInvitesEndpoints(accountManager, router, permissionsManager)
|
||||
users.AddEndpoints(accountManager, router)
|
||||
users.AddInvitesEndpoints(accountManager, router)
|
||||
users.AddPublicInvitesEndpoints(accountManager, router)
|
||||
setup_keys.AddEndpoints(accountManager, router, permissionsManager)
|
||||
policies.AddEndpoints(accountManager, LocationManager, router, permissionsManager)
|
||||
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router, permissionsManager)
|
||||
setup_keys.AddEndpoints(accountManager, router)
|
||||
policies.AddEndpoints(accountManager, LocationManager, router)
|
||||
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
|
||||
policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router)
|
||||
groups.AddEndpoints(accountManager, router, permissionsManager)
|
||||
routes.AddEndpoints(accountManager, router, permissionsManager)
|
||||
dns.AddEndpoints(accountManager, router, permissionsManager)
|
||||
events.AddEndpoints(accountManager, router, permissionsManager)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, permissionsManager, router)
|
||||
zonesManager.RegisterEndpoints(router, zManager, permissionsManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager, permissionsManager)
|
||||
idp.AddEndpoints(accountManager, router, permissionsManager)
|
||||
groups.AddEndpoints(accountManager, router)
|
||||
routes.AddEndpoints(accountManager, router)
|
||||
dns.AddEndpoints(accountManager, router)
|
||||
events.AddEndpoints(accountManager, router)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
||||
zonesManager.RegisterEndpoints(router, zManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router, permissionsManager)
|
||||
instance.AddVersionEndpoint(instanceManager, router)
|
||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
||||
}
|
||||
|
||||
@@ -12,13 +12,10 @@ import (
|
||||
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -43,11 +40,11 @@ type handler struct {
|
||||
settingsManager settings.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
|
||||
accountsHandler := newHandler(accountManager, settingsManager)
|
||||
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Update, accountsHandler.updateAccount)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Delete, accountsHandler.deleteAccount)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/accounts", permissionsManager.WithPermission(modules.Accounts, operations.Read, accountsHandler.getAllAccounts)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new handler HTTP handler
|
||||
@@ -102,7 +99,7 @@ func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID st
|
||||
}
|
||||
|
||||
func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error {
|
||||
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "", true)
|
||||
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "")
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "get peer count: %v", err)
|
||||
}
|
||||
@@ -139,26 +136,34 @@ func calculateRequiredAddresses(peerCount int) int64 {
|
||||
}
|
||||
|
||||
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.settingsManager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(userAuth.AccountId, settings, meta, onboarding)
|
||||
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, settings, meta, onboarding)
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
@@ -228,15 +233,24 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
|
||||
}
|
||||
|
||||
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
accountID := mux.Vars(r)["accountId"]
|
||||
if accountID != userAuth.AccountId {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w)
|
||||
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
_, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
accountID := vars["accountId"]
|
||||
if len(accountID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiAccountsAccountIdJSONRequestBody
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -253,7 +267,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
|
||||
return
|
||||
}
|
||||
if err := h.validateNetworkRange(r.Context(), accountID, userAuth.UserId, prefix); err != nil {
|
||||
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@@ -268,19 +282,19 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth
|
||||
}
|
||||
}
|
||||
|
||||
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userAuth.UserId, onboarding)
|
||||
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userAuth.UserId, settings)
|
||||
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userAuth.UserId)
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -292,14 +306,21 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth
|
||||
}
|
||||
|
||||
// deleteAccount is a HTTP DELETE handler to delete an account
|
||||
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
accountID := mux.Vars(r)["accountId"]
|
||||
if accountID != userAuth.AccountId {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w)
|
||||
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeleteAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
vars := mux.Vars(r)
|
||||
targetAccountID := vars["accountId"]
|
||||
if len(targetAccountID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -291,8 +290,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/accounts", permissions.WrapHandler(handler.getAllAccounts)).Methods("GET")
|
||||
router.HandleFunc("/api/accounts/{accountId}", permissions.WrapHandler(handler.updateAccount)).Methods("PUT")
|
||||
router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")
|
||||
router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -5,13 +5,11 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -21,15 +19,15 @@ type dnsSettingsHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
addDNSSettingEndpoint(accountManager, router, permissionsManager)
|
||||
addDNSNameserversEndpoint(accountManager, router, permissionsManager)
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
addDNSSettingEndpoint(accountManager, router)
|
||||
addDNSNameserversEndpoint(accountManager, router)
|
||||
}
|
||||
|
||||
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||
dnsSettingsHandler := newDNSSettingsHandler(accountManager)
|
||||
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Read, dnsSettingsHandler.getDNSSettings)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Update, dnsSettingsHandler.updateDNSSettings)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
|
||||
}
|
||||
|
||||
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
|
||||
@@ -38,8 +36,17 @@ func newDNSSettingsHandler(accountManager account.Manager) *dnsSettingsHandler {
|
||||
}
|
||||
|
||||
// getDNSSettings returns the DNS settings for the account
|
||||
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -53,9 +60,17 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
// updateDNSSettings handles update to DNS settings of an account
|
||||
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.PutApiDnsSettingsJSONRequestBody
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -65,7 +80,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re
|
||||
DisabledManagementGroups: req.DisabledManagementGroups,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId, updateDNSSettings)
|
||||
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
@@ -116,8 +115,8 @@ func TestDNSSettingsHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.getDNSSettings)).Methods("GET")
|
||||
router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.updateDNSSettings)).Methods("PUT")
|
||||
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
|
||||
router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -6,13 +6,11 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -23,13 +21,13 @@ type nameserversHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||
nameserversHandler := newNameserversHandler(accountManager)
|
||||
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getAllNameservers)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Create, nameserversHandler.createNameserverGroup)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Update, nameserversHandler.updateNameserverGroup)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getNameserverGroup)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Delete, nameserversHandler.deleteNameserverGroup)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newNameserversHandler returns a new instance of nameserversHandler handler
|
||||
@@ -38,8 +36,17 @@ func newNameserversHandler(accountManager account.Manager) *nameserversHandler {
|
||||
}
|
||||
|
||||
// getAllNameservers returns the list of nameserver groups for the account
|
||||
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -54,9 +61,17 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re
|
||||
}
|
||||
|
||||
// createNameserverGroup handles nameserver group creation request
|
||||
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.PostApiDnsNameserversJSONRequestBody
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -68,7 +83,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), userAuth.AccountId, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userAuth.UserId, req.SearchDomainsEnabled)
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -80,7 +95,15 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// updateNameserverGroup handles update to a nameserver group identified by a given ID
|
||||
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
@@ -88,7 +111,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -112,7 +135,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, updatedNSGroup)
|
||||
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -124,14 +147,22 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// deleteNameserverGroup handles nameserver group deletion request
|
||||
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeleteNameServerGroup(r.Context(), userAuth.AccountId, nsGroupID, userAuth.UserId)
|
||||
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -141,14 +172,22 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// getNameserverGroup handles a nameserver group Get request identified by ID
|
||||
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, nsGroupID)
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
@@ -202,10 +201,10 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.getNameserverGroup)).Methods("GET")
|
||||
router.HandleFunc("/api/dns/nameservers", permissions.WrapHandler(p.createNameserverGroup)).Methods("POST")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.deleteNameserverGroup)).Methods("DELETE")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.updateNameserverGroup)).Methods("PUT")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")
|
||||
router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -5,13 +5,11 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -21,10 +19,10 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
eventsHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/events", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events/audit", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events/audit", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new events handler
|
||||
@@ -33,8 +31,17 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllEvents list of the given account
|
||||
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
@@ -197,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/events/", permissions.WrapHandler(handler.getAllEvents)).Methods("GET")
|
||||
router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -7,13 +7,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -21,45 +19,46 @@ import (
|
||||
|
||||
// handler is a handler that returns groups of the account
|
||||
type handler struct {
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
groupsHandler := newHandler(accountManager, permissionsManager)
|
||||
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getGroup)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Delete, groupsHandler.deleteGroup)).Methods("DELETE", "OPTIONS")
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
groupsHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new groups handler
|
||||
func newHandler(accountManager account.Manager, permissionsManager permissions.Manager) *handler {
|
||||
func newHandler(accountManager account.Manager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
accountManager: accountManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) canReadPeers(r *http.Request, userAuth *auth.UserAuth) bool {
|
||||
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read)
|
||||
return err == nil && allowed
|
||||
}
|
||||
|
||||
// getAllGroups list for the account
|
||||
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
// Check if filtering by name
|
||||
groupName := r.URL.Query().Get("name")
|
||||
if groupName != "" {
|
||||
// Get single group by name
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, userAuth.AccountId)
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -72,13 +71,13 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth
|
||||
}
|
||||
|
||||
// Get all groups
|
||||
groups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -93,7 +92,15 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth
|
||||
}
|
||||
|
||||
// updateGroup handles update to a group identified by a given ID
|
||||
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
groupID, ok := vars["groupId"]
|
||||
if !ok {
|
||||
@@ -105,13 +112,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
return
|
||||
}
|
||||
|
||||
existingGroup, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
|
||||
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", userAuth.AccountId)
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -159,13 +166,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
IntegrationReference: existingGroup.IntegrationReference,
|
||||
}
|
||||
|
||||
if err := h.accountManager.UpdateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, userAuth.AccountId, err)
|
||||
if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -175,9 +182,17 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
}
|
||||
|
||||
// createGroup handles group creation request
|
||||
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.PostApiGroupsJSONRequestBody
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -211,13 +226,13 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
Issued: types.GroupIssuedAPI,
|
||||
}
|
||||
|
||||
err = h.accountManager.CreateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group)
|
||||
err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -227,14 +242,22 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
}
|
||||
|
||||
// deleteGroup handles group deletion request
|
||||
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeleteGroup(r.Context(), userAuth.AccountId, userAuth.UserId, groupID)
|
||||
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
|
||||
if err != nil {
|
||||
wrappedErr, ok := err.(interface{ Unwrap() []error })
|
||||
if ok && len(wrappedErr.Unwrap()) > 0 {
|
||||
@@ -250,26 +273,34 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
}
|
||||
|
||||
// getGroup returns a group
|
||||
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
|
||||
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
|
||||
|
||||
}
|
||||
|
||||
func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group {
|
||||
|
||||
@@ -13,14 +13,10 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
@@ -37,18 +33,8 @@ var TestPeers = map[string]*nbpeer.Peer{
|
||||
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
||||
}
|
||||
|
||||
func initGroupTestData(t *testing.T, initGroups ...*types.Group) *handler {
|
||||
t.Helper()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
permissionsManagerMock.EXPECT().
|
||||
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
|
||||
Return(true, nil).
|
||||
AnyTimes()
|
||||
|
||||
func initGroupTestData(initGroups ...*types.Group) *handler {
|
||||
return &handler{
|
||||
permissionsManager: permissionsManagerMock,
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error {
|
||||
if !strings.HasPrefix(group.ID, "id-") {
|
||||
@@ -85,14 +71,14 @@ func initGroupTestData(t *testing.T, initGroups ...*types.Group) *handler {
|
||||
|
||||
return groups, nil
|
||||
},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
|
||||
if groupName == "All" {
|
||||
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "unknown group name")
|
||||
},
|
||||
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) {
|
||||
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||
return maps.Values(TestPeers), nil
|
||||
},
|
||||
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
|
||||
@@ -142,7 +128,7 @@ func TestGetGroup(t *testing.T) {
|
||||
Name: "Group",
|
||||
}
|
||||
|
||||
p := initGroupTestData(t, group)
|
||||
p := initGroupTestData(group)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -155,7 +141,7 @@ func TestGetGroup(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.getGroup)).Methods("GET")
|
||||
router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -268,7 +254,7 @@ func TestWriteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := initGroupTestData(t)
|
||||
p := initGroupTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -281,8 +267,8 @@ func TestWriteGroup(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups", permissions.WrapHandler(p.createGroup)).Methods("POST")
|
||||
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.updateGroup)).Methods("PUT")
|
||||
router.HandleFunc("/api/groups", p.createGroup).Methods("POST")
|
||||
router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -346,7 +332,7 @@ func TestGetAllGroups(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := initGroupTestData(t)
|
||||
p := initGroupTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -359,7 +345,7 @@ func TestGetAllGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups", permissions.WrapHandler(p.getAllGroups)).Methods("GET")
|
||||
router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -428,7 +414,7 @@ func TestDeleteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := initGroupTestData(t)
|
||||
p := initGroupTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -440,7 +426,7 @@ func TestDeleteGroup(t *testing.T) {
|
||||
AccountId: "test_id",
|
||||
})
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.deleteGroup)).Methods("DELETE")
|
||||
router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -6,12 +6,9 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -23,13 +20,13 @@ type handler struct {
|
||||
}
|
||||
|
||||
// AddEndpoints registers identity provider endpoints
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
h := newHandler(accountManager)
|
||||
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getAllIdentityProviders)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Create, h.createIdentityProvider)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getIdentityProvider)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Update, h.updateIdentityProvider)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Delete, h.deleteIdentityProvider)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newHandler(accountManager account.Manager) *handler {
|
||||
@@ -39,8 +36,16 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllIdentityProviders returns all identity providers for the account
|
||||
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
providers, err := h.accountManager.GetIdentityProviders(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -55,7 +60,15 @@ func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// getIdentityProvider returns a specific identity provider
|
||||
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
@@ -63,7 +76,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, us
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.accountManager.GetIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId)
|
||||
provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -73,7 +86,15 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, us
|
||||
}
|
||||
|
||||
// createIdentityProvider creates a new identity provider
|
||||
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.IdentityProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -82,7 +103,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request,
|
||||
|
||||
idp := fromAPIRequest(&req)
|
||||
|
||||
created, err := h.accountManager.CreateIdentityProvider(r.Context(), userAuth.AccountId, userAuth.UserId, idp)
|
||||
created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -92,7 +113,15 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
// updateIdentityProvider updates an existing identity provider
|
||||
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
@@ -108,7 +137,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request,
|
||||
|
||||
idp := fromAPIRequest(&req)
|
||||
|
||||
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId, idp)
|
||||
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -118,7 +147,15 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
// deleteIdentityProvider deletes an identity provider
|
||||
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
@@ -126,7 +163,7 @@ func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.accountManager.DeleteIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId); err != nil {
|
||||
if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -121,7 +120,7 @@ func TestGetAllIdentityProviders(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.getAllIdentityProviders)).Methods("GET")
|
||||
router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -181,7 +180,7 @@ func TestGetIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.getIdentityProvider)).Methods("GET")
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -243,7 +242,7 @@ func TestCreateIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.createIdentityProvider)).Methods("POST")
|
||||
router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -329,7 +328,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.updateIdentityProvider)).Methods("PUT")
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -389,7 +388,7 @@ func TestDeleteIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.deleteIdentityProvider)).Methods("DELETE")
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -7,11 +7,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -33,12 +29,12 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
}
|
||||
|
||||
// AddVersionEndpoint registers the authenticated version endpoint.
|
||||
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
instanceManager: instanceManager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/instance/version", permissionsManager.WithPermission(modules.Settings, operations.Read, h.getVersionInfo)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// getInstanceStatus returns the instance status including whether setup is required.
|
||||
@@ -81,7 +77,7 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// getVersionInfo returns version information for NetBird components.
|
||||
// This endpoint requires authentication.
|
||||
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
|
||||
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)
|
||||
|
||||
@@ -10,17 +10,12 @@ import (
|
||||
"net/mail"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -300,15 +295,8 @@ func TestSetup_ManagerError(t *testing.T) {
|
||||
|
||||
func TestGetVersionInfo_Success(t *testing.T) {
|
||||
manager := &mockInstanceManager{}
|
||||
ctrl := gomock.NewController(t)
|
||||
permissionsManager := permissions.NewMockManager(ctrl)
|
||||
permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
handler(w, r, &auth.UserAuth{})
|
||||
}
|
||||
}).AnyTimes()
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router, permissionsManager)
|
||||
AddVersionEndpoint(manager, router)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -335,15 +323,8 @@ func TestGetVersionInfo_Error(t *testing.T) {
|
||||
return nil, errors.New("failed to fetch versions")
|
||||
},
|
||||
}
|
||||
ctrl := gomock.NewController(t)
|
||||
permissionsManager := permissions.NewMockManager(ctrl)
|
||||
permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
handler(w, r, &auth.UserAuth{})
|
||||
}
|
||||
}).AnyTimes()
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router, permissionsManager)
|
||||
AddVersionEndpoint(manager, router)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -9,10 +9,8 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
@@ -20,7 +18,6 @@ import (
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -36,16 +33,16 @@ type handler struct {
|
||||
groupsManager groups.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
addRouterEndpoints(routerManager, permissionsManager, router)
|
||||
addResourceEndpoints(resourceManager, groupsManager, permissionsManager, router)
|
||||
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, router *mux.Router) {
|
||||
addRouterEndpoints(routerManager, router)
|
||||
addResourceEndpoints(resourceManager, groupsManager, router)
|
||||
|
||||
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager)
|
||||
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getAllNetworks)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Create, networksHandler.createNetwork)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getNetwork)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Update, networksHandler.updateNetwork)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, networksHandler.deleteNetwork)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager) *handler {
|
||||
@@ -58,32 +55,40 @@ func newHandler(networksManager networks.Manager, resourceManager resources.Mana
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networks, err := h.networksManager.GetAllNetworks(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -92,9 +97,16 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request, userAut
|
||||
util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account))
|
||||
}
|
||||
|
||||
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.NetworkRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -103,14 +115,14 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth
|
||||
network := &types.Network{}
|
||||
network.FromAPIRequest(&req)
|
||||
|
||||
network.AccountID = userAuth.AccountId
|
||||
network, err = h.networksManager.CreateNetwork(r.Context(), userAuth.UserId, network)
|
||||
network.AccountID = accountID
|
||||
network, err = h.networksManager.CreateNetwork(r.Context(), userID, network)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -121,7 +133,14 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs))
|
||||
}
|
||||
|
||||
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
@@ -129,19 +148,19 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
return
|
||||
}
|
||||
|
||||
network, err := h.networksManager.GetNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -152,7 +171,14 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
|
||||
}
|
||||
|
||||
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
@@ -161,7 +187,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth
|
||||
}
|
||||
|
||||
var req api.NetworkRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -171,20 +197,20 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth
|
||||
network.FromAPIRequest(&req)
|
||||
|
||||
network.ID = networkID
|
||||
network.AccountID = userAuth.AccountId
|
||||
network, err = h.networksManager.UpdateNetwork(r.Context(), userAuth.UserId, network)
|
||||
network.AccountID = accountID
|
||||
network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -195,7 +221,14 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
|
||||
}
|
||||
|
||||
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
@@ -203,7 +236,7 @@ func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request, userAuth
|
||||
return
|
||||
}
|
||||
|
||||
err := h.networksManager.DeleteNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -6,13 +6,10 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -22,14 +19,14 @@ type resourceHandler struct {
|
||||
groupsManager groups.Manager
|
||||
}
|
||||
|
||||
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) {
|
||||
resourceHandler := newResourceHandler(resourcesManager, groupsManager)
|
||||
router.HandleFunc("/networks/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInAccount)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInNetwork)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Create, resourceHandler.createResource)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getResource)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Update, resourceHandler.updateResource)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, resourceHandler.deleteResource)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler {
|
||||
@@ -39,15 +36,22 @@ func newResourceHandler(resourceManager resources.Manager, groupsManager groups.
|
||||
}
|
||||
}
|
||||
|
||||
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -62,14 +66,22 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resourcesResponse)
|
||||
}
|
||||
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -85,9 +97,17 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt
|
||||
util.WriteJSONObject(r.Context(), w, resourcesResponse)
|
||||
}
|
||||
|
||||
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.NetworkResourceRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -97,14 +117,14 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request,
|
||||
resource.FromAPIRequest(&req)
|
||||
|
||||
resource.NetworkID = mux.Vars(r)["networkId"]
|
||||
resource.AccountID = userAuth.AccountId
|
||||
resource, err = h.resourceManager.CreateResource(r.Context(), userAuth.UserId, resource)
|
||||
resource.AccountID = accountID
|
||||
resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -115,16 +135,23 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request,
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
|
||||
}
|
||||
|
||||
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resourceID := mux.Vars(r)["resourceId"]
|
||||
resource, err := h.resourceManager.GetResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
|
||||
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -135,9 +162,16 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request, us
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
|
||||
}
|
||||
|
||||
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
var req api.NetworkResourceRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -148,14 +182,14 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request,
|
||||
|
||||
resource.ID = mux.Vars(r)["resourceId"]
|
||||
resource.NetworkID = mux.Vars(r)["networkId"]
|
||||
resource.AccountID = userAuth.AccountId
|
||||
resource, err = h.resourceManager.UpdateResource(r.Context(), userAuth.UserId, resource)
|
||||
resource.AccountID = accountID
|
||||
resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -166,10 +200,17 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request,
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
|
||||
}
|
||||
|
||||
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resourceID := mux.Vars(r)["resourceId"]
|
||||
err := h.resourceManager.DeleteResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
|
||||
err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -6,12 +6,9 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -20,14 +17,14 @@ type routersHandler struct {
|
||||
routersManager routers.Manager
|
||||
}
|
||||
|
||||
func addRouterEndpoints(routersManager routers.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) {
|
||||
routersHandler := newRoutersHandler(routersManager)
|
||||
router.HandleFunc("/networks/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getAllRouters)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getNetworkRouters)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Create, routersHandler.createRouter)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getRouter)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Update, routersHandler.updateRouter)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, routersHandler.deleteRouter)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/networks/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getNetworkRouters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newRoutersHandler(routersManager routers.Manager) *routersHandler {
|
||||
@@ -36,8 +33,16 @@ func newRoutersHandler(routersManager routers.Manager) *routersHandler {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -53,9 +58,17 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request, u
|
||||
util.WriteJSONObject(r.Context(), w, routersResponse)
|
||||
}
|
||||
|
||||
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -69,10 +82,18 @@ func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Reques
|
||||
util.WriteJSONObject(r.Context(), w, routersResponse)
|
||||
}
|
||||
|
||||
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
var req api.NetworkRouterRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -82,7 +103,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, us
|
||||
router.FromAPIRequest(&req)
|
||||
|
||||
router.NetworkID = networkID
|
||||
router.AccountID = userAuth.AccountId
|
||||
router.AccountID = accountID
|
||||
router.Enabled = true
|
||||
|
||||
if err := router.Validate(); err != nil {
|
||||
@@ -90,7 +111,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, us
|
||||
return
|
||||
}
|
||||
|
||||
router, err = h.routersManager.CreateRouter(r.Context(), userAuth.UserId, router)
|
||||
router, err = h.routersManager.CreateRouter(r.Context(), userID, router)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -99,10 +120,18 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, us
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routerID := mux.Vars(r)["routerId"]
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
router, err := h.routersManager.GetRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
|
||||
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -111,9 +140,17 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request, userA
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.NetworkRouterRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -124,14 +161,14 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, us
|
||||
|
||||
router.NetworkID = mux.Vars(r)["networkId"]
|
||||
router.ID = mux.Vars(r)["routerId"]
|
||||
router.AccountID = userAuth.AccountId
|
||||
router.AccountID = accountID
|
||||
|
||||
if err := router.Validate(); err != nil {
|
||||
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
router, err = h.routersManager.UpdateRouter(r.Context(), userAuth.UserId, router)
|
||||
router, err = h.routersManager.UpdateRouter(r.Context(), userID, router)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -140,10 +177,17 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, us
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
routerID := mux.Vars(r)["routerId"]
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
err := h.routersManager.DeleteRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
|
||||
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package peers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -11,15 +12,15 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -34,15 +35,14 @@ type Handler struct {
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
|
||||
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
|
||||
router.HandleFunc("/peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAllPeers)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetPeer)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Update, peersHandler.UpdatePeer)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Delete, peersHandler.DeletePeer)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/accessible-peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAccessiblePeers)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/temporary-access", permissionsManager.WithPermission(modules.Peers, operations.Create, peersHandler.CreateTemporaryAccess)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.ListJobs)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Create, peersHandler.CreateJob)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.GetJob)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// NewHandler creates a new peers Handler
|
||||
@@ -54,7 +54,14 @@ func NewHandler(accountManager account.Manager, networkMapController network_map
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
|
||||
@@ -66,30 +73,37 @@ func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request, userAuth *au
|
||||
|
||||
job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
if err := h.accountManager.CreatePeerJob(r.Context(), userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
|
||||
jobs, err := h.accountManager.GetAllPeerJobs(r.Context(), userAuth.AccountId, userAuth.UserId, peerID)
|
||||
jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -97,88 +111,79 @@ func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request, userAuth *aut
|
||||
for _, job := range jobs {
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
respBody = append(respBody, resp)
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, respBody)
|
||||
util.WriteJSONObject(ctx, w, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
jobID := vars["jobId"]
|
||||
|
||||
job, err := h.accountManager.GetPeerJobByID(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, jobID)
|
||||
job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
}
|
||||
|
||||
// GetPeer handles GET request for a single peer
|
||||
func (h *Handler) GetPeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
|
||||
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if peer.ProxyMeta.Embedded {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
|
||||
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
grps, _ := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peerID)
|
||||
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
|
||||
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
reason := invalidPeers[peer.ID]
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
}
|
||||
|
||||
// UpdatePeer handles PUT request to update a peer
|
||||
func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
req := &api.PeerRequest{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@@ -187,10 +192,11 @@ func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
}
|
||||
|
||||
update := &nbpeer.Peer{
|
||||
ID: peerID,
|
||||
SSHEnabled: req.SshEnabled,
|
||||
Name: req.Name,
|
||||
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
||||
ID: peerID,
|
||||
SSHEnabled: req.SshEnabled,
|
||||
Name: req.Name,
|
||||
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
||||
|
||||
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
|
||||
}
|
||||
|
||||
@@ -204,41 +210,41 @@ func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
if req.Ip != nil {
|
||||
addr, err := netip.ParseAddr(*req.Ip)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
|
||||
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.accountManager.UpdatePeerIP(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, addr); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.UpdatePeer(r.Context(), userAuth.AccountId, userAuth.UserId, update)
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peer.ID)
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
|
||||
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -248,8 +254,25 @@ func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
}
|
||||
|
||||
// DeletePeer handles DELETE request to delete a peer
|
||||
func (h *Handler) DeletePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(ctx, w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||
func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
@@ -257,34 +280,48 @@ func (h *Handler) DeletePeer(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeletePeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to delete peer: %v", err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
switch r.Method {
|
||||
case http.MethodDelete:
|
||||
h.deletePeer(r.Context(), accountID, userID, peerID, w)
|
||||
return
|
||||
case http.MethodGet:
|
||||
h.getPeer(r.Context(), accountID, peerID, userID, w)
|
||||
return
|
||||
case http.MethodPut:
|
||||
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
|
||||
return
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// GetAllPeers returns a list of all peers associated with a provided account
|
||||
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nameFilter := r.URL.Query().Get("name")
|
||||
ipFilter := r.URL.Query().Get("ip")
|
||||
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, nameFilter, ipFilter, true)
|
||||
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
nameFilter := r.URL.Query().Get("name")
|
||||
ipFilter := r.URL.Query().Get("ip")
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, nameFilter, ipFilter)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
grps, _ := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
@@ -295,7 +332,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
|
||||
}
|
||||
|
||||
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
@@ -319,7 +356,15 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersM
|
||||
}
|
||||
|
||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
@@ -327,22 +372,25 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, use
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is an admin/service user through their role
|
||||
isAdmin := user.Role == types.UserRoleAdmin || user.Role == types.UserRoleOwner
|
||||
|
||||
if !isAdmin && !user.IsServiceUser && !userAuth.IsChild {
|
||||
if !allowed && !userAuth.IsChild {
|
||||
if account.Settings.RegularUsersViewBlocked {
|
||||
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
|
||||
return
|
||||
@@ -360,7 +408,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, use
|
||||
}
|
||||
}
|
||||
|
||||
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
@@ -369,12 +417,18 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, use
|
||||
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
@@ -383,7 +437,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
var req api.PeerTemporaryAccessRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
|
||||
@@ -19,11 +19,11 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -174,7 +174,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
},
|
||||
GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) {
|
||||
GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||
return peers, nil
|
||||
},
|
||||
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
|
||||
@@ -307,9 +307,9 @@ func TestGetPeers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/", permissions.WrapHandler(p.GetAllPeers)).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.GetPeer)).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT")
|
||||
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -498,7 +498,7 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/{peerId}/accessible-peers", permissions.WrapHandler(p.GetAccessiblePeers)).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -582,7 +582,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT")
|
||||
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT")
|
||||
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
|
||||
@@ -14,12 +14,12 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -121,7 +121,7 @@ func TestGetCitiesByCountry(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/locations/countries/{country}/cities", permissions.WrapHandler(geolocationHandler.getCitiesByCountry)).Methods("GET")
|
||||
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -214,7 +214,7 @@ func TestGetAllCountries(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/locations/countries", permissions.WrapHandler(geolocationHandler.getAllCountries)).Methods("GET")
|
||||
router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -6,12 +6,12 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -30,8 +30,8 @@ type geolocationsHandler struct {
|
||||
|
||||
func AddLocationsEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, permissionsManager)
|
||||
router.HandleFunc("/locations/countries", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getAllCountries)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/locations/countries/{country}/cities", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getCitiesByCountry)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newGeolocationsHandlerHandler creates a new Geolocations handler
|
||||
@@ -44,7 +44,12 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
|
||||
}
|
||||
|
||||
// getAllCountries retrieves a list of all countries
|
||||
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if l.geolocationManager == nil {
|
||||
// TODO: update error message to include geo db self hosted doc link when ready
|
||||
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||
@@ -65,7 +70,12 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
|
||||
// getCitiesByCountry retrieves a list of cities based on the given country code
|
||||
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
countryCode := vars["country"]
|
||||
if !countryCodeRegex.MatchString(countryCode) {
|
||||
@@ -92,6 +102,27 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
|
||||
util.WriteJSONObject(r.Context(), w, cities)
|
||||
}
|
||||
|
||||
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toCountryResponse(country geolocation.Country) api.Country {
|
||||
return api.Country{
|
||||
CountryName: country.CountryName,
|
||||
|
||||
@@ -7,13 +7,10 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -24,13 +21,13 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
|
||||
policiesHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getAllPolicies)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Create, policiesHandler.createPolicy)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Update, policiesHandler.updatePolicy)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getPolicy)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, policiesHandler.deletePolicy)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new policies handler
|
||||
@@ -41,14 +38,22 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllPolicies list for the account
|
||||
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
listPolicies, err := h.accountManager.ListPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -68,7 +73,15 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request, userAut
|
||||
}
|
||||
|
||||
// updatePolicy handles update to a policy identified by a given ID
|
||||
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@@ -76,18 +89,26 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request, userAuth
|
||||
return
|
||||
}
|
||||
|
||||
_, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
|
||||
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, policyID, false)
|
||||
h.savePolicy(w, r, accountID, userID, policyID, false)
|
||||
}
|
||||
|
||||
// createPolicy handles policy creation request
|
||||
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, "", true)
|
||||
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
h.savePolicy(w, r, accountID, userID, "", true)
|
||||
}
|
||||
|
||||
// savePolicy handles policy creation and update
|
||||
@@ -282,7 +303,14 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
}
|
||||
|
||||
// deletePolicy handles policy deletion request
|
||||
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@@ -290,7 +318,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.accountManager.DeletePolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId); err != nil {
|
||||
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@@ -299,7 +327,15 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth
|
||||
}
|
||||
|
||||
// getPolicy handles a group Get request identified by ID
|
||||
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@@ -307,13 +343,13 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request, userAuth *au
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
|
||||
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -112,7 +111,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.getPolicy)).Methods("GET")
|
||||
router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -276,8 +275,8 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/policies", permissions.WrapHandler(p.createPolicy)).Methods("POST")
|
||||
router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.updatePolicy)).Methods("PUT")
|
||||
router.HandleFunc("/api/policies", p.createPolicy).Methods("POST")
|
||||
router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -6,13 +6,10 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -24,13 +21,13 @@ type postureChecksHandler struct {
|
||||
geolocationManager geolocation.Geolocation
|
||||
}
|
||||
|
||||
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
|
||||
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager)
|
||||
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getAllPostureChecks)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Create, postureCheckHandler.createPostureCheck)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Update, postureCheckHandler.updatePostureCheck)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getPostureCheck)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, postureCheckHandler.deletePostureCheck)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newPostureChecksHandler creates a new PostureChecks handler
|
||||
@@ -42,8 +39,15 @@ func newPostureChecksHandler(accountManager account.Manager, geolocationManager
|
||||
}
|
||||
|
||||
// getAllPostureChecks list for the account
|
||||
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -58,30 +62,15 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// updatePostureCheck handles update to a posture check identified by a given ID
|
||||
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
|
||||
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, postureChecksID, false)
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
// createPostureCheck handles posture check creation request
|
||||
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, "", true)
|
||||
}
|
||||
|
||||
// getPostureCheck handles a posture check Get request identified by ID
|
||||
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
@@ -89,7 +78,45 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
|
||||
return
|
||||
}
|
||||
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
|
||||
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
p.savePostureChecks(w, r, accountID, userID, postureChecksID, false)
|
||||
}
|
||||
|
||||
// createPostureCheck handles posture check creation request
|
||||
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
p.savePostureChecks(w, r, accountID, userID, "", true)
|
||||
}
|
||||
|
||||
// getPostureCheck handles a posture check Get request identified by ID
|
||||
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -99,7 +126,14 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
|
||||
}
|
||||
|
||||
// deletePostureCheck handles posture check deletion request
|
||||
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
@@ -107,7 +141,7 @@ func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.accountManager.DeletePostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId); err != nil {
|
||||
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
@@ -184,7 +183,7 @@ func TestGetPostureCheck(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(p.getPostureCheck)).Methods("GET")
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -842,8 +841,8 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/posture-checks", permissions.WrapHandler(defaultHandler.createPostureCheck)).Methods("POST")
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(defaultHandler.updatePostureCheck)).Methods("PUT")
|
||||
router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST")
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -8,12 +8,9 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -29,13 +26,13 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
routesHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getAllRoutes)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Create, routesHandler.createRoute)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Update, routesHandler.updateRoute)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getRoute)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Delete, routesHandler.deleteRoute)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler returns a new instance of routes handler
|
||||
@@ -46,8 +43,16 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllRoutes returns the list of routes for the account
|
||||
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routes, err := h.accountManager.ListRoutes(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -66,9 +71,17 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request, userAuth
|
||||
}
|
||||
|
||||
// createRoute handles route creation request
|
||||
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.PostApiRoutesJSONRequestBody
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -121,8 +134,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
skipAutoApply = false
|
||||
}
|
||||
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), userAuth.AccountId, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userAuth.UserId, req.KeepRoute, skipAutoApply)
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply)
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -172,7 +185,14 @@ func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *
|
||||
}
|
||||
|
||||
// updateRoute handles update to a route identified by a given ID
|
||||
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
routeID := vars["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
@@ -180,7 +200,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
return
|
||||
}
|
||||
|
||||
_, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
|
||||
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -251,7 +271,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
newRoute.AccessControlGroups = *req.AccessControlGroups
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveRoute(r.Context(), userAuth.AccountId, userAuth.UserId, newRoute)
|
||||
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -267,14 +287,21 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
}
|
||||
|
||||
// deleteRoute handles route deletion request
|
||||
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeleteRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
|
||||
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -284,14 +311,22 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
}
|
||||
|
||||
// getRoute handles a route Get request identified by ID
|
||||
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
|
||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
@@ -502,10 +501,10 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.getRoute)).Methods("GET")
|
||||
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.deleteRoute)).Methods("DELETE")
|
||||
router.HandleFunc("/api/routes", permissions.WrapHandler(p.createRoute)).Methods("POST")
|
||||
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.updateRoute)).Methods("PUT")
|
||||
router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET")
|
||||
router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE")
|
||||
router.HandleFunc("/api/routes", p.createRoute).Methods("POST")
|
||||
router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -8,12 +8,9 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -24,13 +21,13 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
keysHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getAllSetupKeys)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Create, keysHandler.createSetupKey)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getSetupKey)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Update, keysHandler.updateSetupKey)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Delete, keysHandler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new setup key handler
|
||||
@@ -41,9 +38,16 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// createSetupKey is a POST requests that creates a new SetupKey
|
||||
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
req := &api.PostApiSetupKeysJSONRequestBody{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -81,8 +85,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAut
|
||||
allowExtraDNSLabels = *req.AllowExtraDnsLabels
|
||||
}
|
||||
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), userAuth.AccountId, req.Name, types.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, userAuth.UserId, ephemeral, allowExtraDNSLabels)
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, userID, ephemeral, allowExtraDNSLabels)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -96,7 +100,14 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAut
|
||||
}
|
||||
|
||||
// getSetupKey is a GET request to get a SetupKey by ID
|
||||
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
@@ -104,7 +115,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.accountManager.GetSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
|
||||
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -114,7 +125,14 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
}
|
||||
|
||||
// updateSetupKey is a PUT request to update server.SetupKey
|
||||
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
@@ -123,7 +141,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAut
|
||||
}
|
||||
|
||||
req := &api.PutApiSetupKeysKeyIdJSONRequestBody{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -139,7 +157,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAut
|
||||
newKey.Revoked = req.Revoked
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), userAuth.AccountId, newKey, userAuth.UserId)
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -148,8 +166,15 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAut
|
||||
}
|
||||
|
||||
// getAllSetupKeys is a GET request that returns a list of SetupKey
|
||||
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -163,7 +188,14 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request, userAu
|
||||
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||
}
|
||||
|
||||
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
@@ -171,7 +203,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request, userAut
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeleteSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
|
||||
err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -172,11 +171,11 @@ func TestSetupKeysHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.getAllSetupKeys)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.createSetupKey)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.getSetupKey)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.updateSetupKey)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -9,13 +9,10 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -58,14 +55,14 @@ type invitesHandler struct {
|
||||
}
|
||||
|
||||
// AddInvitesEndpoints registers invite-related endpoints
|
||||
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
h := &invitesHandler{accountManager: accountManager}
|
||||
|
||||
// Authenticated endpoints (require admin)
|
||||
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Read, h.listInvites)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Create, h.createInvite)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}", permissionsManager.WithPermission(modules.Users, operations.Delete, h.deleteInvite)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}/regenerate", permissionsManager.WithPermission(modules.Users, operations.Update, h.regenerateInvite)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
|
||||
@@ -82,7 +79,14 @@ func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Route
|
||||
}
|
||||
|
||||
// listInvites handles GET /api/users/invites
|
||||
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -98,7 +102,14 @@ func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, use
|
||||
}
|
||||
|
||||
// createInvite handles POST /api/users/invites
|
||||
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.UserInviteCreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -180,12 +191,18 @@ func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
|
||||
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
@@ -221,7 +238,14 @@ func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// deleteInvite handles DELETE /api/users/invites/{inviteId}
|
||||
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
@@ -229,7 +253,7 @@ func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, us
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
|
||||
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -110,11 +110,7 @@ func TestListInvites(t *testing.T) {
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.listInvites(rr, req, userAuth)
|
||||
handler.listInvites(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -239,11 +235,7 @@ func TestCreateInvite(t *testing.T) {
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.createInvite(rr, req, userAuth)
|
||||
handler.createInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -581,11 +573,7 @@ func TestRegenerateInvite(t *testing.T) {
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.regenerateInvite(rr, req, userAuth)
|
||||
handler.regenerateInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -663,11 +651,7 @@ func TestDeleteInvite(t *testing.T) {
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.deleteInvite(rr, req, userAuth)
|
||||
handler.deleteInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
})
|
||||
|
||||
@@ -6,12 +6,9 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -22,12 +19,12 @@ type patHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||
tokenHandler := newPATsHandler(accountManager)
|
||||
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getAllTokens)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Create, tokenHandler.createToken)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getToken)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Delete, tokenHandler.deleteToken)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newPATsHandler creates a new patHandler HTTP handler
|
||||
@@ -38,15 +35,22 @@ func newPATsHandler(accountManager account.Manager) *patHandler {
|
||||
}
|
||||
|
||||
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
if len(userID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -61,7 +65,14 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAu
|
||||
}
|
||||
|
||||
// getToken is HTTP GET handler that returns a personal access token for the given user
|
||||
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -75,7 +86,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -85,7 +96,14 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
}
|
||||
|
||||
// createToken is HTTP POST handler that creates a personal access token for the given user
|
||||
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -94,13 +112,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAut
|
||||
}
|
||||
|
||||
var req api.PostApiUsersUserIdTokensJSONRequestBody
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.Name, req.ExpiresIn)
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -110,7 +128,14 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAut
|
||||
}
|
||||
|
||||
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -124,7 +149,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAut
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeletePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
|
||||
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -182,10 +181,10 @@ func TestTokenHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.getAllTokens)).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.getToken)).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.createToken)).Methods("POST")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.deleteToken)).Methods("DELETE")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -8,16 +8,14 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
|
||||
// handler is a handler that returns users of the account
|
||||
@@ -25,18 +23,18 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
userHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getAllUsers, userHandler.getOwnUser)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/current", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getCurrentUser, userHandler.getCurrentUserFallback)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.updateUser)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.deleteUser)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.createUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.inviteUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/approve", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.approveUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/reject", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.rejectUser)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/password", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.changePassword)).Methods("PUT", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, router, permissionsManager)
|
||||
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, router)
|
||||
}
|
||||
|
||||
// newHandler creates a new UsersHandler HTTP handler
|
||||
@@ -47,12 +45,19 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// updateUser is a PUT requests to update User data
|
||||
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -66,11 +71,6 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
return
|
||||
}
|
||||
|
||||
if existingUser.AccountID != userAuth.AccountId {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "user not found"), w)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PutApiUsersUserIdJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@@ -89,7 +89,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.User{
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{
|
||||
Id: targetUserID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
@@ -102,16 +102,23 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
}
|
||||
|
||||
// deleteUser is a DELETE request to delete a user
|
||||
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -119,7 +126,7 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeleteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -129,14 +136,21 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
}
|
||||
|
||||
// createUser creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
req := &api.PostApiUsersJSONRequestBody{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -157,7 +171,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
name = *req.Name
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.UserInfo{
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: req.Role,
|
||||
@@ -169,18 +183,25 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
}
|
||||
|
||||
// getAllUsers returns a list of users of the account this user belongs to.
|
||||
// It also gathers additional user data (like email and name) from the IDP manager.
|
||||
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -194,7 +215,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
continue
|
||||
}
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(d, userAuth.UserId))
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -205,7 +226,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
return
|
||||
}
|
||||
if includeServiceUser == d.IsServiceUser {
|
||||
users = append(users, toUserResponse(d, userAuth.UserId))
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,12 +235,19 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
|
||||
// inviteUser resend invitations to users who haven't activated their accounts,
|
||||
// prior to the expiration period.
|
||||
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -227,7 +255,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.InviteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -236,13 +264,19 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
|
||||
user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -322,7 +356,7 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
|
||||
}
|
||||
|
||||
// approveUser is a POST request to approve a user that is pending approval
|
||||
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -335,6 +369,11 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -346,7 +385,7 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
}
|
||||
|
||||
// rejectUser is a DELETE request to reject a user that is pending approval
|
||||
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -359,7 +398,12 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *a
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -377,7 +421,7 @@ type passwordChangeRequest struct {
|
||||
// changePassword is a PUT request to change user's password.
|
||||
// Only available when embedded IDP is enabled.
|
||||
// Users can only change their own password.
|
||||
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -390,13 +434,19 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAut
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req passwordChangeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
|
||||
err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -404,39 +454,3 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAut
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func (h *handler) getCurrentUserFallback(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.PermissionDenied {
|
||||
return false
|
||||
}
|
||||
|
||||
user, userErr := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
|
||||
if userErr != nil {
|
||||
util.WriteError(r.Context(), userErr, w)
|
||||
return true
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toUserWithPermissionsResponse(user, userAuth.UserId))
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *handler) getOwnUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.PermissionDenied {
|
||||
return false
|
||||
}
|
||||
|
||||
if r.URL.Query().Get("service_user") != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
user, userErr := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
|
||||
if userErr != nil {
|
||||
util.WriteError(r.Context(), userErr, w)
|
||||
return true
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, []*api.User{toUserResponse(user.UserInfo, userAuth.UserId)})
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -15,11 +15,10 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
roles2 "github.com/netbirdio/netbird/management/internals/modules/permissions/roles"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
@@ -39,7 +38,6 @@ var usersTestAccount = &types.Account{
|
||||
Users: map[string]*types.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
AccountID: existingAccountID,
|
||||
Role: "admin",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
@@ -47,7 +45,6 @@ var usersTestAccount = &types.Account{
|
||||
},
|
||||
regularUserID: {
|
||||
Id: regularUserID,
|
||||
AccountID: existingAccountID,
|
||||
Role: "user",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
@@ -55,7 +52,6 @@ var usersTestAccount = &types.Account{
|
||||
},
|
||||
serviceUserID: {
|
||||
Id: serviceUserID,
|
||||
AccountID: existingAccountID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
AutoGroups: []string{"group_1"},
|
||||
@@ -63,7 +59,6 @@ var usersTestAccount = &types.Account{
|
||||
},
|
||||
nonDeletableServiceUserID: {
|
||||
Id: nonDeletableServiceUserID,
|
||||
AccountID: existingAccountID,
|
||||
Role: "admin",
|
||||
IsServiceUser: true,
|
||||
NonDeletable: true,
|
||||
@@ -156,7 +151,7 @@ func initUsersTestData() *handler {
|
||||
NonDeletable: false,
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles2.Owner),
|
||||
Permissions: mergeRolePermissions(roles.Owner),
|
||||
}, nil
|
||||
case "regular-user":
|
||||
return &users.UserInfoWithPermissions{
|
||||
@@ -170,7 +165,7 @@ func initUsersTestData() *handler {
|
||||
NonDeletable: false,
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles2.User),
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
}, nil
|
||||
|
||||
case "admin-user":
|
||||
@@ -186,7 +181,7 @@ func initUsersTestData() *handler {
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles2.Admin),
|
||||
Permissions: mergeRolePermissions(roles.Admin),
|
||||
}, nil
|
||||
case "restricted-user":
|
||||
return &users.UserInfoWithPermissions{
|
||||
@@ -201,7 +196,7 @@ func initUsersTestData() *handler {
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles2.User),
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
Restricted: true,
|
||||
}, nil
|
||||
}
|
||||
@@ -237,11 +232,7 @@ func TestGetUsers(t *testing.T) {
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.getAllUsers(recorder, req, userAuth)
|
||||
userHandler.getAllUsers(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -352,7 +343,7 @@ func TestUpdateUser(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}", permissions.WrapHandler(userHandler.updateUser)).Methods("PUT")
|
||||
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -448,11 +439,7 @@ func TestCreateUser(t *testing.T) {
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.createUser(rr, req, userAuth)
|
||||
userHandler.createUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -503,11 +490,7 @@ func TestInviteUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.inviteUser(rr, req, userAuth)
|
||||
userHandler.inviteUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -566,11 +549,7 @@ func TestDeleteUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.deleteUser(rr, req, userAuth)
|
||||
userHandler.deleteUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -629,7 +608,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
Issued: ptr("api"),
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Owner)),
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Owner)),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -648,7 +627,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
Issued: ptr("api"),
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)),
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -667,7 +646,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
Issued: ptr("api"),
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Admin)),
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Admin)),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -687,7 +666,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
IsRestricted: true,
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)),
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -703,11 +682,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: tc.requestAuth.UserId,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.getCurrentUser(rr, req, userAuth)
|
||||
userHandler.getCurrentUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -727,8 +702,8 @@ func ptr[T any, PT *T](x T) PT {
|
||||
return &x
|
||||
}
|
||||
|
||||
func mergeRolePermissions(role roles2.RolePermissions) roles2.Permissions {
|
||||
permissions := roles2.Permissions{}
|
||||
func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
|
||||
permissions := roles.Permissions{}
|
||||
|
||||
for k := range modules.All {
|
||||
if rolePermissions, ok := role.Permissions[k]; ok {
|
||||
@@ -741,7 +716,7 @@ func mergeRolePermissions(role roles2.RolePermissions) roles2.Permissions {
|
||||
return permissions
|
||||
}
|
||||
|
||||
func stringifyPermissionsKeys(permissions roles2.Permissions) map[string]map[string]bool {
|
||||
func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[string]bool {
|
||||
modules := make(map[string]map[string]bool)
|
||||
for module, operations := range permissions {
|
||||
modules[string(module)] = make(map[string]bool)
|
||||
@@ -804,7 +779,7 @@ func TestApproveUserEndpoint(t *testing.T) {
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/approve", permissions.WrapHandler(handler.approveUser)).Methods("POST")
|
||||
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
|
||||
|
||||
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
|
||||
require.NoError(t, err)
|
||||
@@ -862,7 +837,7 @@ func TestRejectUserEndpoint(t *testing.T) {
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/reject", permissions.WrapHandler(handler.rejectUser)).Methods("DELETE")
|
||||
router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE")
|
||||
|
||||
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
|
||||
require.NoError(t, err)
|
||||
@@ -953,7 +928,7 @@ func TestChangePasswordEndpoint(t *testing.T) {
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/password", permissions.WrapHandler(handler.changePassword)).Methods("PUT")
|
||||
router.HandleFunc("/users/{userId}/password", handler.changePassword).Methods("PUT")
|
||||
|
||||
reqPath := "/users/" + tc.targetUserID + "/password"
|
||||
req, err := http.NewRequest("PUT", reqPath, bytes.NewBufferString(tc.requestBody))
|
||||
@@ -992,7 +967,7 @@ func TestChangePasswordEndpoint_WrongMethod(t *testing.T) {
|
||||
req = nbcontext.SetUserAuthInRequest(req, userAuth)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.changePassword(rr, req, &userAuth)
|
||||
handler.changePassword(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func Test_Accounts_GetAll(t *testing.T) {
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, true},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
@@ -233,71 +233,6 @@ func Test_Accounts_Update(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Accounts_Update_CrossAccountAttack(t *testing.T) {
|
||||
t.Run("Other user attempts to update testAccount via URL", func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(&api.AccountRequest{
|
||||
Settings: api.AccountSettings{
|
||||
PeerLoginExpirationEnabled: false,
|
||||
PeerLoginExpiration: 86400,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
// OtherUserId belongs to otherAccountId, but we target testAccountId in URL
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account update must be rejected")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Accounts_Delete(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, false},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, false},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Delete account", func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Accounts_Delete_CrossAccountAttack(t *testing.T) {
|
||||
t.Run("Other user attempts to delete testAccount via URL", func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
|
||||
|
||||
// OtherUserId belongs to otherAccountId, but we target testAccountId in URL
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account delete must be rejected")
|
||||
})
|
||||
}
|
||||
|
||||
func stringPointer(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
@@ -1,445 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func Test_Records_GetAll(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get all records", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones/testZoneId/records", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.DNSRecord{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, len(got))
|
||||
assert.Equal(t, "sub.example.com", got[0].Name)
|
||||
assert.Equal(t, api.DNSRecordTypeA, got[0].Type)
|
||||
assert.Equal(t, "1.2.3.4", got[0].Content)
|
||||
assert.Equal(t, 300, got[0].Ttl)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Records_GetById(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
recordId string
|
||||
expectedStatus int
|
||||
expectRecord bool
|
||||
}{
|
||||
{
|
||||
name: "Get existing record",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "testRecordId",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectRecord: true,
|
||||
},
|
||||
{
|
||||
name: "Get non-existing record",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "nonExistingRecordId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectRecord: false,
|
||||
},
|
||||
{
|
||||
name: "Get record from non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
recordId: "testRecordId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectRecord: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
|
||||
|
||||
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
|
||||
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expectRecord {
|
||||
got := &api.DNSRecord{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.Equal(t, "testRecordId", got.Id)
|
||||
assert.Equal(t, "sub.example.com", got.Name)
|
||||
assert.Equal(t, api.DNSRecordTypeA, got.Type)
|
||||
assert.Equal(t, "1.2.3.4", got.Content)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Records_Create(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
requestBody *api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
|
||||
expectedStatus int
|
||||
verifyResponse func(t *testing.T, record *api.DNSRecord)
|
||||
}{
|
||||
{
|
||||
name: "Create A record",
|
||||
zoneId: "testZoneId",
|
||||
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
|
||||
Name: "new.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "5.6.7.8",
|
||||
Ttl: 600,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
|
||||
t.Helper()
|
||||
assert.NotEmpty(t, record.Id)
|
||||
assert.Equal(t, "new.example.com", record.Name)
|
||||
assert.Equal(t, api.DNSRecordTypeA, record.Type)
|
||||
assert.Equal(t, "5.6.7.8", record.Content)
|
||||
assert.Equal(t, 600, record.Ttl)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create CNAME record",
|
||||
zoneId: "testZoneId",
|
||||
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
|
||||
Name: "alias.example.com",
|
||||
Type: api.DNSRecordTypeCNAME,
|
||||
Content: "target.example.com",
|
||||
Ttl: 300,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
|
||||
t.Helper()
|
||||
assert.NotEmpty(t, record.Id)
|
||||
assert.Equal(t, "alias.example.com", record.Name)
|
||||
assert.Equal(t, api.DNSRecordTypeCNAME, record.Type)
|
||||
assert.Equal(t, "target.example.com", record.Content)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create record with invalid content for A type",
|
||||
zoneId: "testZoneId",
|
||||
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
|
||||
Name: "bad.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "not-an-ip",
|
||||
Ttl: 300,
|
||||
},
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
},
|
||||
{
|
||||
name: "Create record in non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
|
||||
Name: "new.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "5.6.7.8",
|
||||
Ttl: 600,
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
path := strings.Replace("/api/dns/zones/{zoneId}/records", "{zoneId}", tc.zoneId, 1)
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPost, path, user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.verifyResponse != nil {
|
||||
got := &api.DNSRecord{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify the created record directly in the DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbRecord := testing_tools.VerifyRecordInDB(t, db, got.Id)
|
||||
assert.Equal(t, got.Name, dbRecord.Name)
|
||||
assert.Equal(t, got.Content, dbRecord.Content)
|
||||
assert.Equal(t, got.Ttl, dbRecord.TTL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Records_Update(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
recordId string
|
||||
requestBody *api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
|
||||
expectedStatus int
|
||||
verifyResponse func(t *testing.T, record *api.DNSRecord)
|
||||
}{
|
||||
{
|
||||
name: "Update record content and TTL",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "testRecordId",
|
||||
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
|
||||
Name: "sub.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "10.20.30.40",
|
||||
Ttl: 600,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "sub.example.com", record.Name)
|
||||
assert.Equal(t, "10.20.30.40", record.Content)
|
||||
assert.Equal(t, 600, record.Ttl)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Update non-existing record",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "nonExistingRecordId",
|
||||
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
|
||||
Name: "sub.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "10.20.30.40",
|
||||
Ttl: 600,
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "Update record in non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
recordId: "testRecordId",
|
||||
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
|
||||
Name: "sub.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "10.20.30.40",
|
||||
Ttl: 600,
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
|
||||
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPut, path, user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.verifyResponse != nil {
|
||||
got := &api.DNSRecord{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify the updated record directly in the DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbRecord := testing_tools.VerifyRecordInDB(t, db, tc.recordId)
|
||||
assert.Equal(t, "10.20.30.40", dbRecord.Content)
|
||||
assert.Equal(t, 600, dbRecord.TTL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Records_Delete(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
recordId string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Delete existing record",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "testRecordId",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete non-existing record",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "nonExistingRecordId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "Delete record from non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
recordId: "testRecordId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
|
||||
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
|
||||
// Verify deletion in DB for successful deletes by privileged users
|
||||
if tc.expectedStatus == http.StatusOK && user.expectResponse {
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
testing_tools.VerifyRecordNotInDB(t, db, tc.recordId)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,416 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func Test_Zones_GetAll(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get all zones", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.Zone{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, len(got))
|
||||
assert.Equal(t, "Test Zone", got[0].Name)
|
||||
assert.Equal(t, "example.com", got[0].Domain)
|
||||
assert.Equal(t, true, got[0].Enabled)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Zones_GetById(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
expectedStatus int
|
||||
expectZone bool
|
||||
}{
|
||||
{
|
||||
name: "Get existing zone",
|
||||
zoneId: "testZoneId",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectZone: true,
|
||||
},
|
||||
{
|
||||
name: "Get non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectZone: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expectZone {
|
||||
got := &api.Zone{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.Equal(t, "testZoneId", got.Id)
|
||||
assert.Equal(t, "Test Zone", got.Name)
|
||||
assert.Equal(t, "example.com", got.Domain)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Zones_Create(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
enabled := true
|
||||
disabled := false
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
requestBody *api.PostApiDnsZonesJSONRequestBody
|
||||
expectedStatus int
|
||||
verifyResponse func(t *testing.T, zone *api.Zone)
|
||||
}{
|
||||
{
|
||||
name: "Create zone with valid data",
|
||||
requestBody: &api.PostApiDnsZonesJSONRequestBody{
|
||||
Name: "New Zone",
|
||||
Domain: "newzone.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testing_tools.TestGroupId},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, zone *api.Zone) {
|
||||
t.Helper()
|
||||
assert.NotEmpty(t, zone.Id)
|
||||
assert.Equal(t, "New Zone", zone.Name)
|
||||
assert.Equal(t, "newzone.com", zone.Domain)
|
||||
assert.Equal(t, true, zone.Enabled)
|
||||
assert.Equal(t, false, zone.EnableSearchDomain)
|
||||
assert.Equal(t, 1, len(zone.DistributionGroups))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create zone with search domain enabled",
|
||||
requestBody: &api.PostApiDnsZonesJSONRequestBody{
|
||||
Name: "Search Zone",
|
||||
Domain: "search.example.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testing_tools.TestGroupId},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, zone *api.Zone) {
|
||||
t.Helper()
|
||||
assert.NotEmpty(t, zone.Id)
|
||||
assert.Equal(t, "Search Zone", zone.Name)
|
||||
assert.Equal(t, true, zone.EnableSearchDomain)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create disabled zone",
|
||||
requestBody: &api.PostApiDnsZonesJSONRequestBody{
|
||||
Name: "Disabled Zone",
|
||||
Domain: "disabled.example.com",
|
||||
Enabled: &disabled,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testing_tools.TestGroupId},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, zone *api.Zone) {
|
||||
t.Helper()
|
||||
assert.NotEmpty(t, zone.Id)
|
||||
assert.Equal(t, false, zone.Enabled)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create zone with empty distribution groups",
|
||||
requestBody: &api.PostApiDnsZonesJSONRequestBody{
|
||||
Name: "No Groups Zone",
|
||||
Domain: "nogroups.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{},
|
||||
},
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/dns/zones", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.verifyResponse != nil {
|
||||
got := &api.Zone{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify the created zone directly in the DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbZone := testing_tools.VerifyZoneInDB(t, db, got.Id)
|
||||
assert.Equal(t, got.Name, dbZone.Name)
|
||||
assert.Equal(t, got.Domain, dbZone.Domain)
|
||||
assert.Equal(t, got.Enabled, dbZone.Enabled)
|
||||
assert.Equal(t, got.EnableSearchDomain, dbZone.EnableSearchDomain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Zones_Update(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
enabled := true
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
requestBody *api.PutApiDnsZonesZoneIdJSONRequestBody
|
||||
expectedStatus int
|
||||
verifyResponse func(t *testing.T, zone *api.Zone)
|
||||
}{
|
||||
{
|
||||
name: "Update zone name and settings",
|
||||
zoneId: "testZoneId",
|
||||
requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{
|
||||
Name: "Updated Zone",
|
||||
Domain: "example.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testing_tools.TestGroupId},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, zone *api.Zone) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "Updated Zone", zone.Name)
|
||||
assert.Equal(t, "example.com", zone.Domain)
|
||||
assert.Equal(t, true, zone.EnableSearchDomain)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Update non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{
|
||||
Name: "Whatever",
|
||||
Domain: "whatever.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testing_tools.TestGroupId},
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.verifyResponse != nil {
|
||||
got := &api.Zone{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify the updated zone directly in the DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbZone := testing_tools.VerifyZoneInDB(t, db, tc.zoneId)
|
||||
assert.Equal(t, "Updated Zone", dbZone.Name)
|
||||
assert.Equal(t, "example.com", dbZone.Domain)
|
||||
assert.Equal(t, true, dbZone.Enabled)
|
||||
assert.Equal(t, true, dbZone.EnableSearchDomain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Zones_Delete(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Delete existing zone",
|
||||
zoneId: "testZoneId",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
|
||||
// Verify deletion in DB for successful deletes by privileged users
|
||||
if tc.expectedStatus == http.StatusOK && user.expectResponse {
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
testing_tools.VerifyZoneNotInDB(t, db, tc.zoneId)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -78,68 +78,6 @@ func Test_Events_GetAll(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Events_GetAll_Audit(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get all audit events", func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, false)
|
||||
|
||||
// First, perform a mutation to generate an event (create a group as admin)
|
||||
groupBody, err := json.Marshal(&api.GroupRequest{Name: "auditTestGroup"})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal group request: %v", err)
|
||||
}
|
||||
createReq := testing_tools.BuildRequest(t, groupBody, http.MethodPost, "/api/groups", testing_tools.TestAdminId)
|
||||
createRecorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(createRecorder, createReq)
|
||||
assert.Equal(t, http.StatusOK, createRecorder.Code, "Failed to create group to generate event")
|
||||
|
||||
// Now query audit events
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events/audit", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.Event{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.GreaterOrEqual(t, len(got), 1, "Expected at least one event after creating a group")
|
||||
|
||||
// Verify the group creation event exists
|
||||
found := false
|
||||
for _, event := range got {
|
||||
if event.ActivityCode == "group.add" {
|
||||
found = true
|
||||
assert.Equal(t, testing_tools.TestAdminId, event.InitiatorId)
|
||||
assert.Equal(t, "Group created", event.Activity)
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected to find a group.add event")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Events_GetAll_Empty(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, true)
|
||||
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func Test_Geolocations_GetAllCountries(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get all countries", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/locations/countries", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.Country{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, len(got))
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Geolocations_GetCitiesByCountry(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get cities by country", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "US", 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.City{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, len(got))
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Geolocations_GetCitiesByCountry_InvalidCode(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "INVALID", 1), testing_tools.TestAdminId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, http.StatusUnprocessableEntity, true)
|
||||
}
|
||||
@@ -26,7 +26,7 @@ func Test_Groups_GetAll(t *testing.T) {
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, true},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
@@ -71,7 +71,7 @@ func Test_Groups_GetById(t *testing.T) {
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, true},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
@@ -216,6 +216,7 @@ func Test_Groups_Create(t *testing.T) {
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify group exists in DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbGroup := testing_tools.VerifyGroupInDB(t, db, got.Id)
|
||||
assert.Equal(t, tc.requestBody.Name, dbGroup.Name)
|
||||
|
||||
@@ -1,295 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func Test_IdentityProviders_GetAll(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get all identity providers", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/identity-providers", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.IdentityProvider{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
// The embedded IdP manager is not initialized in the test environment,
|
||||
// so GetIdentityProviders returns an empty list.
|
||||
assert.Equal(t, 0, len(got))
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IdentityProviders_GetById(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
idpId string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Get existing identity provider",
|
||||
idpId: "testIdpId",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "Get non-existing identity provider",
|
||||
idpId: "nonExistingIdpId",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IdentityProviders_Create(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
requestBody *api.PostApiIdentityProvidersJSONRequestBody
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Create identity provider with valid data",
|
||||
requestBody: &api.PostApiIdentityProvidersJSONRequestBody{
|
||||
Type: api.IdentityProviderTypeGoogle,
|
||||
Name: "New IDP",
|
||||
ClientId: "newClientId",
|
||||
ClientSecret: "newClientSecret",
|
||||
},
|
||||
// Validation passes but the embedded IdP manager is not initialized,
|
||||
// so the operation returns an internal server error.
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "Create identity provider with invalid issuer",
|
||||
requestBody: &api.PostApiIdentityProvidersJSONRequestBody{
|
||||
Type: api.IdentityProviderTypeOidc,
|
||||
Name: "Invalid IDP",
|
||||
Issuer: "not-a-url",
|
||||
ClientId: "clientId",
|
||||
ClientSecret: "clientSecret",
|
||||
},
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/identity-providers", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IdentityProviders_Update(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
idpId string
|
||||
requestBody *api.PutApiIdentityProvidersIdpIdJSONRequestBody
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Update existing identity provider",
|
||||
idpId: "testIdpId",
|
||||
requestBody: &api.PutApiIdentityProvidersIdpIdJSONRequestBody{
|
||||
Type: api.IdentityProviderTypeGoogle,
|
||||
Name: "Updated IDP",
|
||||
ClientId: "updatedClientId",
|
||||
ClientSecret: "updatedClientSecret",
|
||||
},
|
||||
// Validation passes but the embedded IdP manager is not initialized,
|
||||
// so the operation returns an internal server error.
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "Update non-existing identity provider",
|
||||
idpId: "nonExistingIdpId",
|
||||
requestBody: &api.PutApiIdentityProvidersIdpIdJSONRequestBody{
|
||||
Type: api.IdentityProviderTypeGoogle,
|
||||
Name: "Updated IDP",
|
||||
ClientId: "updatedClientId",
|
||||
ClientSecret: "updatedClientSecret",
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IdentityProviders_Delete(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
idpId string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Delete existing identity provider",
|
||||
idpId: "testIdpId",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "Delete non-existing identity provider",
|
||||
idpId: "nonExistingIdpId",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user