Compare commits

..

35 Commits

Author SHA1 Message Date
Ashley Mensah
ca4a31c0be fix(ci): improve system prompt for partial fixes and stale issues
- Add rules for partial fixes: PRs that only address some items in a
  multi-item request or explicitly say "partially addresses" do not
  count as resolution
- Clarify that maintainer-offered alternatives without reporter
  confirmation are not fixes
- Add stale issue guidance: 12+ months inactive with unanswered
  maintainer response should be MANUAL_REVIEW not KEEP_OPEN
2026-04-30 16:26:20 +02:00
Ashley Mensah
5bf17023fb feat(ci): sort by oldest updated and bump MAX_ISSUES to 100 2026-04-30 11:53:40 +02:00
Ashley Mensah
42e1f19689 fix(ci): remove unwritable fields, add Status single-select support
- Remove Linked PRs and Repository field writes (built-in, auto-populated)
- Add setSingleSelectField for Status field
- Set items to "Needs Review" when added to project board
- Clean up workflow env vars to only include writable fields
2026-04-30 11:34:14 +02:00
Ashley Mensah
8195aa43d5 fix(ci): skip linked_pull_requests field (not writable via API) 2026-04-29 18:36:36 +02:00
Ashley Mensah
6c4ce8df87 fix(ci): use number mutation for Confidence project field 2026-04-29 18:29:12 +02:00
Ashley Mensah
ca481b0ffc feat(ci): use PROJECT_PAT for project board, populate board in dry-run
- Add PROJECT_PAT secret for project board GraphQL calls, falling
  back to GITHUB_TOKEN if not set
- Dry-run now populates the project board with decisions (confidence,
  reason, evidence fields) while still skipping labels/comments/closing
- Extract addToProjectWithFields helper to reduce duplication
2026-04-29 17:45:46 +02:00
Ashley Mensah
6d118d9c99 feat(ci): trust LLM decisions and feed it PR merge status
- Remove pre_score override from enforcePolicy — policy now only gates
  AUTO_CLOSE, otherwise trusts the model's decision
- Pass pre_score evidence (hard signals, contradictions) to LLM as
  context instead of using it as a decision override
- Fetch linked PR merge status (MERGED/OPEN/CLOSED) in fetch step
  and include in LLM prompt so it can distinguish merged fixes from
  open proposals
2026-04-29 15:49:09 +02:00
Ashley Mensah
222d498bb6 fix(ci): distinguish workarounds from actual fixes in system prompt 2026-04-28 17:48:43 +02:00
Ashley Mensah
52cd104f1e chore(ci): switch back to gpt-4o-mini for higher quota 2026-04-28 17:44:05 +02:00
Ashley Mensah
92f666f652 fix(ci): cap retry-after and handle quota exhaustion gracefully 2026-04-28 17:42:43 +02:00
Ashley Mensah
4fc0cb7ec4 fix(ci): pace API requests to avoid rate limit thrashing 2026-04-28 17:30:18 +02:00
Ashley Mensah
695614834e fix(ci): fix policy logic and add message truncation
- enforcePolicy: respect KEEP_OPEN when model is confident and
  pre_score is low. Only promote to MANUAL_REVIEW when model suggests
  resolution or pre_score has hard signals
- Truncate user messages to 24k chars (issue body capped at 4k) to
  stay within GitHub Models 8000 token input limit
2026-04-28 16:38:36 +02:00
Ashley Mensah
d75fa6ad45 feat(ci): switch to gpt-4o and add rate limit retry
- Upgrade model from gpt-4o-mini to gpt-4o for better classification
- Add retry loop with backoff on 429 responses (up to 5 retries)
- Respect Retry-After header from GitHub Models API
2026-04-28 16:23:31 +02:00
Ashley Mensah
7ce7f322eb increase max number of dry run issues to 100 2026-04-28 16:13:38 +02:00
Ashley Mensah
22bcf70b6e fix(ci): add additionalProperties to nested schema objects 2026-04-28 16:11:57 +02:00
Ashley Mensah
fe8aa21245 feat(ci): wire up gpt-4o-mini for issue classification
- Replace stub callGitHubModel() with real GitHub Models API call
  using gpt-4o-mini with structured JSON output
- Build detailed user messages from issue body, comments, and timeline
- Add per-issue decision logging to classify step
- Upload candidates.json and decisions.json as workflow artifacts
2026-04-28 16:11:01 +02:00
Ashley Mensah
29f211e51c fix(ci): guard all decisions behind dry-run check
Move the dry-run check to the top of the loop so it applies to all
decision types, not just AUTO_CLOSE. In dry-run mode the workflow now
only logs what it would do without touching any issues.
2026-04-28 16:08:16 +02:00
Ashley Mensah
6df3580bd3 fix(ci): handle project API permission errors gracefully
GITHUB_TOKEN cannot access org-level Projects V2. Make addToProject
return null on failure instead of crashing, and skip setTextField
calls when project access is unavailable. A PAT with project scope
is needed for full project board integration.
2026-04-28 15:55:12 +02:00
Ashley Mensah
9ff735dd52 fix(ci): fix workflow by adding working-directory, fetch script, and removing npm ci
- Set defaults.run.working-directory to .github/issue-resolution so
  scripts resolve from the correct path
- Remove npm ci step (no npm dependencies needed)
- Add fetch-candidates.mjs to gather open issues with comments and
  timeline events via GitHub REST API
- Add minimal package.json with type: module
2026-04-28 15:53:17 +02:00
Ashley Mensah
7c43973bc9 fix typo in workflow 2026-04-28 15:48:35 +02:00
Ashley Mensah
01e53d07b9 fix typo in workflow 2026-04-28 15:46:57 +02:00
Ashley Mensah
797dce1631 dummy commit to test workflow 2026-04-28 15:43:37 +02:00
Ashley Mensah
2877fcbbf6 add push trigger to workflow 2026-04-28 15:42:50 +02:00
Ashley Mensah
5d8201fcd0 Merge branch 'main' into github-issue-resolver 2026-04-28 15:36:54 +02:00
Ashley Mensah
09595bd0c2 enable dry run, add project field id values 2026-04-28 15:36:21 +02:00
Zoltan Papp
8fc4265995 [relay] evict foreign client cache on disconnect (#6015)
* [relay] evict foreign client cache on disconnect

When a foreign relay's TCP connection drops, the manager's
onServerDisconnected handler only triggered reconnect logic for the
home server; the disconnected foreign entry stayed in the relayClients
cache. Subsequent OpenConn calls reused the closed client until the
60-second cleanup tick evicted it, breaking peer connectivity through
that relay for up to a minute.

Evict the foreign entry from the cache on disconnect so the next
OpenConn dials a fresh client.

Also:
- Make the reconnect backoff cap configurable via WithMaxBackoffInterval
  ManagerOption; the previous hard-coded 60s constant forced
  TestAutoReconnect to sleep ~61s. Test now polls Ready() and finishes
  in ~2s.
- Add NB_HOME_RELAY_SERVERS env var that overrides the relay URL list
  received from management, so a peer can be pinned to a specific home
  relay (used by the netbird-conn-lab Edge 4 reproducer).

* [client] treat empty NB_HOME_RELAY_SERVERS as unset

Returning (urls=[], ok=true) when the env var contained only separators or
whitespace caused callers to wipe the mgmt-provided relay list, leaving the
peer with no relays. Treat a parsed-empty result the same as an unset env.
2026-04-28 15:04:48 +02:00
Zoltan Papp
9c50819f20 Don't mark management disconnected on transient job stream errors (#6005)
The JOB stream is a separate channel from the SYNC stream. Server-side
EOF or other transient errors on the JOB stream do not indicate that
the management connection is unhealthy — the SYNC stream remains the
authoritative state signal.

Previously, a JOB stream EOF would call notifyDisconnected and the
client would emit OnConnecting to the UI. The backoff retry would
reconnect the JOB stream, but handleJobStream never calls notifyConnected
on success, so the UI was stuck on "Connecting" until the next SYNC
event or health check.

Keep notifyDisconnected for codes.PermissionDenied since IsLoginRequired
relies on managementError to detect expired auth.
2026-04-28 15:04:41 +02:00
Bethuel Mmbaga
6f0eff3ba0 [management] Handle single-string JWT group claim from IdPs (#6014) 2026-04-28 14:48:28 +03:00
Bethuel Mmbaga
f8745723fc [management] Add Microsoft AD FS support for embedded Dex identity providers (#6008) 2026-04-28 12:42:19 +03:00
Vlad
154b81645a [management] removed legacy network map code (#5565) 2026-04-27 16:02:54 +02:00
Maycon Santos
34167c8a16 [misc] Update release pipeline version (#5995) 2026-04-27 10:55:38 +02:00
Maycon Santos
d6f08e4840 [misc] Update sign pipeline version (#5981) 2026-04-24 13:13:27 +02:00
Zoltan Papp
f732b01a05 [management] unify peer-update test timeout via constant (#5952)
peerShouldReceiveUpdate waited 500ms for the expected update message,
and every outer wrapper across the management/server test suite paired
it with a 1s goroutine-drain timeout. Both were too tight for slower
CI runners (MySQL, FreeBSD, loaded sqlite), producing intermittent
"Timed out waiting for update message" failures in tests like
TestDNSAccountPeersUpdate, TestPeerAccountPeersUpdate, and
TestNameServerAccountPeersUpdate.

Introduce peerUpdateTimeout (5s) next to the helper and use it both in
the helper and in every outer wrapper so the two timeouts stay in sync.
Only runs down on failure; passing tests return as soon as the channel
delivers, so there is no slowdown on green runs.
2026-04-23 21:19:21 +02:00
Ashley Mensah
6b8e40f78d initial commit - workflow yaml, prompts and schemas 2026-04-23 18:48:38 +02:00
alsruf36
c07c726ea7 [proxy] Set session cookie path to root (#5915) 2026-04-23 18:20:54 +02:00
180 changed files with 5474 additions and 9677 deletions

5
.github/issue-resolution/package.json vendored Normal file
View File

@@ -0,0 +1,5 @@
{
"name": "issue-resolution",
"private": true,
"type": "module"
}

View File

@@ -0,0 +1,41 @@
You are a GitHub issue resolution classifier.
Your job is to decide whether an open GitHub issue is:
- AUTO_CLOSE
- MANUAL_REVIEW
- KEEP_OPEN
Rules:
1. AUTO_CLOSE is only allowed if there is objective, hard evidence:
- a merged linked PR that clearly resolves the issue, or
- an explicit maintainer/member/owner/collaborator comment saying the issue is fixed, resolved, duplicate, or superseded
2. If there is any contradictory later evidence, do NOT AUTO_CLOSE.
3. If evidence is promising but not airtight, choose MANUAL_REVIEW.
4. If the issue still appears active or unresolved, choose KEEP_OPEN.
5. Do not invent evidence.
6. Output valid JSON only.
Maintainer-authoritative roles:
- MEMBER
- OWNER
- COLLABORATOR
Partial fixes and multi-item issues:
- A merged PR that only addresses SOME items in a multi-item request does NOT resolve the issue. If the issue lists 5 feature requests and a PR fixes 1, the issue is still open.
- If a PR description or comment says "partially addresses", "partial fix", or similar, the issue is NOT resolved. Classify as KEEP_OPEN.
- If a merged PR addresses the core ask but a later comment objects or reports a regression, classify as MANUAL_REVIEW (not resolved).
Workarounds vs. actual fixes:
- A WORKAROUND is when a user changes their own setup to avoid the problem (editing configs, using a different setting, manual SQL fixes, switching tools, scripts). Workarounds do NOT count as resolution — the underlying issue is still present in the product.
- An ACTUAL FIX is when a user reports the problem went away after upgrading to a specific version (e.g., "fixed after updating to v0.65.1") or after a specific PR was merged. This suggests the fix was shipped in the product itself.
- A maintainer pointing to an existing alternative feature is NOT the same as fixing the issue. If the reporter never confirmed the alternative works for them, classify as KEEP_OPEN.
- If only workarounds exist and no maintainer has confirmed a fix, classify as KEEP_OPEN.
- If a user reports an actual fix via a version upgrade but no maintainer confirmed it, classify as MANUAL_REVIEW (not AUTO_CLOSE).
Stale issues:
- An issue with no activity for over 12 months, where a maintainer offered an alternative or asked for more info and the reporter never responded, is a candidate for MANUAL_REVIEW — not necessarily KEEP_OPEN.
Important:
- Later comments outweigh earlier ones.
- A non-maintainer saying "fixed for me" is not enough for AUTO_CLOSE.
- If uncertain, prefer MANUAL_REVIEW or KEEP_OPEN.

View File

@@ -0,0 +1,80 @@
{
"type": "object",
"additionalProperties": false,
"required": [
"decision",
"reason_code",
"confidence",
"hard_signals",
"contradictions",
"summary",
"close_comment",
"manual_review_note"
],
"properties": {
"decision": {
"type": "string",
"enum": ["AUTO_CLOSE", "MANUAL_REVIEW", "KEEP_OPEN"]
},
"reason_code": {
"type": "string",
"enum": [
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed",
"likely_fixed_but_unconfirmed",
"still_open",
"unclear"
]
},
"confidence": {
"type": "number",
"minimum": 0,
"maximum": 1
},
"hard_signals": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": false,
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"merged_pr",
"maintainer_comment",
"duplicate_reference",
"superseded_reference"
]
},
"url": { "type": "string" }
}
}
},
"contradictions": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": false,
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"reporter_still_broken",
"later_unresolved_comment",
"ambiguous_pr_link",
"other"
]
},
"url": { "type": "string" }
}
}
},
"summary": { "type": "string" },
"close_comment": { "type": "string" },
"manual_review_note": { "type": "string" }
}
}

View File

@@ -0,0 +1,213 @@
import fs from "node:fs/promises";
const decisions = JSON.parse(await fs.readFile("decisions.json", "utf8"));
const dryRun = String(process.env.DRY_RUN).toLowerCase() === "true";
const ghHeaders = {
Authorization: `Bearer ${process.env.GH_TOKEN}`,
Accept: "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
};
// Use PROJECT_PAT for project board operations, fall back to GH_TOKEN
const projectHeaders = {
Authorization: `Bearer ${process.env.PROJECT_PAT || process.env.GH_TOKEN}`,
Accept: "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
};
async function rest(url, method = "GET", body) {
const res = await fetch(url, {
method,
headers: ghHeaders,
body: body ? JSON.stringify(body) : undefined
});
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
return res.status === 204 ? null : res.json();
}
async function graphql(query, variables) {
const res = await fetch("https://api.github.com/graphql", {
method: "POST",
headers: projectHeaders,
body: JSON.stringify({ query, variables })
});
if (!res.ok) throw new Error(`${res.status}: ${await res.text()}`);
const json = await res.json();
if (json.errors) throw new Error(JSON.stringify(json.errors));
return json.data;
}
async function addLabel(owner, repo, issueNumber, labels) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/labels`,
"POST",
{ labels }
);
}
async function addComment(owner, repo, issueNumber, body) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/comments`,
"POST",
{ body }
);
}
async function closeIssue(owner, repo, issueNumber) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`,
"PATCH",
{ state: "closed", state_reason: "completed" }
);
}
async function getIssueNodeId(owner, repo, issueNumber) {
const issue = await rest(`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`);
return issue.node_id;
}
async function addToProject(issueNodeId) {
const mutation = `
mutation($projectId: ID!, $contentId: ID!) {
addProjectV2ItemById(input: {projectId: $projectId, contentId: $contentId}) {
item { id }
}
}
`;
try {
const data = await graphql(mutation, {
projectId: process.env.PROJECT_ID,
contentId: issueNodeId
});
return data.addProjectV2ItemById.item.id;
} catch (err) {
console.warn(`[WARN] Could not add to project (needs PAT with project scope): ${err.message}`);
return null;
}
}
async function setTextField(itemId, fieldId, value) {
const mutation = `
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
updateProjectV2ItemFieldValue(input: {
projectId: $projectId,
itemId: $itemId,
fieldId: $fieldId,
value: { text: $value }
}) {
projectV2Item { id }
}
}
`;
return graphql(mutation, {
projectId: process.env.PROJECT_ID,
itemId,
fieldId,
value
});
}
async function setNumberField(itemId, fieldId, value) {
const mutation = `
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: Float!) {
updateProjectV2ItemFieldValue(input: {
projectId: $projectId,
itemId: $itemId,
fieldId: $fieldId,
value: { number: $value }
}) {
projectV2Item { id }
}
}
`;
return graphql(mutation, {
projectId: process.env.PROJECT_ID,
itemId,
fieldId,
value
});
}
async function setSingleSelectField(itemId, fieldId, optionId) {
const mutation = `
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $optionId: String!) {
updateProjectV2ItemFieldValue(input: {
projectId: $projectId,
itemId: $itemId,
fieldId: $fieldId,
value: { singleSelectOptionId: $optionId }
}) {
projectV2Item { id }
}
}
`;
return graphql(mutation, {
projectId: process.env.PROJECT_ID,
itemId,
fieldId,
optionId
});
}
async function addToProjectWithFields(owner, repo, d) {
const issueNodeId = await getIssueNodeId(owner, repo, d.issue_number);
const itemId = await addToProject(issueNodeId);
if (itemId) {
if (process.env.PROJECT_STATUS_FIELD_ID && process.env.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID) {
await setSingleSelectField(itemId, process.env.PROJECT_STATUS_FIELD_ID, process.env.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID);
}
if (process.env.PROJECT_CONFIDENCE_FIELD_ID) {
await setNumberField(itemId, process.env.PROJECT_CONFIDENCE_FIELD_ID, d.model.confidence);
}
if (process.env.PROJECT_REASON_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_REASON_FIELD_ID, d.model.reason_code);
}
if (process.env.PROJECT_EVIDENCE_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_EVIDENCE_FIELD_ID, d.issue_url);
}
console.log(` → Added to project board (Status: Needs Review)`);
}
}
for (const d of decisions) {
const [owner, repo] = d.repository.split("/");
if (d.final_decision === "KEEP_OPEN") {
console.log(`#${d.issue_number} → KEEP_OPEN (confidence: ${d.model.confidence}, reason: ${d.model.reason_code})`);
continue;
}
if (dryRun) {
console.log(`[DRY RUN] #${d.issue_number}${d.final_decision} (confidence: ${d.model.confidence}, reason: ${d.model.reason_code})`);
// In dry-run: populate project board but don't touch issues
if (d.final_decision === "MANUAL_REVIEW" || d.final_decision === "AUTO_CLOSE") {
await addToProjectWithFields(owner, repo, d);
}
continue;
}
if (d.final_decision === "AUTO_CLOSE") {
await addLabel(owner, repo, d.issue_number, ["auto-closed-resolved"]);
await addComment(owner, repo, d.issue_number, d.model.close_comment);
await closeIssue(owner, repo, d.issue_number);
await addToProjectWithFields(owner, repo, d);
}
if (d.final_decision === "MANUAL_REVIEW") {
await addLabel(owner, repo, d.issue_number, ["resolution-candidate"]);
await addToProjectWithFields(owner, repo, d);
await addComment(
owner,
repo,
d.issue_number,
d.model.manual_review_note ||
"This issue looks like a possible resolution candidate, but not with enough certainty for automatic closure. Added to the review queue."
);
}
}

View File

@@ -0,0 +1,259 @@
import fs from "node:fs/promises";
const candidates = JSON.parse(await fs.readFile("candidates.json", "utf8"));
const systemPrompt = await fs.readFile("prompts/issue-resolution-system.txt", "utf8");
const outputSchema = JSON.parse(await fs.readFile("schemas/issue-resolution-output.json", "utf8"));
function isMaintainerRole(role) {
return ["MEMBER", "OWNER", "COLLABORATOR"].includes(role || "");
}
function preScore(candidate) {
let score = 0;
const hardSignals = [];
const contradictions = [];
for (const t of candidate.timeline) {
const sourceIssue = t.source?.issue;
if (t.event === "cross-referenced" && sourceIssue?.pull_request?.html_url) {
hardSignals.push({
type: "merged_pr",
url: sourceIssue.html_url
});
score += 40; // provisional until PR merged state is verified
}
if (["referenced", "connected"].includes(t.event)) {
score += 10;
}
}
for (const c of candidate.comments) {
const body = c.body.toLowerCase();
if (
isMaintainerRole(c.author_association) &&
/\b(fixed|resolved|duplicate|superseded|closing)\b/.test(body)
) {
score += 25;
hardSignals.push({
type: "maintainer_comment",
url: c.html_url
});
}
if (/\b(still broken|still happening|not fixed|reproducible)\b/.test(body)) {
score -= 50;
contradictions.push({
type: "later_unresolved_comment",
url: c.html_url
});
}
}
return { score, hardSignals, contradictions };
}
// GitHub Models gpt-4o has an 8000 token input limit.
// Reserve ~2000 tokens for system prompt + response overhead.
// 1 token ~= 4 chars, so cap user message at ~24000 chars.
const MAX_USER_MESSAGE_CHARS = 24000;
function truncate(text, maxChars) {
if (text.length <= maxChars) return text;
return text.slice(0, maxChars) + "\n\n[... truncated due to length]";
}
function buildUserMessage(candidate, pre) {
const { issue, comments, timeline } = candidate;
const commentBlock = comments
.map((c) => `[${c.author_association}] ${c.user} (${c.created_at}):\n${c.body}`)
.join("\n---\n");
const timelineBlock = timeline
.filter((t) => ["cross-referenced", "referenced", "connected", "closed", "reopened"].includes(t.event))
.map((t) => {
let line = `${t.event} (${t.created_at})`;
if (t.source?.issue?.html_url) line += `${t.source.issue.html_url}`;
if (t.source?.issue?.pull_request?.html_url) line += ` (PR: ${t.source.issue.pull_request.html_url})`;
return line;
})
.join("\n");
const sections = [
`## Issue #${issue.number}: ${issue.title}`,
`URL: ${issue.html_url}`,
`Created: ${issue.created_at} | Updated: ${issue.updated_at}`,
`Labels: ${issue.labels.join(", ") || "none"}`,
"",
"### Body",
truncate(issue.body || "(empty)", 4000),
"",
"### Comments",
commentBlock || "(none)",
"",
"### Timeline events",
timelineBlock || "(none)",
];
if (candidate.linked_prs?.length) {
sections.push("");
sections.push("### Linked PRs (verified state)");
for (const pr of candidate.linked_prs) {
const status = pr.merged ? `MERGED (${pr.merged_at})` : pr.state.toUpperCase();
sections.push(`- PR #${pr.number}: ${pr.title}${status}${pr.url}`);
}
}
if (pre.hardSignals.length || pre.contradictions.length) {
sections.push("");
sections.push("### Automated evidence scan");
for (const s of pre.hardSignals) {
sections.push(`- SIGNAL: ${s.type}${s.url}`);
}
for (const c of pre.contradictions) {
sections.push(`- CONTRADICTION: ${c.type}${c.url}`);
}
}
return truncate(sections.join("\n"), MAX_USER_MESSAGE_CHARS);
}
const MODEL = "gpt-4o-mini";
const MAX_RETRIES = 5;
function sleep(ms) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
async function callGitHubModel(candidate, pre) {
const body = JSON.stringify({
model: MODEL,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: buildUserMessage(candidate, pre) },
],
response_format: {
type: "json_schema",
json_schema: {
name: "issue_resolution",
strict: true,
schema: outputSchema,
},
},
temperature: 0.1,
});
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
const res = await fetch("https://models.inference.ai.azure.com/chat/completions", {
method: "POST",
headers: {
Authorization: `Bearer ${process.env.GH_TOKEN}`,
"Content-Type": "application/json",
},
body,
});
if (res.status === 429) {
const retryAfter = Number(res.headers.get("retry-after")) || 30;
if (retryAfter > 120) {
console.warn(` [QUOTA EXHAUSTED] API wants ${retryAfter}s wait — skipping remaining issues.`);
return null;
}
console.warn(` [RATE LIMITED] Waiting ${retryAfter}s (attempt ${attempt + 1}/${MAX_RETRIES})...`);
await sleep(retryAfter * 1000);
continue;
}
if (!res.ok) {
const text = await res.text();
throw new Error(`GitHub Models ${res.status}: ${text}`);
}
const data = await res.json();
return JSON.parse(data.choices[0].message.content);
}
throw new Error(`GitHub Models: exceeded ${MAX_RETRIES} retries due to rate limiting`);
}
function enforcePolicy(modelOut, pre) {
const approvedReasons = new Set([
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed"
]);
const hasHardSignal =
(modelOut.hard_signals || []).some(s =>
["merged_pr", "maintainer_comment", "duplicate_reference", "superseded_reference"].includes(s.type)
) || pre.hardSignals.length > 0;
const hasContradiction =
(modelOut.contradictions || []).length > 0 || pre.contradictions.length > 0;
// Only auto-close with very strict criteria
if (
modelOut.decision === "AUTO_CLOSE" &&
modelOut.confidence >= 0.97 &&
approvedReasons.has(modelOut.reason_code) &&
hasHardSignal &&
!hasContradiction
) {
return "AUTO_CLOSE";
}
// Downgrade AUTO_CLOSE that didn't pass the gate
if (modelOut.decision === "AUTO_CLOSE") {
return "MANUAL_REVIEW";
}
// Otherwise trust the model
return modelOut.decision;
}
console.log(`Classifying ${candidates.length} candidates with ${MODEL}...\n`);
// 15 req/min limit → 1 request every 4s. Use 4.5s for safety margin.
const PACE_MS = 4500;
let lastRequestTime = 0;
async function paced(fn) {
const elapsed = Date.now() - lastRequestTime;
if (elapsed < PACE_MS) await sleep(PACE_MS - elapsed);
lastRequestTime = Date.now();
return fn();
}
const decisions = [];
for (const candidate of candidates) {
const pre = preScore(candidate);
const modelOut = await paced(() => callGitHubModel(candidate, pre));
if (modelOut === null) {
console.warn(`\nQuota exhausted after ${decisions.length} issues. Writing partial results.`);
break;
}
const finalDecision = enforcePolicy(modelOut, pre);
decisions.push({
repository: candidate.repository,
issue_number: candidate.issue.number,
issue_url: candidate.issue.html_url,
title: candidate.issue.title,
pre_score: pre.score,
final_decision: finalDecision,
model: modelOut
});
console.log(
`#${candidate.issue.number} | pre_score: ${pre.score} | model: ${modelOut.decision} @ ${modelOut.confidence} | final: ${finalDecision} | ${modelOut.reason_code}`
);
}
await fs.writeFile("decisions.json", JSON.stringify(decisions, null, 2));
console.log(`\nWrote ${decisions.length} decisions to decisions.json`);

View File

@@ -0,0 +1,123 @@
import fs from "node:fs/promises";
const token = process.env.GH_TOKEN;
const repo = process.env.REPO; // "owner/repo"
const maxIssues = Number(process.env.MAX_ISSUES) || 100;
const headers = {
Authorization: `Bearer ${token}`,
Accept: "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
};
async function rest(url) {
const res = await fetch(url, { headers });
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
return res.json();
}
async function restSafe(url) {
const res = await fetch(url, { headers });
if (!res.ok) return null;
return res.json();
}
async function paginate(url, max) {
const items = [];
let page = 1;
while (items.length < max) {
const perPage = Math.min(100, max - items.length);
const sep = url.includes("?") ? "&" : "?";
const batch = await rest(`${url}${sep}per_page=${perPage}&page=${page}`);
if (!batch.length) break;
items.push(...batch);
page++;
}
return items.slice(0, max);
}
console.log(`Fetching up to ${maxIssues} open issues from ${repo}...`);
const issues = await paginate(
`https://api.github.com/repos/${repo}/issues?state=open&sort=updated&direction=asc`,
maxIssues
);
// Filter out pull requests (GitHub API returns PRs as issues too)
const realIssues = issues.filter((i) => !i.pull_request);
console.log(`Found ${realIssues.length} open issues (excluded PRs).`);
const candidates = [];
for (const issue of realIssues) {
const [comments, timeline] = await Promise.all([
rest(`https://api.github.com/repos/${repo}/issues/${issue.number}/comments?per_page=100`),
rest(`https://api.github.com/repos/${repo}/issues/${issue.number}/timeline?per_page=100`),
]);
candidates.push({
repository: repo,
issue: {
number: issue.number,
html_url: issue.html_url,
title: issue.title,
body: issue.body,
created_at: issue.created_at,
updated_at: issue.updated_at,
labels: issue.labels.map((l) => l.name),
},
comments: comments.map((c) => ({
body: c.body,
author_association: c.author_association,
html_url: c.html_url,
created_at: c.created_at,
user: c.user?.login,
})),
timeline: timeline.map((t) => ({
event: t.event,
created_at: t.created_at,
source: t.source
? {
issue: {
html_url: t.source.issue?.html_url,
pull_request: t.source.issue?.pull_request
? { html_url: t.source.issue.pull_request.html_url }
: undefined,
},
}
: undefined,
})),
linked_prs: [],
});
// Fetch merge status for cross-referenced PRs
const prUrls = new Set();
for (const t of timeline) {
const prHtml = t.source?.issue?.pull_request?.html_url;
if (t.event === "cross-referenced" && prHtml) {
prUrls.add(prHtml);
}
}
const candidate = candidates[candidates.length - 1];
for (const prHtml of prUrls) {
// Extract owner/repo and PR number from URL like https://github.com/owner/repo/pull/123
const match = prHtml.match(/github\.com\/([^/]+\/[^/]+)\/pull\/(\d+)/);
if (!match) continue;
const [, prRepo, prNum] = match;
const pr = await restSafe(`https://api.github.com/repos/${prRepo}/pulls/${prNum}`);
if (!pr) continue;
candidate.linked_prs.push({
number: pr.number,
title: pr.title,
url: prHtml,
state: pr.state,
merged: pr.merged || false,
merged_at: pr.merged_at,
});
}
console.log(` #${issue.number}${comments.length} comments, ${timeline.length} timeline events, ${candidate.linked_prs.length} linked PRs`);
}
await fs.writeFile("candidates.json", JSON.stringify(candidates, null, 2));
console.log(`Wrote ${candidates.length} candidates to candidates.json`);

View File

@@ -0,0 +1,63 @@
name: issue-resolution-triage
on:
push:
branches: [github-issue-resolver]
workflow_dispatch:
inputs:
dry_run:
description: "If true, do not close issues"
required: false
default: "true"
max_issues:
description: "How many issues to process"
required: false
default: "100"
schedule:
- cron: "17 2 * * *"
permissions:
contents: read
issues: write
pull-requests: read
models: read
# todo: remove hardcoded values
jobs:
triage:
runs-on: ubuntu-latest
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PROJECT_PAT: ${{ secrets.PROJECT_PAT }}
DRY_RUN: "true"
MAX_ISSUES: "100"
REPO: ${{ github.repository }}
PROJECT_ID: "PVT_kwDOBfz4Jc4BVeWR"
PROJECT_STATUS_FIELD_ID: "PVTSSF_lADOBfz4Jc4BVeWRzhQ56sU"
PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID: "a55a2be9"
PROJECT_CONFIDENCE_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ57x4"
PROJECT_REASON_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ5-Lg"
PROJECT_EVIDENCE_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ5-Pw"
defaults:
run:
working-directory: .github/issue-resolution
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: "20"
- run: node scripts/fetch-candidates.mjs
- run: node scripts/classify-candidates.mjs
- run: node scripts/apply-decisions.mjs
- uses: actions/upload-artifact@v4
if: always()
with:
name: triage-results
path: |
.github/issue-resolution/candidates.json
.github/issue-resolution/decisions.json

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.2"
SIGN_PIPE_VER: "v0.1.4"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -13,8 +13,6 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -31,6 +29,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -98,7 +97,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersmanager := peers.NewManager(store)
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersmanager)

View File

@@ -333,6 +333,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp)
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
relayURLs = override
}
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)

View File

@@ -944,7 +944,12 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
return fmt.Errorf("update relay token: %w", err)
}
e.relayManager.UpdateServerURLs(update.Urls)
urls := update.Urls
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
urls = override
}
e.relayManager.UpdateServerURLs(urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries.

View File

@@ -25,7 +25,6 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations"
@@ -58,6 +57,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -1632,7 +1632,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)

View File

@@ -7,7 +7,8 @@ import (
)
const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
)
func IsForceRelayed() bool {
@@ -16,3 +17,28 @@ func IsForceRelayed() bool {
}
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
}
// OverrideRelayURLs returns the relay server URL list set in
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
// the override is active. When the env var is unset, the boolean is false
// and the caller should keep the list received from the management server.
// Intended for lab/debug scenarios where a peer must pin to a specific home
// relay regardless of what management offers.
func OverrideRelayURLs() ([]string, bool) {
raw := os.Getenv(EnvKeyNBHomeRelayServers)
if raw == "" {
return nil, false
}
parts := strings.Split(raw, ",")
urls := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
urls = append(urls, p)
}
}
if len(urls) == 0 {
return nil, false
}
return urls, true
}

View File

@@ -15,8 +15,6 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
@@ -40,6 +38,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -306,7 +305,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store)
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager)

2
go.mod
View File

@@ -71,7 +71,7 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0

4
go.sum
View File

@@ -453,8 +453,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34 h1:g74mB64wnjCagzE1spKgPfTI/ont1SdSL3uX5bOecgM=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34/go.mod h1:lCOq5d1i19AQjEEW2d7aNK0Nn0KC0MKyfMz/PLwVBFg=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) {
// are stored with types that Dex can open.
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
switch connType {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
return "oidc", applyOIDCDefaults(connType, config)
default:
return connType, config
@@ -218,6 +218,8 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
case "okta", "pocketid":
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
case "adfs":
augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
}
return augmented

View File

@@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto
var err error
switch cfg.Type {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
dexType = "oidc"
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
case "google":
@@ -220,6 +220,8 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "adfs":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
}
return encodeConnectorConfig(oidcConfig)
}
@@ -283,7 +285,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa
// inferOIDCProviderType infers the specific OIDC provider from connector ID
func inferOIDCProviderType(connectorID string) string {
connectorIDLower := strings.ToLower(connectorID)
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} {
if strings.Contains(connectorIDLower, provider) {
return provider
}

View File

@@ -7,7 +7,6 @@ import (
"os"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -16,11 +15,9 @@ import (
"golang.org/x/exp/maps"
"golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -58,13 +55,6 @@ type Controller struct {
proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
}
type bufferUpdate struct {
@@ -81,29 +71,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
}
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
compactedNetworkMap := true
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
if err != nil && len(compactedEnv) > 0 {
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
}
if err == nil && !parsedCompactedNmap {
log.WithContext(ctx).Info("disabling compacted mode")
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
return &Controller{
repo: newRepository(store),
metrics: nMetrics,
@@ -117,12 +84,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
}
}
@@ -153,17 +114,9 @@ func (c *Controller) CountStreams() int {
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
globalStart := time.Now()
@@ -197,10 +150,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
@@ -243,16 +192,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountID):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -318,10 +258,6 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return fmt.Errorf("recalculate network map cache: %v", err)
}
return c.sendUpdateAccountPeers(ctx, accountID)
}
@@ -371,16 +307,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err
}
var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountId):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -451,17 +378,9 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, emptyMap, nil, 0, nil
}
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
@@ -493,20 +412,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, nil, 0, err
}
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -518,108 +427,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, networkMap, postureChecks, dnsFwdPort, nil
}
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
c.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (c *Controller) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
c.enrichAccountFromHolder(account)
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
}
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := c.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
c.updateAccountInHolder(account)
}
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
c.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (c *Controller) experimentalNetworkMap(accountId string) bool {
_, ok := c.expNewNetworkMapAIDs[accountId]
return c.expNewNetworkMap || ok
}
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
a := c.holder.GetAccount(account.Id)
if a == nil {
c.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
c.holder.AddAccount(account)
}
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
return c.holder.GetAccount(accountID)
}
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
a := c.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (c *Controller) updateAccountInHolder(account *types.Account) {
c.holder.AddAccount(account)
}
// GetDNSDomain returns the configured dnsDomain
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
@@ -756,16 +563,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
if err != nil {
return fmt.Errorf("failed to get peers by ids: %w", err)
}
for _, peer := range peers {
c.UpdatePeerInNetworkMapCache(accountID, peer)
}
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
}
@@ -775,14 +573,6 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
@@ -817,19 +607,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
continue
}
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
@@ -872,21 +649,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
return nil, err
}
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else {
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
}
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {

View File

@@ -12,9 +12,6 @@ import (
)
const (
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
DnsForwarderPort = nbdns.ForwarderServerPort
OldForwarderPort = nbdns.ForwarderClientPort
DnsForwarderPortMinVersion = "v0.59.0"

View File

@@ -16,6 +16,9 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -35,15 +38,17 @@ type Manager interface {
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
integratedPeerValidator integrated_validator.IntegratedValidator
accountManager account.Manager
networkMapController network_map.Controller
}
func NewManager(store store.Store) Manager {
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
return &managerImpl{
store: store,
store: store,
permissionsManager: permissionsManager,
}
}
@@ -60,10 +65,28 @@ func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
}
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
}

View File

@@ -5,11 +5,8 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/shared/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -18,15 +15,21 @@ type handler struct {
manager accesslogs.Manager
}
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager, permissionsManager permissions.Manager) {
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/events/proxy", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAccessLogs)).Methods("GET", "OPTIONS")
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
}
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var filter accesslogs.AccessLogFilter
filter.ParseFromRequest(r)

View File

@@ -9,19 +9,25 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
geo geolocation.Geolocation
cleanupCancel context.CancelFunc
store store.Store
permissionsManager permissions.Manager
geo geolocation.Geolocation
cleanupCancel context.CancelFunc
}
func NewManager(store store.Store, geo geolocation.Geolocation) accesslogs.Manager {
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
return &managerImpl{
store: store,
geo: geo,
store: store,
permissionsManager: permissionsManager,
geo: geo,
}
}
@@ -57,6 +63,14 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, 0, status.NewPermissionValidationError(err)
}
if !ok {
return nil, 0, status.NewPermissionDeniedError()
}
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
}

View File

@@ -6,11 +6,8 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/shared/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -20,15 +17,15 @@ type handler struct {
manager Manager
}
func RegisterEndpoints(router *mux.Router, manager Manager, permissionsManager permissions.Manager) {
func RegisterEndpoints(router *mux.Router, manager Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllDomains)).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Create, h.createCustomDomain)).Methods("POST", "OPTIONS")
router.HandleFunc("/domains/{domainId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteCustomDomain)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/domains/{domainId}/validate", permissionsManager.WithPermission(modules.Services, operations.Create, h.triggerCustomDomainValidation)).Methods("GET", "OPTIONS") // TODO: this should be a POST
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
}
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
@@ -59,7 +56,13 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
return resp
}
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -74,7 +77,13 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, ret)
}
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiReverseProxiesDomainsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -90,7 +99,13 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, use
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
}
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domainID := mux.Vars(r)["domainId"]
if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
@@ -105,7 +120,13 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, use
w.WriteHeader(http.StatusNoContent)
}
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domainID := mux.Vars(r)["domainId"]
if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)

View File

@@ -11,7 +11,11 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
type store interface {
@@ -33,22 +37,32 @@ type proxyManager interface {
}
type Manager struct {
store store
validator domain.Validator
proxyManager proxyManager
accountManager account.Manager
store store
validator domain.Validator
proxyManager proxyManager
permissionsManager permissions.Manager
accountManager account.Manager
}
func NewManager(store store, proxyMgr proxyManager, accountManager account.Manager) Manager {
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
return Manager{
store: store,
proxyManager: proxyMgr,
validator: domain.Validator{Resolver: net.DefaultResolver},
accountManager: accountManager,
store: store,
proxyManager: proxyMgr,
validator: domain.Validator{Resolver: net.DefaultResolver},
permissionsManager: permissionsManager,
accountManager: accountManager,
}
}
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
domains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("list custom domains: %w", err)
@@ -104,6 +118,14 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
}
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
@@ -137,6 +159,14 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
}
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
if err != nil {
return fmt.Errorf("get domain from store: %w", err)
@@ -153,6 +183,21 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
}
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
return
}
if !ok {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
}
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,

View File

@@ -23,7 +23,7 @@ type Proxy struct {
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Capabilities Capabilities `gorm:"embedded"`
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -6,14 +6,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -32,19 +30,25 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
}
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
domainmanager.RegisterEndpoints(domainRouter, domainManager, permissionsManager)
domainmanager.RegisterEndpoints(domainRouter, domainManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager, permissionsManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/reverse-proxies/clusters", permissionsManager.WithPermission(modules.Services, operations.Read, h.getClusters)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllServices)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Create, h.createService)).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Read, h.getService)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Update, h.updateService)).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteService)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -59,7 +63,13 @@ func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAut
util.WriteJSONObject(r.Context(), w, apiServices)
}
func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.ServiceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -67,13 +77,12 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth
}
service := new(rpservice.Service)
var err error
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err := service.Validate(); err != nil {
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -87,7 +96,13 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
}
func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -103,7 +118,13 @@ func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *a
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
}
func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -118,13 +139,12 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth
service := new(rpservice.Service)
service.ID = serviceID
var err error
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err := service.Validate(); err != nil {
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -138,7 +158,13 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
}
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -153,7 +179,13 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
@@ -85,17 +86,18 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
},
}
mgr := &Manager{
store: testStore,
accountManager: accountMgr,
proxyController: mockCtrl,
capabilities: mockCaps,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
store: testStore,
accountManager: accountMgr,
permissionsManager: permissions.NewManager(testStore),
proxyController: mockCtrl,
capabilities: mockCaps,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
}
mgr.exposeReaper = &exposeReaper{manager: mgr}

View File

@@ -21,6 +21,9 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -79,22 +82,24 @@ type CapabilityProvider interface {
}
type Manager struct {
store store.Store
accountManager account.Manager
proxyController proxy.Controller
capabilities CapabilityProvider
clusterDeriver ClusterDeriver
exposeReaper *exposeReaper
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyController proxy.Controller
capabilities CapabilityProvider
clusterDeriver ClusterDeriver
exposeReaper *exposeReaper
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
mgr := &Manager{
store: store,
accountManager: accountManager,
proxyController: proxyController,
capabilities: capabilities,
clusterDeriver: clusterDeriver,
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyController: proxyController,
capabilities: capabilities,
clusterDeriver: clusterDeriver,
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr
@@ -107,10 +112,26 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetActiveProxyClusters(ctx)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
@@ -164,6 +185,14 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
}
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
@@ -177,6 +206,14 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s
}
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
return nil, err
}
@@ -187,7 +224,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
err := m.replaceHostByLookup(ctx, accountID, s)
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
@@ -454,6 +491,14 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
}
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := service.Auth.HashSecrets(); err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
@@ -740,8 +785,16 @@ func validateResourceTargetType(target *service.Target, resource *resourcetypes.
}
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var s *service.Service
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
@@ -772,8 +825,16 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
}
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var services []*service.Service
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
@@ -1058,7 +1119,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
}
groupIDs := make([]string, 0, len(groupNames))
for _, groupName := range groupNames {
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator)
if err != nil {
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
}

View File

@@ -23,6 +23,9 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
@@ -697,10 +700,12 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
require.NoError(t, err)
permsMgr := permissions.NewManager(testStore)
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
},
}
@@ -713,9 +718,10 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
require.NoError(t, err)
mgr := &Manager{
store: testStore,
accountManager: accountMgr,
proxyController: proxyController,
store: testStore,
accountManager: accountMgr,
permissionsManager: permsMgr,
proxyController: proxyController,
clusterDeriver: &testClusterDeriver{
domains: []string{"test.netbird.io"},
},
@@ -1124,6 +1130,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockPerms := permissions.NewMockManager(ctrl)
mockAcct := account.NewMockManager(ctrl)
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
@@ -1134,9 +1141,10 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err)
mgr := &Manager{
store: sqlStore,
accountManager: mockAcct,
proxyController: proxyController,
store: sqlStore,
permissionsManager: mockPerms,
accountManager: mockAcct,
proxyController: proxyController,
}
service := &rpservice.Service{
@@ -1159,6 +1167,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err)
require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion")
mockPerms.EXPECT().
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
Return(true, nil)
mockAcct.EXPECT().
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
mockAcct.EXPECT().

View File

@@ -6,11 +6,8 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/shared/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -20,19 +17,25 @@ type handler struct {
manager zones.Manager
}
func RegisterEndpoints(router *mux.Router, manager zones.Manager, permissionsManager permissions.Manager) {
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllZones)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createZone)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getZone)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateZone)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteZone)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -47,7 +50,13 @@ func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *
util.WriteJSONObject(r.Context(), w, apiZones)
}
func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiDnsZonesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -57,7 +66,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *a
zone := new(zones.Zone)
zone.FromAPIRequest(&req)
if err := zone.Validate(); err != nil {
if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -71,7 +80,13 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *a
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
}
func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -87,7 +102,13 @@ func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
}
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -95,7 +116,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *a
}
var req api.PutApiDnsZonesZoneIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
@@ -104,7 +125,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *a
zone.FromAPIRequest(&req)
zone.ID = zoneID
if err := zone.Validate(); err != nil {
if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -118,14 +139,20 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *a
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
}
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
if err := h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -7,34 +7,62 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
dnsDomain string
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
dnsDomain string
}
func NewManager(store store.Store, accountManager account.Manager, dnsDomain string) zones.Manager {
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
dnsDomain: dnsDomain,
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
dnsDomain: dnsDomain,
}
}
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
var err error
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
return nil, err
}
@@ -74,6 +102,14 @@ func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string,
}
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
if err != nil {
return nil, fmt.Errorf("failed to get zone: %w", err)
@@ -114,6 +150,14 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
}
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)

View File

@@ -13,6 +13,9 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
@@ -26,7 +29,7 @@ const (
testDNSDomain = "netbird.selfhosted"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *gomock.Controller, func()) {
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
@@ -46,17 +49,23 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccoun
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := NewManager(testStore, mockAccountManager, testDNSDomain).(*managerImpl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
dnsDomain: testDNSDomain,
}
return manager, testStore, mockAccountManager, ctrl, cleanup
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
}
func TestManagerImpl_GetAllZones(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, ctrl, cleanup := setupTest(t)
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -68,6 +77,10 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
err = testStore.CreateZone(ctx, zone2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.NoError(t, err)
assert.Len(t, result, 2)
@@ -75,13 +88,43 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
assert.Equal(t, zone2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, ctrl, cleanup := setupTest(t)
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -89,6 +132,10 @@ func TestManagerImpl_GetZone(t *testing.T) {
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Equal(t, zone.ID, result.ID)
@@ -96,13 +143,29 @@ func TestManagerImpl_GetZone(t *testing.T) {
assert.Equal(t, zone.Domain, result.Domain)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, _, mockAccountManager, ctrl, cleanup := setupTest(t)
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -114,6 +177,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
@@ -132,8 +199,31 @@ func TestManagerImpl_CreateZone(t *testing.T) {
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("invalid group", func(t *testing.T) {
manager, _, _, ctrl, cleanup := setupTest(t)
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -143,13 +233,17 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{"invalid-group"},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("duplicate domain", func(t *testing.T) {
manager, testStore, _, ctrl, cleanup := setupTest(t)
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -165,6 +259,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
@@ -175,7 +273,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
})
t.Run("peer DNS domain conflict", func(t *testing.T) {
manager, testStore, _, ctrl, cleanup := setupTest(t)
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -193,6 +291,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
@@ -203,7 +305,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
})
t.Run("default DNS domain conflict", func(t *testing.T) {
manager, _, _, ctrl, cleanup := setupTest(t)
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -215,6 +317,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
@@ -229,7 +335,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -246,6 +352,10 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
@@ -265,7 +375,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
})
t.Run("domain change not allowed", func(t *testing.T) {
manager, testStore, _, ctrl, cleanup := setupTest(t)
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -282,6 +392,10 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
@@ -291,8 +405,31 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
assert.Equal(t, status.InvalidArgument, s.Type())
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: testZoneID,
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, ctrl, cleanup := setupTest(t)
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -302,6 +439,10 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
@@ -312,7 +453,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
ctx := context.Background()
t.Run("success with records", func(t *testing.T) {
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -328,6 +469,10 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCallCount := 0
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCallCount++
@@ -348,7 +493,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
})
t.Run("success without records", func(t *testing.T) {
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -356,6 +501,10 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
@@ -373,11 +522,31 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
require.Error(t, err)
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, ctrl, cleanup := setupTest(t)
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
require.Error(t, err)
})

View File

@@ -6,11 +6,8 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/shared/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -20,19 +17,25 @@ type handler struct {
manager records.Manager
}
func RegisterEndpoints(router *mux.Router, manager records.Manager, permissionsManager permissions.Manager) {
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllRecords)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createRecord)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getRecord)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateRecord)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteRecord)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -53,7 +56,13 @@ func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, apiRecords)
}
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -69,7 +78,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth
record := new(records.Record)
record.FromAPIRequest(&req)
if err := record.Validate(); err != nil {
if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -83,7 +92,13 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
}
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -105,7 +120,13 @@ func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *au
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
}
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -119,7 +140,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth
}
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
@@ -128,7 +149,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth
record.FromAPIRequest(&req)
record.ID = recordID
if err := record.Validate(); err != nil {
if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -142,7 +163,13 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
}
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -155,7 +182,7 @@ func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth
return
}
if err := h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -9,36 +9,64 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
}
func NewManager(store store.Store, accountManager account.Manager) records.Manager {
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
}
}
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
}
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
@@ -73,11 +101,18 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
}
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
var record *records.Record
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
@@ -125,11 +160,18 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
}
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var record *records.Record
var zone *zones.Zone
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)

View File

@@ -12,8 +12,12 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -23,7 +27,7 @@ const (
testGroupID = "test-group-id"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *gomock.Controller, func()) {
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
@@ -47,17 +51,22 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_serv
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := NewManager(testStore, mockAccountManager).(*managerImpl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
}
return manager, testStore, zone, mockAccountManager, ctrl, cleanup
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
}
func TestManagerImpl_GetAllRecords(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -69,6 +78,10 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Len(t, result, 2)
@@ -76,13 +89,43 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
assert.Equal(t, record2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -90,6 +133,10 @@ func TestManagerImpl_GetRecord(t *testing.T) {
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
assert.Equal(t, record.ID, result.ID)
@@ -99,13 +146,29 @@ func TestManagerImpl_GetRecord(t *testing.T) {
assert.Equal(t, record.TTL, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success - A record", func(t *testing.T) {
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -116,6 +179,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
@@ -135,7 +202,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
})
t.Run("success - AAAA record", func(t *testing.T) {
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -146,6 +213,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
@@ -160,7 +231,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
})
t.Run("success - CNAME record", func(t *testing.T) {
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -171,6 +242,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
@@ -184,8 +259,32 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
assert.Equal(t, inputRecord.Content, result.Content)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record name not in zone", func(t *testing.T) {
manager, _, zone, _, ctrl, cleanup := setupTest(t)
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -196,6 +295,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
@@ -203,7 +306,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
})
t.Run("duplicate record", func(t *testing.T) {
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -218,6 +321,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
@@ -225,7 +332,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
})
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -240,6 +347,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
@@ -251,7 +362,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -267,6 +378,10 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600, // Changed TTL
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
@@ -285,7 +400,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
})
t.Run("update only TTL - no validation", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -301,6 +416,10 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600, // Only TTL changed
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
// Event should be stored
}
@@ -311,8 +430,33 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
assert.Equal(t, 600, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: testRecordID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, ctrl, cleanup := setupTest(t)
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -324,13 +468,17 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("update creates duplicate", func(t *testing.T) {
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -350,6 +498,10 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
@@ -361,7 +513,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -369,6 +521,10 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
@@ -386,11 +542,31 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
require.Error(t, err)
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, ctrl, cleanup := setupTest(t)
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
require.Error(t, err)
})

View File

@@ -233,7 +233,7 @@ func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
return Create(s, func() accesslogs.Manager {
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.GeoLocationManager())
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
accessLogManager.StartPeriodicCleanup(
context.Background(),
s.Config.ReverseProxy.AccessLogRetentionDays,

View File

@@ -9,7 +9,6 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
@@ -28,6 +27,7 @@ import (
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
)
@@ -82,13 +82,13 @@ func (s *BaseServer) SettingsManager() settings.Manager {
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
}
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, idpConfig)
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
})
}
func (s *BaseServer) PeersManager() peers.Manager {
return Create(s, func() peers.Manager {
manager := peers.NewManager(s.Store())
manager := peers.NewManager(s.Store(), s.PermissionsManager())
s.AfterInit(func(s *BaseServer) {
manager.SetNetworkMapController(s.NetworkMapController())
manager.SetIntegratedPeerValidator(s.IntegratedValidator())
@@ -161,43 +161,43 @@ func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
func (s *BaseServer) GroupsManager() groups.Manager {
return Create(s, func() groups.Manager {
return groups.NewManager(s.Store(), s.AccountManager())
return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
})
}
func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
})
}
func (s *BaseServer) RoutesManager() routers.Manager {
return Create(s, func() routers.Manager {
return routers.NewManager(s.Store(), s.AccountManager())
return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
})
}
func (s *BaseServer) NetworksManager() networks.Manager {
return Create(s, func() networks.Manager {
return networks.NewManager(s.Store(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
})
}
func (s *BaseServer) ZonesManager() zones.Manager {
return Create(s, func() zones.Manager {
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.DNSDomain())
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
})
}
func (s *BaseServer) RecordsManager() records.Manager {
return Create(s, func() records.Manager {
return recordsManager.NewManager(s.Store(), s.AccountManager())
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
})
}
func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
})
}
@@ -213,7 +213,7 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ProxyManager(), s.AccountManager())
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
return &m
})
}

View File

@@ -15,7 +15,6 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth"
@@ -40,6 +39,9 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -280,14 +282,22 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// User that performs the update has to belong to the account.
// Returns an updated Settings
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
var oldSettings *types.Settings
var updateAccountPeers bool
var groupChangesAffectPeers bool
var reloadReverseProxy bool
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var groupsUpdated bool
var err error
oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
@@ -715,6 +725,15 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return err
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
if err != nil {
return fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
}
userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users))
if err != nil {
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
@@ -957,10 +976,6 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
return nil, err
}
if user.AccountID != accountID {
return nil, fmt.Errorf("user %s does not belong to account %s", userID, accountID)
}
key := user.IntegrationReference.CacheKey(accountID, userID)
ud, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
@@ -1272,16 +1287,41 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
// GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccount(ctx, accountID)
}
// GetAccountMeta returns the account metadata associated with this account ID.
func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
}
// GetAccountOnboarding retrieves the onboarding information for a specific account.
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
log.Errorf("failed to get account onboarding for account %s: %v", accountID, err)
@@ -1298,6 +1338,15 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou
}
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
return nil, fmt.Errorf("failed to get account onboarding: %w", err)
@@ -1356,8 +1405,9 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil
}
// Permission checks are now handled by the HTTP middleware via WithPermission wrapper
// User account association is already validated above by GetUserByUserID
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return "", "", err
}
if !user.IsServiceUser && userAuth.Invited {
err = am.redeemInvite(ctx, accountID, user.Id)
@@ -1799,6 +1849,13 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
}
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
}
@@ -2140,6 +2197,14 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee
}
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP)
if err != nil {
return fmt.Errorf("update peer IP transaction: %w", err)

View File

@@ -60,7 +60,7 @@ type Manager interface {
GetUserByID(ctx context.Context, id string) (*types.User, error)
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*nbpeer.Peer, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -75,7 +75,7 @@ type Manager interface {
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error

View File

@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
}
// GetGroupByName mocks base method.
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
ret0, _ := ret[0].(*types.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
}
// GetIdentityProvider mocks base method.
@@ -946,18 +946,18 @@ func (mr *MockManagerMockRecorder) GetPeerNetwork(ctx, peerID interface{}) *gomo
}
// GetPeers mocks base method.
func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*peer.Peer, error) {
func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*peer.Peer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter, all)
ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter)
ret0, _ := ret[0].([]*peer.Peer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeers indicates an expected call of GetPeers.
func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter, all interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter, all)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter)
}
// GetPolicy mocks base method.

View File

@@ -22,10 +22,8 @@ import (
"go.opentelemetry.io/otel/metric/noop"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/shared/management/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
@@ -51,6 +49,7 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -409,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@@ -1172,11 +1171,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
}
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SaveGroup(t)
}
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
testAccountManager_NetworkUpdates_SaveGroup(t)
}
@@ -1232,11 +1226,6 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
@@ -1275,11 +1264,6 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SavePolicy(t)
}
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_SavePolicy(t)
}
@@ -1333,11 +1317,6 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePeer(t)
}
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePeer(t)
}
@@ -1398,11 +1377,6 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
@@ -1634,75 +1608,6 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2"))
}
func TestAccount_GetRoutesToSync(t *testing.T) {
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
if err != nil {
t.Fatal(err)
}
_, prefix2, err := route.ParseNetwork("192.168.0.0/24")
if err != nil {
t.Fatal(err)
}
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
},
Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
Routes: map[route.ID]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-1",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
"route-2": {
ID: "route-2",
Network: prefix2,
NetID: "network-2",
Description: "network-2",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
"route-3": {
ID: "route-3",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
},
}
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
assert.Len(t, emptyRoutes, 0)
}
func TestAccount_Copy(t *testing.T) {
account := &types.Account{
Id: "account1",
@@ -1825,9 +1730,7 @@ func TestAccount_Copy(t *testing.T) {
AccountID: "account1",
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
}
account.InitOnce()
err := hasNilField(account)
if err != nil {
t.Fatal(err)
@@ -3148,7 +3051,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
AnyTimes()
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
proxyManager := proxy.NewMockManager(ctrl)
proxyManager.EXPECT().
@@ -3165,7 +3068,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, nil, err
@@ -3176,7 +3079,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
if err != nil {
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, proxyController, proxyManager, nil))
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil))
return manager, updateManager, nil
}
@@ -3254,6 +3157,13 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
return manager, updateManager, account, peer1, peer2, peer3
}
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
// wrappers wait for an expected update message. Sized for slow CI runners
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
// seconds. Only runs down on failure; passing tests return immediately
// when the channel delivers.
const peerUpdateTimeout = 5 * time.Second
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
@@ -3272,7 +3182,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
if msg == nil {
t.Errorf("Received nil update message, expected valid message")
}
case <-time.After(500 * time.Millisecond):
case <-time.After(peerUpdateTimeout):
t.Error("Timed out waiting for update message")
}
}

View File

@@ -8,6 +8,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
@@ -20,6 +22,14 @@ const (
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
}
@@ -29,11 +39,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var eventsToStore []func()
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
return err
}

View File

@@ -14,11 +14,11 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -28,6 +28,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -78,6 +79,16 @@ func TestGetDNSSettings(t *testing.T) {
if len(dnsSettings.DisabledManagementGroups) != 1 {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
}
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
}
s, ok := status.FromError(err)
if !ok && s.Type() != status.PermissionDenied {
t.Errorf("returned error should be Permission Denied, got err: %s", err)
}
}
func TestSaveDNSSettings(t *testing.T) {
@@ -212,7 +223,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
ctx := context.Background()
@@ -223,7 +234,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
}
@@ -447,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -467,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -507,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -9,8 +9,11 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
func isEnabled() bool {
@@ -20,6 +23,14 @@ func isEnabled() bool {
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
if err != nil {
return nil, err

View File

@@ -12,6 +12,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
@@ -30,24 +32,13 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
// Permission checks are now handled by the HTTP middleware via WithPermission wrapper
// This method is called from authenticated/authorized handlers, so we just validate
// that the user exists and is part of the account
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
if err != nil {
return err
}
if user == nil {
return status.NewUserNotFoundError(userID)
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsBlocked() {
return status.NewUserBlockedError()
if !allowed {
return status.NewPermissionDeniedError()
}
return nil
@@ -70,17 +61,27 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}
// CreateGroup object of the peers
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
@@ -124,11 +125,19 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
// UpdateGroup object of the peers
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
@@ -187,24 +196,33 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
return err
}
if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return err
}
@@ -225,7 +243,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
}
var err error
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
@@ -247,14 +264,21 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
@@ -287,7 +311,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
}
var err error
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
@@ -393,6 +416,14 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var allErrors error
var groupIDsToDelete []string
var deletedGroups []*types.Group

View File

@@ -26,6 +26,7 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -619,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -637,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -655,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -688,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -729,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -756,17 +757,18 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving a group linked to network router should update account peers and send peer update
t.Run("saving group linked to network router", func(t *testing.T) {
groupsManager := groups.NewManager(manager.Store, manager)
resourcesManager := resources.NewManager(manager.Store, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, manager)
networksManager := networks.NewManager(manager.Store, resourcesManager, routersManager, manager)
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{
ID: "network_test",
@@ -802,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -6,6 +6,9 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -22,21 +25,31 @@ type Manager interface {
}
type managerImpl struct {
store store.Store
accountManager account.Manager
store store.Store
permissionsManager permissions.Manager
accountManager account.Manager
}
type mockManager struct {
}
func NewManager(store store.Store, accountManager account.Manager) Manager {
func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
store: store,
permissionsManager: permissionsManager,
accountManager: accountManager,
}
}
func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
if err != nil {
return nil, err
}
if !ok {
return nil, err
}
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err)
@@ -60,6 +73,14 @@ func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID str
}
func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return err
}
if !ok {
return err
}
event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, userID, groupID, resource)
if err != nil {
return fmt.Errorf("error adding resource to group: %w", err)

View File

@@ -10,7 +10,6 @@ import (
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
"github.com/netbirdio/netbird/management/server/types"
@@ -32,8 +31,10 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/auth"
@@ -123,25 +124,25 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to create instance manager: %w", err)
}
accounts.AddEndpoints(accountManager, settingsManager, router, permissionsManager)
accounts.AddEndpoints(accountManager, settingsManager, router)
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
users.AddEndpoints(accountManager, router, permissionsManager)
users.AddInvitesEndpoints(accountManager, router, permissionsManager)
users.AddEndpoints(accountManager, router)
users.AddInvitesEndpoints(accountManager, router)
users.AddPublicInvitesEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router, permissionsManager)
policies.AddEndpoints(accountManager, LocationManager, router, permissionsManager)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router, permissionsManager)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router)
groups.AddEndpoints(accountManager, router, permissionsManager)
routes.AddEndpoints(accountManager, router, permissionsManager)
dns.AddEndpoints(accountManager, router, permissionsManager)
events.AddEndpoints(accountManager, router, permissionsManager)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, permissionsManager, router)
zonesManager.RegisterEndpoints(router, zManager, permissionsManager)
recordsManager.RegisterEndpoints(router, rManager, permissionsManager)
idp.AddEndpoints(accountManager, router, permissionsManager)
groups.AddEndpoints(accountManager, router)
routes.AddEndpoints(accountManager, router)
dns.AddEndpoints(accountManager, router)
events.AddEndpoints(accountManager, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
zonesManager.RegisterEndpoints(router, zManager)
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router, permissionsManager)
instance.AddVersionEndpoint(instanceManager, router)
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
}

View File

@@ -12,13 +12,10 @@ import (
goversion "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -43,11 +40,11 @@ type handler struct {
settingsManager settings.Manager
}
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router, permissionsManager permissions.Manager) {
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
accountsHandler := newHandler(accountManager, settingsManager)
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Update, accountsHandler.updateAccount)).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Delete, accountsHandler.deleteAccount)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", permissionsManager.WithPermission(modules.Accounts, operations.Read, accountsHandler.getAllAccounts)).Methods("GET", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
}
// newHandler creates a new handler HTTP handler
@@ -102,7 +99,7 @@ func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID st
}
func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error {
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "", true)
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "")
if err != nil {
return status.Errorf(status.Internal, "get peer count: %v", err)
}
@@ -139,26 +136,34 @@ func calculateRequiredAddresses(peerCount int) int64 {
}
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
meta, err := h.accountManager.GetAccountMeta(r.Context(), userAuth.AccountId, userAuth.UserId)
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
settings, err := h.settingsManager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
accountID, userID := userAuth.AccountId, userAuth.UserId
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), userAuth.AccountId, userAuth.UserId)
settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(userAuth.AccountId, settings, meta, onboarding)
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(accountID, settings, meta, onboarding)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
@@ -228,15 +233,24 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
}
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
accountID := mux.Vars(r)["accountId"]
if accountID != userAuth.AccountId {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w)
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
_, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
accountID := vars["accountId"]
if len(accountID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
return
}
var req api.PutApiAccountsAccountIdJSONRequestBody
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -253,7 +267,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
return
}
if err := h.validateNetworkRange(r.Context(), accountID, userAuth.UserId, prefix); err != nil {
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -268,19 +282,19 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth
}
}
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userAuth.UserId, onboarding)
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userAuth.UserId, settings)
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userAuth.UserId)
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -292,14 +306,21 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth
}
// deleteAccount is a HTTP DELETE handler to delete an account
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
accountID := mux.Vars(r)["accountId"]
if accountID != userAuth.AccountId {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w)
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
err := h.accountManager.DeleteAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
if len(targetAccountID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w)
return
}
err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -14,7 +14,6 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/settings"
@@ -291,8 +290,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/accounts", permissions.WrapHandler(handler.getAllAccounts)).Methods("GET")
router.HandleFunc("/api/accounts/{accountId}", permissions.WrapHandler(handler.updateAccount)).Methods("PUT")
router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")
router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -5,13 +5,11 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -21,15 +19,15 @@ type dnsSettingsHandler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
addDNSSettingEndpoint(accountManager, router, permissionsManager)
addDNSNameserversEndpoint(accountManager, router, permissionsManager)
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
addDNSSettingEndpoint(accountManager, router)
addDNSNameserversEndpoint(accountManager, router)
}
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router) {
dnsSettingsHandler := newDNSSettingsHandler(accountManager)
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Read, dnsSettingsHandler.getDNSSettings)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Update, dnsSettingsHandler.updateDNSSettings)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
}
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
@@ -38,8 +36,17 @@ func newDNSSettingsHandler(accountManager account.Manager) *dnsSettingsHandler {
}
// getDNSSettings returns the DNS settings for the account
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -53,9 +60,17 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque
}
// updateDNSSettings handles update to DNS settings of an account
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PutApiDnsSettingsJSONRequestBody
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -65,7 +80,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups,
}
err = h.accountManager.SaveDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId, updateDNSSettings)
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -17,7 +17,6 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
@@ -116,8 +115,8 @@ func TestDNSSettingsHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.getDNSSettings)).Methods("GET")
router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.updateDNSSettings)).Methods("PUT")
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -6,13 +6,11 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/shared/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -23,13 +21,13 @@ type nameserversHandler struct {
accountManager account.Manager
}
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router) {
nameserversHandler := newNameserversHandler(accountManager)
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getAllNameservers)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Create, nameserversHandler.createNameserverGroup)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Update, nameserversHandler.updateNameserverGroup)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getNameserverGroup)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Delete, nameserversHandler.deleteNameserverGroup)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS")
}
// newNameserversHandler returns a new instance of nameserversHandler handler
@@ -38,8 +36,17 @@ func newNameserversHandler(accountManager account.Manager) *nameserversHandler {
}
// getAllNameservers returns the list of nameserver groups for the account
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -54,9 +61,17 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re
}
// createNameserverGroup handles nameserver group creation request
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PostApiDnsNameserversJSONRequestBody
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -68,7 +83,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), userAuth.AccountId, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userAuth.UserId, req.SearchDomainsEnabled)
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -80,7 +95,15 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
}
// updateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
@@ -88,7 +111,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
}
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -112,7 +135,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled,
}
err = h.accountManager.SaveNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, updatedNSGroup)
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -124,14 +147,22 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
}
// deleteNameserverGroup handles nameserver group deletion request
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
err := h.accountManager.DeleteNameServerGroup(r.Context(), userAuth.AccountId, nsGroupID, userAuth.UserId)
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -141,14 +172,22 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt
}
// getNameserverGroup handles a nameserver group Get request identified by ID
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, nsGroupID)
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -18,7 +18,6 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
@@ -202,10 +201,10 @@ func TestNameserversHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.getNameserverGroup)).Methods("GET")
router.HandleFunc("/api/dns/nameservers", permissions.WrapHandler(p.createNameserverGroup)).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.deleteNameserverGroup)).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.updateNameserverGroup)).Methods("PUT")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")
router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -5,13 +5,11 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/shared/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -21,10 +19,10 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
eventsHandler := newHandler(accountManager)
router.HandleFunc("/events", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
router.HandleFunc("/events/audit", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
router.HandleFunc("/events/audit", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
}
// newHandler creates a new events handler
@@ -33,8 +31,17 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllEvents list of the given account
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
accountEvents, err := h.accountManager.GetEvents(r.Context(), userAuth.AccountId, userAuth.UserId)
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -13,7 +13,6 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
@@ -197,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/events/", permissions.WrapHandler(handler.getAllEvents)).Methods("GET")
router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -7,13 +7,11 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -21,45 +19,46 @@ import (
// handler is a handler that returns groups of the account
type handler struct {
accountManager account.Manager
permissionsManager permissions.Manager
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
groupsHandler := newHandler(accountManager, permissionsManager)
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS")
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS")
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS")
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getGroup)).Methods("GET", "OPTIONS")
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Delete, groupsHandler.deleteGroup)).Methods("DELETE", "OPTIONS")
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
groupsHandler := newHandler(accountManager)
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS")
}
// newHandler creates a new groups handler
func newHandler(accountManager account.Manager, permissionsManager permissions.Manager) *handler {
func newHandler(accountManager account.Manager) *handler {
return &handler{
accountManager: accountManager,
permissionsManager: permissionsManager,
accountManager: accountManager,
}
}
func (h *handler) canReadPeers(r *http.Request, userAuth *auth.UserAuth) bool {
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read)
return err == nil && allowed
}
// getAllGroups list for the account
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
// Check if filtering by name
groupName := r.URL.Query().Get("name")
if groupName != "" {
// Get single group by name
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, userAuth.AccountId)
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -72,13 +71,13 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth
}
// Get all groups
groups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -93,7 +92,15 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth
}
// updateGroup handles update to a group identified by a given ID
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
groupID, ok := vars["groupId"]
if !ok {
@@ -105,13 +112,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *
return
}
existingGroup, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", userAuth.AccountId)
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -159,13 +166,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *
IntegrationReference: existingGroup.IntegrationReference,
}
if err := h.accountManager.UpdateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, userAuth.AccountId, err)
if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -175,9 +182,17 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *
}
// createGroup handles group creation request
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PostApiGroupsJSONRequestBody
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -211,13 +226,13 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *
Issued: types.GroupIssuedAPI,
}
err = h.accountManager.CreateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group)
err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -227,14 +242,22 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *
}
// deleteGroup handles group deletion request
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
err := h.accountManager.DeleteGroup(r.Context(), userAuth.AccountId, userAuth.UserId, groupID)
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
if err != nil {
wrappedErr, ok := err.(interface{ Unwrap() []error })
if ok && len(wrappedErr.Unwrap()) > 0 {
@@ -250,26 +273,34 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request, userAuth *
}
// getGroup returns a group
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
}
func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group {

View File

@@ -13,14 +13,10 @@ import (
"strings"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
@@ -37,18 +33,8 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
}
func initGroupTestData(t *testing.T, initGroups ...*types.Group) *handler {
t.Helper()
ctrl := gomock.NewController(t)
permissionsManagerMock := permissions.NewMockManager(ctrl)
permissionsManagerMock.EXPECT().
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
Return(true, nil).
AnyTimes()
func initGroupTestData(initGroups ...*types.Group) *handler {
return &handler{
permissionsManager: permissionsManagerMock,
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error {
if !strings.HasPrefix(group.ID, "id-") {
@@ -85,14 +71,14 @@ func initGroupTestData(t *testing.T, initGroups ...*types.Group) *handler {
return groups, nil
},
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
}
return nil, status.Errorf(status.NotFound, "unknown group name")
},
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) {
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
return maps.Values(TestPeers), nil
},
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
@@ -142,7 +128,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
p := initGroupTestData(t, group)
p := initGroupTestData(group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -155,7 +141,7 @@ func TestGetGroup(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.getGroup)).Methods("GET")
router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -268,7 +254,7 @@ func TestWriteGroup(t *testing.T) {
},
}
p := initGroupTestData(t)
p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -281,8 +267,8 @@ func TestWriteGroup(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/groups", permissions.WrapHandler(p.createGroup)).Methods("POST")
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.updateGroup)).Methods("PUT")
router.HandleFunc("/api/groups", p.createGroup).Methods("POST")
router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -346,7 +332,7 @@ func TestGetAllGroups(t *testing.T) {
},
}
p := initGroupTestData(t)
p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -359,7 +345,7 @@ func TestGetAllGroups(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/groups", permissions.WrapHandler(p.getAllGroups)).Methods("GET")
router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -428,7 +414,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
p := initGroupTestData(t)
p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -440,7 +426,7 @@ func TestDeleteGroup(t *testing.T) {
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.deleteGroup)).Methods("DELETE")
router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -6,12 +6,9 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -23,13 +20,13 @@ type handler struct {
}
// AddEndpoints registers identity provider endpoints
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
h := newHandler(accountManager)
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getAllIdentityProviders)).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Create, h.createIdentityProvider)).Methods("POST", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getIdentityProvider)).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Update, h.updateIdentityProvider)).Methods("PUT", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Delete, h.deleteIdentityProvider)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS")
}
func newHandler(accountManager account.Manager) *handler {
@@ -39,8 +36,16 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllIdentityProviders returns all identity providers for the account
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
providers, err := h.accountManager.GetIdentityProviders(r.Context(), userAuth.AccountId, userAuth.UserId)
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -55,7 +60,15 @@ func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request
}
// getIdentityProvider returns a specific identity provider
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
@@ -63,7 +76,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, us
return
}
provider, err := h.accountManager.GetIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId)
provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -73,7 +86,15 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, us
}
// createIdentityProvider creates a new identity provider
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.IdentityProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -82,7 +103,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request,
idp := fromAPIRequest(&req)
created, err := h.accountManager.CreateIdentityProvider(r.Context(), userAuth.AccountId, userAuth.UserId, idp)
created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -92,7 +113,15 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request,
}
// updateIdentityProvider updates an existing identity provider
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
@@ -108,7 +137,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request,
idp := fromAPIRequest(&req)
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId, idp)
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -118,7 +147,15 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request,
}
// deleteIdentityProvider deletes an identity provider
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
@@ -126,7 +163,7 @@ func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request,
return
}
if err := h.accountManager.DeleteIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId); err != nil {
if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
@@ -121,7 +120,7 @@ func TestGetAllIdentityProviders(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.getAllIdentityProviders)).Methods("GET")
router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -181,7 +180,7 @@ func TestGetIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.getIdentityProvider)).Methods("GET")
router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -243,7 +242,7 @@ func TestCreateIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.createIdentityProvider)).Methods("POST")
router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -329,7 +328,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.updateIdentityProvider)).Methods("PUT")
router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -389,7 +388,7 @@ func TestDeleteIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.deleteIdentityProvider)).Methods("DELETE")
router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -7,11 +7,7 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -33,12 +29,12 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
}
// AddVersionEndpoint registers the authenticated version endpoint.
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router, permissionsManager permissions.Manager) {
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
h := &handler{
instanceManager: instanceManager,
}
router.HandleFunc("/instance/version", permissionsManager.WithPermission(modules.Settings, operations.Read, h.getVersionInfo)).Methods("GET", "OPTIONS")
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
}
// getInstanceStatus returns the instance status including whether setup is required.
@@ -81,7 +77,7 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
// getVersionInfo returns version information for NetBird components.
// This endpoint requires authentication.
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)

View File

@@ -10,17 +10,12 @@ import (
"net/mail"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/idp"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -300,15 +295,8 @@ func TestSetup_ManagerError(t *testing.T) {
func TestGetVersionInfo_Success(t *testing.T) {
manager := &mockInstanceManager{}
ctrl := gomock.NewController(t)
permissionsManager := permissions.NewMockManager(ctrl)
permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handler(w, r, &auth.UserAuth{})
}
}).AnyTimes()
router := mux.NewRouter()
AddVersionEndpoint(manager, router, permissionsManager)
AddVersionEndpoint(manager, router)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()
@@ -335,15 +323,8 @@ func TestGetVersionInfo_Error(t *testing.T) {
return nil, errors.New("failed to fetch versions")
},
}
ctrl := gomock.NewController(t)
permissionsManager := permissions.NewMockManager(ctrl)
permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handler(w, r, &auth.UserAuth{})
}
}).AnyTimes()
router := mux.NewRouter()
AddVersionEndpoint(manager, router, permissionsManager)
AddVersionEndpoint(manager, router)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()

View File

@@ -9,10 +9,8 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
@@ -20,7 +18,6 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/networks/types"
nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -36,16 +33,16 @@ type handler struct {
groupsManager groups.Manager
}
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, permissionsManager permissions.Manager, router *mux.Router) {
addRouterEndpoints(routerManager, permissionsManager, router)
addResourceEndpoints(resourceManager, groupsManager, permissionsManager, router)
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, router *mux.Router) {
addRouterEndpoints(routerManager, router)
addResourceEndpoints(resourceManager, groupsManager, router)
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager)
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getAllNetworks)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Create, networksHandler.createNetwork)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getNetwork)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Update, networksHandler.updateNetwork)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, networksHandler.deleteNetwork)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
}
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager) *handler {
@@ -58,32 +55,40 @@ func newHandler(networksManager networks.Manager, resourceManager resources.Mana
}
}
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networks, err := h.networksManager.GetAllNetworks(r.Context(), userAuth.AccountId, userAuth.UserId)
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
accountID, userID := userAuth.AccountId, userAuth.UserId
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), userAuth.AccountId, userAuth.UserId)
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -92,9 +97,16 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request, userAut
util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account))
}
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkRequest
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -103,14 +115,14 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth
network := &types.Network{}
network.FromAPIRequest(&req)
network.AccountID = userAuth.AccountId
network, err = h.networksManager.CreateNetwork(r.Context(), userAuth.UserId, network)
network.AccountID = accountID
network, err = h.networksManager.CreateNetwork(r.Context(), userID, network)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
account, err := h.accountManager.GetAccount(r.Context(), accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -121,7 +133,14 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs))
}
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
@@ -129,19 +148,19 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *a
return
}
network, err := h.networksManager.GetNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
account, err := h.accountManager.GetAccount(r.Context(), accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -152,7 +171,14 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *a
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
}
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
@@ -161,7 +187,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth
}
var req api.NetworkRequest
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -171,20 +197,20 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth
network.FromAPIRequest(&req)
network.ID = networkID
network.AccountID = userAuth.AccountId
network, err = h.networksManager.UpdateNetwork(r.Context(), userAuth.UserId, network)
network.AccountID = accountID
network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
account, err := h.accountManager.GetAccount(r.Context(), accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -195,7 +221,14 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
}
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
@@ -203,7 +236,7 @@ func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request, userAuth
return
}
err := h.networksManager.DeleteNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -6,13 +6,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -22,14 +19,14 @@ type resourceHandler struct {
groupsManager groups.Manager
}
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, permissionsManager permissions.Manager, router *mux.Router) {
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) {
resourceHandler := newResourceHandler(resourcesManager, groupsManager)
router.HandleFunc("/networks/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInAccount)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInNetwork)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Create, resourceHandler.createResource)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getResource)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Update, resourceHandler.updateResource)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, resourceHandler.deleteResource)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS")
}
func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler {
@@ -39,15 +36,22 @@ func newResourceHandler(resourceManager resources.Manager, groupsManager groups.
}
}
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networkID := mux.Vars(r)["networkId"]
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -62,14 +66,22 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, resourcesResponse)
}
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
accountID, userID := userAuth.AccountId, userAuth.UserId
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -85,9 +97,17 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, resourcesResponse)
}
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkResourceRequest
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -97,14 +117,14 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request,
resource.FromAPIRequest(&req)
resource.NetworkID = mux.Vars(r)["networkId"]
resource.AccountID = userAuth.AccountId
resource, err = h.resourceManager.CreateResource(r.Context(), userAuth.UserId, resource)
resource.AccountID = accountID
resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -115,16 +135,23 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request,
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
}
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]
resource, err := h.resourceManager.GetResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -135,9 +162,16 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request, us
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
}
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkResourceRequest
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -148,14 +182,14 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request,
resource.ID = mux.Vars(r)["resourceId"]
resource.NetworkID = mux.Vars(r)["networkId"]
resource.AccountID = userAuth.AccountId
resource, err = h.resourceManager.UpdateResource(r.Context(), userAuth.UserId, resource)
resource.AccountID = accountID
resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -166,10 +200,17 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request,
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
}
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]
err := h.resourceManager.DeleteResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -6,12 +6,9 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -20,14 +17,14 @@ type routersHandler struct {
routersManager routers.Manager
}
func addRouterEndpoints(routersManager routers.Manager, permissionsManager permissions.Manager, router *mux.Router) {
func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) {
routersHandler := newRoutersHandler(routersManager)
router.HandleFunc("/networks/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getAllRouters)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getNetworkRouters)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Create, routersHandler.createRouter)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getRouter)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Update, routersHandler.updateRouter)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, routersHandler.deleteRouter)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/networks/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getNetworkRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS")
}
func newRoutersHandler(routersManager routers.Manager) *routersHandler {
@@ -36,8 +33,16 @@ func newRoutersHandler(routersManager routers.Manager) *routersHandler {
}
}
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -53,9 +58,17 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request, u
util.WriteJSONObject(r.Context(), w, routersResponse)
}
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -69,10 +82,18 @@ func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Reques
util.WriteJSONObject(r.Context(), w, routersResponse)
}
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
var req api.NetworkRouterRequest
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -82,7 +103,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, us
router.FromAPIRequest(&req)
router.NetworkID = networkID
router.AccountID = userAuth.AccountId
router.AccountID = accountID
router.Enabled = true
if err := router.Validate(); err != nil {
@@ -90,7 +111,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, us
return
}
router, err = h.routersManager.CreateRouter(r.Context(), userAuth.UserId, router)
router, err = h.routersManager.CreateRouter(r.Context(), userID, router)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -99,10 +120,18 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, us
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
router, err := h.routersManager.GetRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -111,9 +140,17 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request, userA
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkRouterRequest
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -124,14 +161,14 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, us
router.NetworkID = mux.Vars(r)["networkId"]
router.ID = mux.Vars(r)["routerId"]
router.AccountID = userAuth.AccountId
router.AccountID = accountID
if err := router.Validate(); err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
router, err = h.routersManager.UpdateRouter(r.Context(), userAuth.UserId, router)
router, err = h.routersManager.UpdateRouter(r.Context(), userID, router)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -140,10 +177,17 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, us
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
err := h.routersManager.DeleteRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -1,6 +1,7 @@
package peers
import (
"context"
"encoding/json"
"fmt"
"net/http"
@@ -11,15 +12,15 @@ import (
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -34,15 +35,14 @@ type Handler struct {
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
router.HandleFunc("/peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAllPeers)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetPeer)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Update, peersHandler.UpdatePeer)).Methods("PUT", "OPTIONS")
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Delete, peersHandler.DeletePeer)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAccessiblePeers)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/temporary-access", permissionsManager.WithPermission(modules.Peers, operations.Create, peersHandler.CreateTemporaryAccess)).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.ListJobs)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Create, peersHandler.CreateJob)).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.GetJob)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS")
}
// NewHandler creates a new peers Handler
@@ -54,7 +54,14 @@ func NewHandler(accountManager account.Manager, networkMapController network_map
}
}
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
@@ -66,30 +73,37 @@ func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request, userAuth *au
job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
if err := h.accountManager.CreatePeerJob(r.Context(), userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
util.WriteError(r.Context(), err, w)
if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
util.WriteError(ctx, err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
util.WriteJSONObject(ctx, w, resp)
}
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
jobs, err := h.accountManager.GetAllPeerJobs(r.Context(), userAuth.AccountId, userAuth.UserId, peerID)
jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
@@ -97,88 +111,79 @@ func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request, userAuth *aut
for _, job := range jobs {
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
respBody = append(respBody, resp)
}
util.WriteJSONObject(r.Context(), w, respBody)
util.WriteJSONObject(ctx, w, respBody)
}
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
jobID := vars["jobId"]
job, err := h.accountManager.GetPeerJobByID(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, jobID)
job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
util.WriteJSONObject(ctx, w, resp)
}
// GetPeer handles GET request for a single peer
func (h *Handler) GetPeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
peer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
if peer.ProxyMeta.Embedded {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
return
}
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peerID)
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
return
}
_, valid := validPeers[peer.ID]
reason := invalidPeers[peer.ID]
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
// UpdatePeer handles PUT request to update a peer
func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -187,10 +192,11 @@ func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *a
}
update := &nbpeer.Peer{
ID: peerID,
SSHEnabled: req.SshEnabled,
Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled,
ID: peerID,
SSHEnabled: req.SshEnabled,
Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled,
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
}
@@ -204,41 +210,41 @@ func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *a
if req.Ip != nil {
addr, err := netip.ParseAddr(*req.Ip)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
return
}
if err = h.accountManager.UpdatePeerIP(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, addr); err != nil {
util.WriteError(r.Context(), err, w)
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil {
util.WriteError(ctx, err, w)
return
}
}
peer, err := h.accountManager.UpdatePeer(r.Context(), userAuth.AccountId, userAuth.UserId, update)
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
dnsDomain := h.networkMapController.GetDNSDomain(settings)
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peer.ID)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
return
}
@@ -248,8 +254,25 @@ func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *a
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
// DeletePeer handles DELETE request to delete a peer
func (h *Handler) DeletePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(ctx, w, util.EmptyObject{})
}
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
@@ -257,34 +280,48 @@ func (h *Handler) DeletePeer(w http.ResponseWriter, r *http.Request, userAuth *a
return
}
err := h.accountManager.DeletePeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to delete peer: %v", err)
util.WriteError(r.Context(), err, w)
switch r.Method {
case http.MethodDelete:
h.deletePeer(r.Context(), accountID, userID, peerID, w)
return
case http.MethodGet:
h.getPeer(r.Context(), accountID, peerID, userID, w)
return
case http.MethodPut:
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
return
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
// GetAllPeers returns a list of all peers associated with a provided account
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nameFilter := r.URL.Query().Get("name")
ipFilter := r.URL.Query().Get("ip")
peers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, nameFilter, ipFilter, true)
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
nameFilter := r.URL.Query().Get("name")
ipFilter := r.URL.Query().Get("ip")
accountID, userID := userAuth.AccountId, userAuth.UserId
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, nameFilter, ipFilter)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
respBody := make([]*api.PeerBatch, 0, len(peers))
@@ -295,7 +332,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request, userAuth *
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
}
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -319,7 +356,15 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersM
}
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
@@ -327,22 +372,25 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, use
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
user, err := h.accountManager.GetUserByID(r.Context(), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), userAuth.AccountId, activity.SystemInitiator)
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
if err != nil {
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
// Check if user is an admin/service user through their role
isAdmin := user.Role == types.UserRoleAdmin || user.Role == types.UserRoleOwner
if !isAdmin && !user.IsServiceUser && !userAuth.IsChild {
if !allowed && !userAuth.IsChild {
if account.Settings.RegularUsersViewBlocked {
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
return
@@ -360,7 +408,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, use
}
}
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -369,12 +417,18 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, use
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
@@ -383,7 +437,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request,
}
var req api.PeerTemporaryAccessRequest
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return

View File

@@ -19,11 +19,11 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -174,7 +174,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
return nil, fmt.Errorf("user not found")
}
},
GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) {
GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
return peers, nil
},
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
@@ -307,9 +307,9 @@ func TestGetPeers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/peers/", permissions.WrapHandler(p.GetAllPeers)).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.GetPeer)).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT")
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -498,7 +498,7 @@ func TestGetAccessiblePeers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", permissions.WrapHandler(p.GetAccessiblePeers)).Methods("GET")
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -582,7 +582,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
rr := httptest.NewRecorder()
router := mux.NewRouter()
router.HandleFunc("/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT")
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.ServeHTTP(rr, req)

View File

@@ -14,12 +14,12 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -121,7 +121,7 @@ func TestGetCitiesByCountry(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries/{country}/cities", permissions.WrapHandler(geolocationHandler.getCitiesByCountry)).Methods("GET")
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -214,7 +214,7 @@ func TestGetAllCountries(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries", permissions.WrapHandler(geolocationHandler.getAllCountries)).Methods("GET")
router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -6,12 +6,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -30,8 +30,8 @@ type geolocationsHandler struct {
func AddLocationsEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, permissionsManager permissions.Manager, router *mux.Router) {
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, permissionsManager)
router.HandleFunc("/locations/countries", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getAllCountries)).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getCitiesByCountry)).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
}
// newGeolocationsHandlerHandler creates a new Geolocations handler
@@ -44,7 +44,12 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
}
// getAllCountries retrieves a list of all countries
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
@@ -65,7 +70,12 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
}
// getCitiesByCountry retrieves a list of cities based on the given country code
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) {
@@ -92,6 +102,27 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
util.WriteJSONObject(r.Context(), w, cities)
}
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
return err
}
accountID, userID := userAuth.AccountId, userAuth.UserId
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
return nil
}
func toCountryResponse(country geolocation.Country) api.Country {
return api.Country{
CountryName: country.CountryName,

View File

@@ -7,13 +7,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -24,13 +21,13 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
policiesHandler := newHandler(accountManager)
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getAllPolicies)).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Create, policiesHandler.createPolicy)).Methods("POST", "OPTIONS")
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Update, policiesHandler.updatePolicy)).Methods("PUT", "OPTIONS")
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getPolicy)).Methods("GET", "OPTIONS")
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, policiesHandler.deletePolicy)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
}
// newHandler creates a new policies handler
@@ -41,14 +38,22 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllPolicies list for the account
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
listPolicies, err := h.accountManager.ListPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
accountID, userID := userAuth.AccountId, userAuth.UserId
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -68,7 +73,15 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request, userAut
}
// updatePolicy handles update to a policy identified by a given ID
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@@ -76,18 +89,26 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request, userAuth
return
}
_, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, policyID, false)
h.savePolicy(w, r, accountID, userID, policyID, false)
}
// createPolicy handles policy creation request
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, "", true)
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
h.savePolicy(w, r, accountID, userID, "", true)
}
// savePolicy handles policy creation and update
@@ -282,7 +303,14 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
}
// deletePolicy handles policy deletion request
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@@ -290,7 +318,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth
return
}
if err := h.accountManager.DeletePolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId); err != nil {
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -299,7 +327,15 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth
}
// getPolicy handles a group Get request identified by ID
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@@ -307,13 +343,13 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request, userAuth *au
return
}
policy, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -13,7 +13,6 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
@@ -112,7 +111,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.getPolicy)).Methods("GET")
router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -276,8 +275,8 @@ func TestPoliciesWritePolicy(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/policies", permissions.WrapHandler(p.createPolicy)).Methods("POST")
router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.updatePolicy)).Methods("PUT")
router.HandleFunc("/api/policies", p.createPolicy).Methods("POST")
router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -6,13 +6,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -24,13 +21,13 @@ type postureChecksHandler struct {
geolocationManager geolocation.Geolocation
}
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager)
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getAllPostureChecks)).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Create, postureCheckHandler.createPostureCheck)).Methods("POST", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Update, postureCheckHandler.updatePostureCheck)).Methods("PUT", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getPostureCheck)).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, postureCheckHandler.deletePostureCheck)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
}
// newPostureChecksHandler creates a new PostureChecks handler
@@ -42,8 +39,15 @@ func newPostureChecksHandler(accountManager account.Manager, geolocationManager
}
// getAllPostureChecks list for the account
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), userAuth.AccountId, userAuth.UserId)
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -58,30 +62,15 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt
}
// updatePostureCheck handles update to a posture check identified by a given ID
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return
}
_, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, postureChecksID, false)
}
accountID, userID := userAuth.AccountId, userAuth.UserId
// createPostureCheck handles posture check creation request
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, "", true)
}
// getPostureCheck handles a posture check Get request identified by ID
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@@ -89,7 +78,45 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
return
}
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
p.savePostureChecks(w, r, accountID, userID, postureChecksID, false)
}
// createPostureCheck handles posture check creation request
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
p.savePostureChecks(w, r, accountID, userID, "", true)
}
// getPostureCheck handles a posture check Get request identified by ID
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return
}
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -99,7 +126,14 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
}
// deletePostureCheck handles posture check deletion request
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@@ -107,7 +141,7 @@ func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http
return
}
if err := p.accountManager.DeletePostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId); err != nil {
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/mock_server"
@@ -184,7 +183,7 @@ func TestGetPostureCheck(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(p.getPostureCheck)).Methods("GET")
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -842,8 +841,8 @@ func TestPostureCheckUpdate(t *testing.T) {
}
router := mux.NewRouter()
router.HandleFunc("/api/posture-checks", permissions.WrapHandler(defaultHandler.createPostureCheck)).Methods("POST")
router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(defaultHandler.updatePostureCheck)).Methods("PUT")
router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST")
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -8,12 +8,9 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -29,13 +26,13 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
routesHandler := newHandler(accountManager)
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getAllRoutes)).Methods("GET", "OPTIONS")
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Create, routesHandler.createRoute)).Methods("POST", "OPTIONS")
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Update, routesHandler.updateRoute)).Methods("PUT", "OPTIONS")
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getRoute)).Methods("GET", "OPTIONS")
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Delete, routesHandler.deleteRoute)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS")
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS")
}
// newHandler returns a new instance of routes handler
@@ -46,8 +43,16 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllRoutes returns the list of routes for the account
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routes, err := h.accountManager.ListRoutes(r.Context(), userAuth.AccountId, userAuth.UserId)
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -66,9 +71,17 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request, userAuth
}
// createRoute handles route creation request
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PostApiRoutesJSONRequestBody
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -121,8 +134,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request, userAuth *
skipAutoApply = false
}
newRoute, err := h.accountManager.CreateRoute(r.Context(), userAuth.AccountId, newPrefix, networkType, domains, peerId, peerGroupIds,
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userAuth.UserId, req.KeepRoute, skipAutoApply)
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -172,7 +185,14 @@ func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *
}
// updateRoute handles update to a route identified by a given ID
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
routeID := vars["routeId"]
if len(routeID) == 0 {
@@ -180,7 +200,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *
return
}
_, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -251,7 +271,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *
newRoute.AccessControlGroups = *req.AccessControlGroups
}
err = h.accountManager.SaveRoute(r.Context(), userAuth.AccountId, userAuth.UserId, newRoute)
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -267,14 +287,21 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *
}
// deleteRoute handles route deletion request
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
err := h.accountManager.DeleteRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -284,14 +311,22 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request, userAuth *
}
// getRoute handles a route Get request identified by ID
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
foundRoute, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/util"
@@ -502,10 +501,10 @@ func TestRoutesHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.getRoute)).Methods("GET")
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.deleteRoute)).Methods("DELETE")
router.HandleFunc("/api/routes", permissions.WrapHandler(p.createRoute)).Methods("POST")
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.updateRoute)).Methods("PUT")
router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET")
router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE")
router.HandleFunc("/api/routes", p.createRoute).Methods("POST")
router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -8,12 +8,9 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -24,13 +21,13 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
keysHandler := newHandler(accountManager)
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getAllSetupKeys)).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Create, keysHandler.createSetupKey)).Methods("POST", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getSetupKey)).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Update, keysHandler.updateSetupKey)).Methods("PUT", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Delete, keysHandler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS")
}
// newHandler creates a new setup key handler
@@ -41,9 +38,16 @@ func newHandler(accountManager account.Manager) *handler {
}
// createSetupKey is a POST requests that creates a new SetupKey
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
req := &api.PostApiSetupKeysJSONRequestBody{}
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -81,8 +85,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAut
allowExtraDNSLabels = *req.AllowExtraDnsLabels
}
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), userAuth.AccountId, req.Name, types.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, userAuth.UserId, ephemeral, allowExtraDNSLabels)
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, userID, ephemeral, allowExtraDNSLabels)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -96,7 +100,14 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAut
}
// getSetupKey is a GET request to get a SetupKey by ID
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
@@ -104,7 +115,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *
return
}
key, err := h.accountManager.GetSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -114,7 +125,14 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *
}
// updateSetupKey is a PUT request to update server.SetupKey
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
@@ -123,7 +141,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAut
}
req := &api.PutApiSetupKeysKeyIdJSONRequestBody{}
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -139,7 +157,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAut
newKey.Revoked = req.Revoked
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), userAuth.AccountId, newKey, userAuth.UserId)
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -148,8 +166,15 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAut
}
// getAllSetupKeys is a GET request that returns a list of SetupKey
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), userAuth.AccountId, userAuth.UserId)
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -163,7 +188,14 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request, userAu
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
}
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
@@ -171,7 +203,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request, userAut
return
}
err := h.accountManager.DeleteSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -14,7 +14,6 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
@@ -172,11 +171,11 @@ func TestSetupKeysHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.getAllSetupKeys)).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.createSetupKey)).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.getSetupKey)).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.updateSetupKey)).Methods("PUT", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -9,13 +9,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -58,14 +55,14 @@ type invitesHandler struct {
}
// AddInvitesEndpoints registers invite-related endpoints
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
h := &invitesHandler{accountManager: accountManager}
// Authenticated endpoints (require admin)
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Read, h.listInvites)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Create, h.createInvite)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", permissionsManager.WithPermission(modules.Users, operations.Delete, h.deleteInvite)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", permissionsManager.WithPermission(modules.Users, operations.Update, h.regenerateInvite)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
}
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
@@ -82,7 +79,14 @@ func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Route
}
// listInvites handles GET /api/users/invites
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -98,7 +102,14 @@ func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, use
}
// createInvite handles POST /api/users/invites
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.UserInviteCreateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -180,12 +191,18 @@ func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
}
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
@@ -221,7 +238,14 @@ func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request
}
// deleteInvite handles DELETE /api/users/invites/{inviteId}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
@@ -229,7 +253,7 @@ func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, us
return
}
err := h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -110,11 +110,7 @@ func TestListInvites(t *testing.T) {
})
rr := httptest.NewRecorder()
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.listInvites(rr, req, userAuth)
handler.listInvites(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -239,11 +235,7 @@ func TestCreateInvite(t *testing.T) {
})
rr := httptest.NewRecorder()
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.createInvite(rr, req, userAuth)
handler.createInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -581,11 +573,7 @@ func TestRegenerateInvite(t *testing.T) {
}
rr := httptest.NewRecorder()
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.regenerateInvite(rr, req, userAuth)
handler.regenerateInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -663,11 +651,7 @@ func TestDeleteInvite(t *testing.T) {
}
rr := httptest.NewRecorder()
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.deleteInvite(rr, req, userAuth)
handler.deleteInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
})

View File

@@ -6,12 +6,9 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -22,12 +19,12 @@ type patHandler struct {
accountManager account.Manager
}
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router) {
tokenHandler := newPATsHandler(accountManager)
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getAllTokens)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Create, tokenHandler.createToken)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getToken)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Delete, tokenHandler.deleteToken)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
}
// newPATsHandler creates a new patHandler HTTP handler
@@ -38,15 +35,22 @@ func newPATsHandler(accountManager account.Manager) *patHandler {
}
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
if len(userID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
pats, err := h.accountManager.GetAllPATs(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -61,7 +65,14 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAu
}
// getToken is HTTP GET handler that returns a personal access token for the given user
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -75,7 +86,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *
return
}
pat, err := h.accountManager.GetPAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -85,7 +96,14 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *
}
// createToken is HTTP POST handler that creates a personal access token for the given user
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -94,13 +112,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAut
}
var req api.PostApiUsersUserIdTokensJSONRequestBody
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
pat, err := h.accountManager.CreatePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.Name, req.ExpiresIn)
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -110,7 +128,14 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAut
}
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -124,7 +149,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAut
return
}
err := h.accountManager.DeletePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
@@ -182,10 +181,10 @@ func TestTokenHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.getAllTokens)).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.getToken)).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.createToken)).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.deleteToken)).Methods("DELETE")
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -8,16 +8,14 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
nbcontext "github.com/netbirdio/netbird/management/server/context"
)
// handler is a handler that returns users of the account
@@ -25,18 +23,18 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
userHandler := newHandler(accountManager)
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getAllUsers, userHandler.getOwnUser)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/current", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getCurrentUser, userHandler.getCurrentUserFallback)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.updateUser)).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.deleteUser)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.createUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.inviteUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.approveUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.rejectUser)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/password", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.changePassword)).Methods("PUT", "OPTIONS")
addUsersTokensEndpoint(accountManager, router, permissionsManager)
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS")
addUsersTokensEndpoint(accountManager, router)
}
// newHandler creates a new UsersHandler HTTP handler
@@ -47,12 +45,19 @@ func newHandler(accountManager account.Manager) *handler {
}
// updateUser is a PUT requests to update User data
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -66,11 +71,6 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *a
return
}
if existingUser.AccountID != userAuth.AccountId {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "user not found"), w)
return
}
req := &api.PutApiUsersUserIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -89,7 +89,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *a
return
}
newUser, err := h.accountManager.SaveUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.User{
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{
Id: targetUserID,
Role: userRole,
AutoGroups: req.AutoGroups,
@@ -102,16 +102,23 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *a
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
}
// deleteUser is a DELETE request to delete a user
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -119,7 +126,7 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *a
return
}
err := h.accountManager.DeleteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -129,14 +136,21 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *a
}
// createUser creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
req := &api.PostApiUsersJSONRequestBody{}
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -157,7 +171,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *a
name = *req.Name
}
newUser, err := h.accountManager.CreateUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.UserInfo{
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{
Email: email,
Name: name,
Role: req.Role,
@@ -169,18 +183,25 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *a
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
}
// getAllUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager.
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
data, err := h.accountManager.GetUsersFromAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -194,7 +215,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *
continue
}
if serviceUser == "" {
users = append(users, toUserResponse(d, userAuth.UserId))
users = append(users, toUserResponse(d, userID))
continue
}
@@ -205,7 +226,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *
return
}
if includeServiceUser == d.IsServiceUser {
users = append(users, toUserResponse(d, userAuth.UserId))
users = append(users, toUserResponse(d, userID))
}
}
@@ -214,12 +235,19 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *
// inviteUser resend invitations to users who haven't activated their accounts,
// prior to the expiration period.
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -227,7 +255,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *a
return
}
err := h.accountManager.InviteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -236,13 +264,19 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *a
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -322,7 +356,7 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
}
// approveUser is a POST request to approve a user that is pending approval
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -335,6 +369,11 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -346,7 +385,7 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *
}
// rejectUser is a DELETE request to reject a user that is pending approval
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -359,7 +398,12 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *a
return
}
err := h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -377,7 +421,7 @@ type passwordChangeRequest struct {
// changePassword is a PUT request to change user's password.
// Only available when embedded IDP is enabled.
// Users can only change their own password.
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -390,13 +434,19 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAut
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req passwordChangeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err := h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -404,39 +454,3 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAut
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func (h *handler) getCurrentUserFallback(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.PermissionDenied {
return false
}
user, userErr := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
if userErr != nil {
util.WriteError(r.Context(), userErr, w)
return true
}
util.WriteJSONObject(r.Context(), w, toUserWithPermissionsResponse(user, userAuth.UserId))
return true
}
func (h *handler) getOwnUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.PermissionDenied {
return false
}
if r.URL.Query().Get("service_user") != "" {
return false
}
user, userErr := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
if userErr != nil {
util.WriteError(r.Context(), userErr, w)
return true
}
util.WriteJSONObject(r.Context(), w, []*api.User{toUserResponse(user.UserInfo, userAuth.UserId)})
return true
}

View File

@@ -15,11 +15,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
roles2 "github.com/netbirdio/netbird/management/internals/modules/permissions/roles"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth"
@@ -39,7 +38,6 @@ var usersTestAccount = &types.Account{
Users: map[string]*types.User{
existingUserID: {
Id: existingUserID,
AccountID: existingAccountID,
Role: "admin",
IsServiceUser: false,
AutoGroups: []string{"group_1"},
@@ -47,7 +45,6 @@ var usersTestAccount = &types.Account{
},
regularUserID: {
Id: regularUserID,
AccountID: existingAccountID,
Role: "user",
IsServiceUser: false,
AutoGroups: []string{"group_1"},
@@ -55,7 +52,6 @@ var usersTestAccount = &types.Account{
},
serviceUserID: {
Id: serviceUserID,
AccountID: existingAccountID,
Role: "user",
IsServiceUser: true,
AutoGroups: []string{"group_1"},
@@ -63,7 +59,6 @@ var usersTestAccount = &types.Account{
},
nonDeletableServiceUserID: {
Id: nonDeletableServiceUserID,
AccountID: existingAccountID,
Role: "admin",
IsServiceUser: true,
NonDeletable: true,
@@ -156,7 +151,7 @@ func initUsersTestData() *handler {
NonDeletable: false,
Issued: "api",
},
Permissions: mergeRolePermissions(roles2.Owner),
Permissions: mergeRolePermissions(roles.Owner),
}, nil
case "regular-user":
return &users.UserInfoWithPermissions{
@@ -170,7 +165,7 @@ func initUsersTestData() *handler {
NonDeletable: false,
Issued: "api",
},
Permissions: mergeRolePermissions(roles2.User),
Permissions: mergeRolePermissions(roles.User),
}, nil
case "admin-user":
@@ -186,7 +181,7 @@ func initUsersTestData() *handler {
LastLogin: time.Time{},
Issued: "api",
},
Permissions: mergeRolePermissions(roles2.Admin),
Permissions: mergeRolePermissions(roles.Admin),
}, nil
case "restricted-user":
return &users.UserInfoWithPermissions{
@@ -201,7 +196,7 @@ func initUsersTestData() *handler {
LastLogin: time.Time{},
Issued: "api",
},
Permissions: mergeRolePermissions(roles2.User),
Permissions: mergeRolePermissions(roles.User),
Restricted: true,
}, nil
}
@@ -237,11 +232,7 @@ func TestGetUsers(t *testing.T) {
AccountId: existingAccountID,
})
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.getAllUsers(recorder, req, userAuth)
userHandler.getAllUsers(recorder, req)
res := recorder.Result()
defer res.Body.Close()
@@ -352,7 +343,7 @@ func TestUpdateUser(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", permissions.WrapHandler(userHandler.updateUser)).Methods("PUT")
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -448,11 +439,7 @@ func TestCreateUser(t *testing.T) {
AccountId: existingAccountID,
})
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.createUser(rr, req, userAuth)
userHandler.createUser(rr, req)
res := rr.Result()
defer res.Body.Close()
@@ -503,11 +490,7 @@ func TestInviteUser(t *testing.T) {
rr := httptest.NewRecorder()
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.inviteUser(rr, req, userAuth)
userHandler.inviteUser(rr, req)
res := rr.Result()
defer res.Body.Close()
@@ -566,11 +549,7 @@ func TestDeleteUser(t *testing.T) {
rr := httptest.NewRecorder()
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.deleteUser(rr, req, userAuth)
userHandler.deleteUser(rr, req)
res := rr.Result()
defer res.Body.Close()
@@ -629,7 +608,7 @@ func TestCurrentUser(t *testing.T) {
Issued: ptr("api"),
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Owner)),
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Owner)),
},
},
},
@@ -648,7 +627,7 @@ func TestCurrentUser(t *testing.T) {
Issued: ptr("api"),
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)),
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
},
},
},
@@ -667,7 +646,7 @@ func TestCurrentUser(t *testing.T) {
Issued: ptr("api"),
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Admin)),
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Admin)),
},
},
},
@@ -687,7 +666,7 @@ func TestCurrentUser(t *testing.T) {
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
IsRestricted: true,
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)),
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
},
},
},
@@ -703,11 +682,7 @@ func TestCurrentUser(t *testing.T) {
rr := httptest.NewRecorder()
userAuth := &auth.UserAuth{
UserId: tc.requestAuth.UserId,
AccountId: existingAccountID,
}
userHandler.getCurrentUser(rr, req, userAuth)
userHandler.getCurrentUser(rr, req)
res := rr.Result()
defer res.Body.Close()
@@ -727,8 +702,8 @@ func ptr[T any, PT *T](x T) PT {
return &x
}
func mergeRolePermissions(role roles2.RolePermissions) roles2.Permissions {
permissions := roles2.Permissions{}
func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
permissions := roles.Permissions{}
for k := range modules.All {
if rolePermissions, ok := role.Permissions[k]; ok {
@@ -741,7 +716,7 @@ func mergeRolePermissions(role roles2.RolePermissions) roles2.Permissions {
return permissions
}
func stringifyPermissionsKeys(permissions roles2.Permissions) map[string]map[string]bool {
func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[string]bool {
modules := make(map[string]map[string]bool)
for module, operations := range permissions {
modules[string(module)] = make(map[string]bool)
@@ -804,7 +779,7 @@ func TestApproveUserEndpoint(t *testing.T) {
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/approve", permissions.WrapHandler(handler.approveUser)).Methods("POST")
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
require.NoError(t, err)
@@ -862,7 +837,7 @@ func TestRejectUserEndpoint(t *testing.T) {
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/reject", permissions.WrapHandler(handler.rejectUser)).Methods("DELETE")
router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE")
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
require.NoError(t, err)
@@ -953,7 +928,7 @@ func TestChangePasswordEndpoint(t *testing.T) {
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/password", permissions.WrapHandler(handler.changePassword)).Methods("PUT")
router.HandleFunc("/users/{userId}/password", handler.changePassword).Methods("PUT")
reqPath := "/users/" + tc.targetUserID + "/password"
req, err := http.NewRequest("PUT", reqPath, bytes.NewBufferString(tc.requestBody))
@@ -992,7 +967,7 @@ func TestChangePasswordEndpoint_WrongMethod(t *testing.T) {
req = nbcontext.SetUserAuthInRequest(req, userAuth)
rr := httptest.NewRecorder()
handler.changePassword(rr, req, &userAuth)
handler.changePassword(rr, req)
assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
}

View File

@@ -27,7 +27,7 @@ func Test_Accounts_GetAll(t *testing.T) {
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Regular service user", testing_tools.TestServiceUserId, true},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
@@ -233,71 +233,6 @@ func Test_Accounts_Update(t *testing.T) {
}
}
func Test_Accounts_Update_CrossAccountAttack(t *testing.T) {
t.Run("Other user attempts to update testAccount via URL", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(&api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: false,
PeerLoginExpiration: 86400,
},
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
// OtherUserId belongs to otherAccountId, but we target testAccountId in URL
req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account update must be rejected")
})
}
func Test_Accounts_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, false},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, false},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Delete account", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
})
}
}
func Test_Accounts_Delete_CrossAccountAttack(t *testing.T) {
t.Run("Other user attempts to delete testAccount via URL", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// OtherUserId belongs to otherAccountId, but we target testAccountId in URL
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account delete must be rejected")
})
}
func stringPointer(s string) *string {
return &s
}

View File

@@ -1,445 +0,0 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Records_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all records", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones/testZoneId/records", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.DNSRecord{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "sub.example.com", got[0].Name)
assert.Equal(t, api.DNSRecordTypeA, got[0].Type)
assert.Equal(t, "1.2.3.4", got[0].Content)
assert.Equal(t, 300, got[0].Ttl)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Records_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
recordId string
expectedStatus int
expectRecord bool
}{
{
name: "Get existing record",
zoneId: "testZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusOK,
expectRecord: true,
},
{
name: "Get non-existing record",
zoneId: "testZoneId",
recordId: "nonExistingRecordId",
expectedStatus: http.StatusNotFound,
expectRecord: false,
},
{
name: "Get record from non-existing zone",
zoneId: "nonExistingZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusNotFound,
expectRecord: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectRecord {
got := &api.DNSRecord{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "testRecordId", got.Id)
assert.Equal(t, "sub.example.com", got.Name)
assert.Equal(t, api.DNSRecordTypeA, got.Type)
assert.Equal(t, "1.2.3.4", got.Content)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Records_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
requestBody *api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, record *api.DNSRecord)
}{
{
name: "Create A record",
zoneId: "testZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "new.example.com",
Type: api.DNSRecordTypeA,
Content: "5.6.7.8",
Ttl: 600,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
t.Helper()
assert.NotEmpty(t, record.Id)
assert.Equal(t, "new.example.com", record.Name)
assert.Equal(t, api.DNSRecordTypeA, record.Type)
assert.Equal(t, "5.6.7.8", record.Content)
assert.Equal(t, 600, record.Ttl)
},
},
{
name: "Create CNAME record",
zoneId: "testZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "alias.example.com",
Type: api.DNSRecordTypeCNAME,
Content: "target.example.com",
Ttl: 300,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
t.Helper()
assert.NotEmpty(t, record.Id)
assert.Equal(t, "alias.example.com", record.Name)
assert.Equal(t, api.DNSRecordTypeCNAME, record.Type)
assert.Equal(t, "target.example.com", record.Content)
},
},
{
name: "Create record with invalid content for A type",
zoneId: "testZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "bad.example.com",
Type: api.DNSRecordTypeA,
Content: "not-an-ip",
Ttl: 300,
},
expectedStatus: http.StatusUnprocessableEntity,
},
{
name: "Create record in non-existing zone",
zoneId: "nonExistingZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "new.example.com",
Type: api.DNSRecordTypeA,
Content: "5.6.7.8",
Ttl: 600,
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
path := strings.Replace("/api/dns/zones/{zoneId}/records", "{zoneId}", tc.zoneId, 1)
req := testing_tools.BuildRequest(t, body, http.MethodPost, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.DNSRecord{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the created record directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbRecord := testing_tools.VerifyRecordInDB(t, db, got.Id)
assert.Equal(t, got.Name, dbRecord.Name)
assert.Equal(t, got.Content, dbRecord.Content)
assert.Equal(t, got.Ttl, dbRecord.TTL)
}
})
}
}
}
func Test_Records_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
recordId string
requestBody *api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, record *api.DNSRecord)
}{
{
name: "Update record content and TTL",
zoneId: "testZoneId",
recordId: "testRecordId",
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "sub.example.com",
Type: api.DNSRecordTypeA,
Content: "10.20.30.40",
Ttl: 600,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
t.Helper()
assert.Equal(t, "sub.example.com", record.Name)
assert.Equal(t, "10.20.30.40", record.Content)
assert.Equal(t, 600, record.Ttl)
},
},
{
name: "Update non-existing record",
zoneId: "testZoneId",
recordId: "nonExistingRecordId",
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "sub.example.com",
Type: api.DNSRecordTypeA,
Content: "10.20.30.40",
Ttl: 600,
},
expectedStatus: http.StatusNotFound,
},
{
name: "Update record in non-existing zone",
zoneId: "nonExistingZoneId",
recordId: "testRecordId",
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "sub.example.com",
Type: api.DNSRecordTypeA,
Content: "10.20.30.40",
Ttl: 600,
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
req := testing_tools.BuildRequest(t, body, http.MethodPut, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.DNSRecord{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the updated record directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbRecord := testing_tools.VerifyRecordInDB(t, db, tc.recordId)
assert.Equal(t, "10.20.30.40", dbRecord.Content)
assert.Equal(t, 600, dbRecord.TTL)
}
})
}
}
}
func Test_Records_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
recordId string
expectedStatus int
}{
{
name: "Delete existing record",
zoneId: "testZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing record",
zoneId: "testZoneId",
recordId: "nonExistingRecordId",
expectedStatus: http.StatusNotFound,
},
{
name: "Delete record from non-existing zone",
zoneId: "nonExistingZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
// Verify deletion in DB for successful deletes by privileged users
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyRecordNotInDB(t, db, tc.recordId)
}
})
}
}
}

View File

@@ -1,416 +0,0 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Zones_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all zones", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Zone{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "Test Zone", got[0].Name)
assert.Equal(t, "example.com", got[0].Domain)
assert.Equal(t, true, got[0].Enabled)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Zones_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
expectedStatus int
expectZone bool
}{
{
name: "Get existing zone",
zoneId: "testZoneId",
expectedStatus: http.StatusOK,
expectZone: true,
},
{
name: "Get non-existing zone",
zoneId: "nonExistingZoneId",
expectedStatus: http.StatusNotFound,
expectZone: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectZone {
got := &api.Zone{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "testZoneId", got.Id)
assert.Equal(t, "Test Zone", got.Name)
assert.Equal(t, "example.com", got.Domain)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Zones_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
enabled := true
disabled := false
tt := []struct {
name string
requestBody *api.PostApiDnsZonesJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, zone *api.Zone)
}{
{
name: "Create zone with valid data",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "New Zone",
Domain: "newzone.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.NotEmpty(t, zone.Id)
assert.Equal(t, "New Zone", zone.Name)
assert.Equal(t, "newzone.com", zone.Domain)
assert.Equal(t, true, zone.Enabled)
assert.Equal(t, false, zone.EnableSearchDomain)
assert.Equal(t, 1, len(zone.DistributionGroups))
},
},
{
name: "Create zone with search domain enabled",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "Search Zone",
Domain: "search.example.com",
Enabled: &enabled,
EnableSearchDomain: true,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.NotEmpty(t, zone.Id)
assert.Equal(t, "Search Zone", zone.Name)
assert.Equal(t, true, zone.EnableSearchDomain)
},
},
{
name: "Create disabled zone",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "Disabled Zone",
Domain: "disabled.example.com",
Enabled: &disabled,
EnableSearchDomain: false,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.NotEmpty(t, zone.Id)
assert.Equal(t, false, zone.Enabled)
},
},
{
name: "Create zone with empty distribution groups",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "No Groups Zone",
Domain: "nogroups.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{},
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/dns/zones", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Zone{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the created zone directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbZone := testing_tools.VerifyZoneInDB(t, db, got.Id)
assert.Equal(t, got.Name, dbZone.Name)
assert.Equal(t, got.Domain, dbZone.Domain)
assert.Equal(t, got.Enabled, dbZone.Enabled)
assert.Equal(t, got.EnableSearchDomain, dbZone.EnableSearchDomain)
}
})
}
}
}
func Test_Zones_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
enabled := true
tt := []struct {
name string
zoneId string
requestBody *api.PutApiDnsZonesZoneIdJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, zone *api.Zone)
}{
{
name: "Update zone name and settings",
zoneId: "testZoneId",
requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{
Name: "Updated Zone",
Domain: "example.com",
Enabled: &enabled,
EnableSearchDomain: true,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.Equal(t, "Updated Zone", zone.Name)
assert.Equal(t, "example.com", zone.Domain)
assert.Equal(t, true, zone.EnableSearchDomain)
},
},
{
name: "Update non-existing zone",
zoneId: "nonExistingZoneId",
requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{
Name: "Whatever",
Domain: "whatever.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Zone{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the updated zone directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbZone := testing_tools.VerifyZoneInDB(t, db, tc.zoneId)
assert.Equal(t, "Updated Zone", dbZone.Name)
assert.Equal(t, "example.com", dbZone.Domain)
assert.Equal(t, true, dbZone.Enabled)
assert.Equal(t, true, dbZone.EnableSearchDomain)
}
})
}
}
}
func Test_Zones_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
expectedStatus int
}{
{
name: "Delete existing zone",
zoneId: "testZoneId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing zone",
zoneId: "nonExistingZoneId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
// Verify deletion in DB for successful deletes by privileged users
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyZoneNotInDB(t, db, tc.zoneId)
}
})
}
}
}

View File

@@ -78,68 +78,6 @@ func Test_Events_GetAll(t *testing.T) {
}
}
func Test_Events_GetAll_Audit(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all audit events", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, false)
// First, perform a mutation to generate an event (create a group as admin)
groupBody, err := json.Marshal(&api.GroupRequest{Name: "auditTestGroup"})
if err != nil {
t.Fatalf("Failed to marshal group request: %v", err)
}
createReq := testing_tools.BuildRequest(t, groupBody, http.MethodPost, "/api/groups", testing_tools.TestAdminId)
createRecorder := httptest.NewRecorder()
apiHandler.ServeHTTP(createRecorder, createReq)
assert.Equal(t, http.StatusOK, createRecorder.Code, "Failed to create group to generate event")
// Now query audit events
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events/audit", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Event{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.GreaterOrEqual(t, len(got), 1, "Expected at least one event after creating a group")
// Verify the group creation event exists
found := false
for _, event := range got {
if event.ActivityCode == "group.add" {
found = true
assert.Equal(t, testing_tools.TestAdminId, event.InitiatorId)
assert.Equal(t, "Group created", event.Activity)
break
}
}
assert.True(t, found, "Expected to find a group.add event")
})
}
}
func Test_Events_GetAll_Empty(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, true)

View File

@@ -1,118 +0,0 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Geolocations_GetAllCountries(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all countries", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/locations/countries", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Country{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got))
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Geolocations_GetCitiesByCountry(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get cities by country", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "US", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.City{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got))
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Geolocations_GetCitiesByCountry_InvalidCode(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "INVALID", 1), testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, http.StatusUnprocessableEntity, true)
}

View File

@@ -26,7 +26,7 @@ func Test_Groups_GetAll(t *testing.T) {
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Regular service user", testing_tools.TestServiceUserId, true},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
@@ -71,7 +71,7 @@ func Test_Groups_GetById(t *testing.T) {
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Regular service user", testing_tools.TestServiceUserId, true},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
@@ -216,6 +216,7 @@ func Test_Groups_Create(t *testing.T) {
}
tc.verifyResponse(t, got)
// Verify group exists in DB
db := testing_tools.GetDB(t, am.GetStore())
dbGroup := testing_tools.VerifyGroupInDB(t, db, got.Id)
assert.Equal(t, tc.requestBody.Name, dbGroup.Name)

View File

@@ -1,295 +0,0 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_IdentityProviders_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all identity providers", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/identity-providers", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.IdentityProvider{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
// The embedded IdP manager is not initialized in the test environment,
// so GetIdentityProviders returns an empty list.
assert.Equal(t, 0, len(got))
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_IdentityProviders_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
idpId string
expectedStatus int
}{
{
name: "Get existing identity provider",
idpId: "testIdpId",
expectedStatus: http.StatusInternalServerError,
},
{
name: "Get non-existing identity provider",
idpId: "nonExistingIdpId",
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_IdentityProviders_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
requestBody *api.PostApiIdentityProvidersJSONRequestBody
expectedStatus int
}{
{
name: "Create identity provider with valid data",
requestBody: &api.PostApiIdentityProvidersJSONRequestBody{
Type: api.IdentityProviderTypeGoogle,
Name: "New IDP",
ClientId: "newClientId",
ClientSecret: "newClientSecret",
},
// Validation passes but the embedded IdP manager is not initialized,
// so the operation returns an internal server error.
expectedStatus: http.StatusInternalServerError,
},
{
name: "Create identity provider with invalid issuer",
requestBody: &api.PostApiIdentityProvidersJSONRequestBody{
Type: api.IdentityProviderTypeOidc,
Name: "Invalid IDP",
Issuer: "not-a-url",
ClientId: "clientId",
ClientSecret: "clientSecret",
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/identity-providers", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_IdentityProviders_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
idpId string
requestBody *api.PutApiIdentityProvidersIdpIdJSONRequestBody
expectedStatus int
}{
{
name: "Update existing identity provider",
idpId: "testIdpId",
requestBody: &api.PutApiIdentityProvidersIdpIdJSONRequestBody{
Type: api.IdentityProviderTypeGoogle,
Name: "Updated IDP",
ClientId: "updatedClientId",
ClientSecret: "updatedClientSecret",
},
// Validation passes but the embedded IdP manager is not initialized,
// so the operation returns an internal server error.
expectedStatus: http.StatusInternalServerError,
},
{
name: "Update non-existing identity provider",
idpId: "nonExistingIdpId",
requestBody: &api.PutApiIdentityProvidersIdpIdJSONRequestBody{
Type: api.IdentityProviderTypeGoogle,
Name: "Updated IDP",
ClientId: "updatedClientId",
ClientSecret: "updatedClientSecret",
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_IdentityProviders_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
idpId string
expectedStatus int
}{
{
name: "Delete existing identity provider",
idpId: "testIdpId",
expectedStatus: http.StatusInternalServerError,
},
{
name: "Delete non-existing identity provider",
idpId: "nonExistingIdpId",
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}

View File

@@ -1,183 +0,0 @@
//go:build integration
package integration
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// Test_Instance_GetStatus tests the unauthenticated GET /api/instance endpoint.
// This endpoint bypasses auth middleware. With nil idpManager (no embedded IDP),
// SetupRequired should be false.
func Test_Instance_GetStatus(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// The /api/instance endpoint is unauthenticated (bypass path).
// We still pass a token via BuildRequest but the bypass middleware skips auth.
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
assert.Equal(t, http.StatusOK, recorder.Code, "Expected 200 OK for instance status endpoint, got %d: %s", recorder.Code, string(content))
got := &api.InstanceStatus{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
// With nil idpManager (no embedded IDP configured), setup is not required.
assert.Equal(t, false, got.SetupRequired, "Expected SetupRequired to be false when embedded IDP is not configured")
}
// Test_Instance_GetStatus_Unauthenticated verifies the endpoint works without any
// valid user token, since it is on the bypass path.
func Test_Instance_GetStatus_Unauthenticated(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// Use an invalid token to confirm the bypass middleware skips auth
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance", testing_tools.InvalidToken)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
assert.Equal(t, http.StatusOK, recorder.Code, "Expected 200 OK for unauthenticated instance status endpoint, got %d: %s", recorder.Code, string(content))
got := &api.InstanceStatus{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, false, got.SetupRequired)
}
// Test_Instance_GetVersionInfo tests the authenticated GET /api/instance/version endpoint.
func Test_Instance_GetVersionInfo(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get version info", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance/version", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := &api.InstanceVersionInfo{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.NotEmpty(t, got.ManagementCurrentVersion, "Expected non-empty current version")
})
}
}
// Test_Instance_Setup tests the unauthenticated POST /api/setup endpoint.
// Since embedded IDP is not configured in the test environment, the setup
// endpoint should return an error (500 Internal Server Error) because the
// instance manager's CreateOwnerUser returns "embedded IDP is not enabled".
func Test_Instance_Setup(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(&api.SetupRequest{
Email: "admin@test.com",
Password: "securepassword123",
Name: "Admin User",
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
// The /api/setup endpoint is unauthenticated (bypass path).
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/setup", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
// Without embedded IDP, CreateOwnerUser returns a plain error (not a status error),
// which the handler maps to 500 Internal Server Error.
assert.Equal(t, http.StatusInternalServerError, recorder.Code,
"Expected 500 when embedded IDP is not configured, got %d: %s", recorder.Code, string(content))
}
// Test_Instance_Setup_Unauthenticated verifies the setup endpoint works (reaches
// the handler) even with an invalid token, since it is on the bypass path.
func Test_Instance_Setup_Unauthenticated(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(&api.SetupRequest{
Email: "admin@test.com",
Password: "securepassword123",
Name: "Admin User",
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/setup", testing_tools.InvalidToken)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
// Even with an invalid token, the bypass middleware skips auth and the handler runs.
// Without embedded IDP, it returns 500.
assert.Equal(t, http.StatusInternalServerError, recorder.Code,
"Expected 500 when embedded IDP is not configured, got %d: %s", recorder.Code, string(content))
}

View File

@@ -1,154 +0,0 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// Note: The integration test infrastructure does not configure an embedded IDP,
// so actual invite operations will return PreconditionFailed (412) for authorized users.
// These tests verify that the permissions layer correctly denies regular users
// before the handler logic is reached.
func Test_Invites_List(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - List invites", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users/invites", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}
func Test_Invites_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Create invite", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
body, err := json.Marshal(&api.UserInviteCreateRequest{
Email: "newuser@test.com",
Name: "New User",
Role: "user",
AutoGroups: []string{},
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/users/invites", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}
func Test_Invites_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Delete invite", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/invites/{inviteId}", "{inviteId}", "someInviteId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}
func Test_Invites_Regenerate(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Regenerate invite", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodPost, strings.Replace("/api/users/invites/{inviteId}/regenerate", "{inviteId}", "someInviteId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}

View File

@@ -45,7 +45,7 @@ func Test_Peers_GetAll(t *testing.T) {
{
name: "Regular service user",
userId: testing_tools.TestServiceUserId,
expectResponse: false,
expectResponse: true,
},
{
name: "Admin service user",
@@ -123,7 +123,7 @@ func Test_Peers_GetById(t *testing.T) {
{
name: "Regular service user",
userId: testing_tools.TestServiceUserId,
expectResponse: false,
expectResponse: true,
},
{
name: "Admin service user",

View File

@@ -1,372 +0,0 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_PostureChecks_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all posture checks", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/posture-checks", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []*api.PostureCheck{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "testPostureCheckId", got[0].Id)
assert.Equal(t, "NetBird Version Check", got[0].Name)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_PostureChecks_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
postureCheckId string
expectedStatus int
expectCheck bool
}{
{
name: "Get existing posture check",
postureCheckId: "testPostureCheckId",
expectedStatus: http.StatusOK,
expectCheck: true,
},
{
name: "Get non-existing posture check",
postureCheckId: "nonExistingPostureCheckId",
expectedStatus: http.StatusNotFound,
expectCheck: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectCheck {
got := &api.PostureCheck{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "testPostureCheckId", got.Id)
assert.Equal(t, "NetBird Version Check", got.Name)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_PostureChecks_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
minVersion := "0.32.0"
tt := []struct {
name string
requestBody *api.PostureCheckUpdate
expectedStatus int
verifyResponse func(t *testing.T, check *api.PostureCheck)
}{
{
name: "Create posture check with NB version",
requestBody: &api.PostureCheckUpdate{
Name: "New Version Check",
Description: "check for new version",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: minVersion,
},
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, check *api.PostureCheck) {
t.Helper()
assert.NotEmpty(t, check.Id)
assert.Equal(t, "New Version Check", check.Name)
assert.NotNil(t, check.Checks.NbVersionCheck)
assert.Equal(t, minVersion, check.Checks.NbVersionCheck.MinVersion)
},
},
{
name: "Create posture check with empty name",
requestBody: &api.PostureCheckUpdate{
Name: "",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: "0.32.0",
},
},
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/posture-checks", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.PostureCheck{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
db := testing_tools.GetDB(t, am.GetStore())
dbCheck := testing_tools.VerifyPostureCheckInDB(t, db, got.Id)
assert.Equal(t, got.Name, dbCheck.Name)
}
})
}
}
}
func Test_PostureChecks_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
minVersion := "0.33.0"
tt := []struct {
name string
postureCheckId string
requestBody *api.PostureCheckUpdate
expectedStatus int
verifyResponse func(t *testing.T, check *api.PostureCheck)
}{
{
name: "Update posture check name and version",
postureCheckId: "testPostureCheckId",
requestBody: &api.PostureCheckUpdate{
Name: "Updated Version Check",
Description: "updated description",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: minVersion,
},
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, check *api.PostureCheck) {
t.Helper()
assert.Equal(t, "testPostureCheckId", check.Id)
assert.Equal(t, "Updated Version Check", check.Name)
},
},
{
name: "Update non-existing posture check",
postureCheckId: "nonExistingPostureCheckId",
requestBody: &api.PostureCheckUpdate{
Name: "whatever",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: "0.33.0",
},
},
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.PostureCheck{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
db := testing_tools.GetDB(t, am.GetStore())
dbCheck := testing_tools.VerifyPostureCheckInDB(t, db, tc.postureCheckId)
assert.Equal(t, "Updated Version Check", dbCheck.Name)
}
})
}
}
}
func Test_PostureChecks_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
postureCheckId string
expectedStatus int
}{
{
name: "Delete existing posture check",
postureCheckId: "testPostureCheckId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing posture check",
postureCheckId: "nonExistingPostureCheckId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyPostureCheckNotInDB(t, db, tc.postureCheckId)
}
})
}
}
}

View File

@@ -1,306 +0,0 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_ReverseProxy_GetClusters(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get clusters", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/clusters", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.ProxyCluster{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got), "Expected empty clusters list when no proxy is connected")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_ReverseProxy_GetAllServices(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all services", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/services", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []*api.Service{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got), "Expected empty services list with no services in DB")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_ReverseProxy_CreateService(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Create service", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// Creating a service requires a valid domain entry. Without one, authorized
// users will receive a validation error. This test verifies that the
// permissions layer correctly rejects unauthorized users before handler
// logic runs, and that authorized users reach the handler (even if the
// handler returns an error due to missing domain).
body, err := json.Marshal(&api.ServiceRequest{
Name: "test-service",
Domain: "nonexistent.example.com",
Enabled: true,
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/reverse-proxies/services", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users reach the handler but get a validation error (no valid domain).
// Unauthorized users are rejected by the permissions middleware.
testing_tools.ReadResponse(t, recorder, http.StatusUnprocessableEntity, user.expectResponse)
})
}
}
func Test_ReverseProxy_GetServiceById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get service by ID (non-existing)", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/reverse-proxies/services/{serviceId}", "{serviceId}", "nonExistingServiceId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get NotFound for a non-existing service.
// Unauthorized users are rejected by the permissions middleware.
testing_tools.ReadResponse(t, recorder, http.StatusNotFound, user.expectResponse)
})
}
}
func Test_ReverseProxy_DeleteService(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Delete service (non-existing)", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/reverse-proxies/services/{serviceId}", "{serviceId}", "nonExistingServiceId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get NotFound for a non-existing service.
// Unauthorized users are rejected by the permissions middleware.
testing_tools.ReadResponse(t, recorder, http.StatusNotFound, user.expectResponse)
})
}
}
func Test_ReverseProxy_GetAllDomains(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all domains", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/domains", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.ReverseProxyDomain{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got), "Expected empty domains list with no domains in DB")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_ReverseProxy_GetAccessLogs(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get proxy access logs", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events/proxy", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := &api.ProxyAccessLogsResponse{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got.Data), "Expected empty access logs data")
assert.Equal(t, 0, got.TotalRecords, "Expected zero total records")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}

View File

@@ -1,124 +0,0 @@
//go:build integration
package integration
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
)
func Test_Users_Approve(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
targetUserId string
expectedStatus int
}{
{
name: "Approve pending user",
targetUserId: "pendingUserId",
expectedStatus: http.StatusOK,
},
{
name: "Approve non-existing user",
targetUserId: "nonExistingUserId",
expectedStatus: http.StatusNotFound,
},
{
name: "Approve already active user",
targetUserId: testing_tools.TestUserId,
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_approve_reject.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodPost, strings.Replace("/api/users/{userId}/approve", "{userId}", tc.targetUserId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_Users_Reject(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
targetUserId string
expectedStatus int
}{
{
name: "Reject pending user",
targetUserId: "pendingUserId",
expectedStatus: http.StatusOK,
},
{
name: "Reject non-existing user",
targetUserId: "nonExistingUserId",
expectedStatus: http.StatusNotFound,
},
{
name: "Reject non-pending user",
targetUserId: testing_tools.TestUserId,
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_approve_reject.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/{userId}/reject", "{userId}", tc.targetUserId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
_, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if expectResponse && tc.expectedStatus == http.StatusOK {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyUserNotInDB(t, db, tc.targetUserId)
}
})
}
}
}

View File

@@ -1,64 +0,0 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Users_GetCurrent(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, true},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, false},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get current user", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users/current", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := &api.User{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, user.userId, got.Id)
assert.NotNil(t, got.IsCurrent)
assert.True(t, *got.IsCurrent)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}

View File

@@ -26,7 +26,7 @@ func Test_Users_GetAll(t *testing.T) {
{"Regular user", testing_tools.TestUserId, true},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Regular service user", testing_tools.TestServiceUserId, true},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
@@ -637,38 +637,6 @@ func Test_PATs_Create(t *testing.T) {
}
}
func Test_Users_Update_CrossAccountAttack(t *testing.T) {
t.Run("Admin attempts to update user from other account", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
body, _ := json.Marshal(&api.UserRequest{
Role: "user",
AutoGroups: []string{},
IsBlocked: true,
})
// TestAdminId belongs to testAccountId, but targets otherUserId which belongs to otherAccountId
req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/users/otherUserId", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account user update must be rejected")
})
}
func Test_Users_Delete_CrossAccountAttack(t *testing.T) {
t.Run("Admin attempts to delete service user from other account", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
// TestAdminId belongs to testAccountId, but targets otherServiceUserId which belongs to otherAccountId
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/users/otherServiceUserId", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account user delete must be rejected")
})
}
func Test_PATs_Delete(t *testing.T) {
users := []struct {
name string

View File

@@ -15,4 +15,4 @@ INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NU
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);

Some files were not shown because too many files have changed in this diff Show More