mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-09 17:39:57 +00:00
Compare commits
37 Commits
github-iss
...
refactor/p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b10d74ab8 | ||
|
|
312bcf6398 | ||
|
|
381858a865 | ||
|
|
4051af499a | ||
|
|
d774e9a62a | ||
|
|
a4b55af99c | ||
|
|
470307079b | ||
|
|
0b04c0d03b | ||
|
|
bed8d89d9f | ||
|
|
b65a8bcb9c | ||
|
|
a53c38a6ed | ||
|
|
8a41117403 | ||
|
|
e0063731f2 | ||
|
|
a572d8819f | ||
|
|
d139a4fc42 | ||
|
|
dbe533327f | ||
|
|
4835909a04 | ||
|
|
fb6447e714 | ||
|
|
16d6cc1265 | ||
|
|
53e47da7bd | ||
|
|
ce522ea69b | ||
|
|
5be80e976f | ||
|
|
dd9539a9bf | ||
|
|
8530f9b8fc | ||
|
|
9406de9610 | ||
|
|
9e385eb540 | ||
|
|
e46ea895c1 | ||
|
|
31d901c4b0 | ||
|
|
20e6dff507 | ||
|
|
f5c8a6fe1a | ||
|
|
beee14b9bf | ||
|
|
3013c98ab5 | ||
|
|
3741eb46dd | ||
|
|
d85ee0b5a2 | ||
|
|
da4a0eb68a | ||
|
|
b0ce0048b4 | ||
|
|
32730af33f |
5
.github/issue-resolution/package.json
vendored
5
.github/issue-resolution/package.json
vendored
@@ -1,5 +0,0 @@
|
||||
{
|
||||
"name": "issue-resolution",
|
||||
"private": true,
|
||||
"type": "module"
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
You are a GitHub issue resolution classifier.
|
||||
|
||||
Your job is to decide whether an open GitHub issue is:
|
||||
- AUTO_CLOSE
|
||||
- MANUAL_REVIEW
|
||||
- KEEP_OPEN
|
||||
|
||||
Rules:
|
||||
1. AUTO_CLOSE is only allowed if there is objective, hard evidence:
|
||||
- a merged linked PR that clearly resolves the issue, or
|
||||
- an explicit maintainer/member/owner/collaborator comment saying the issue is fixed, resolved, duplicate, or superseded
|
||||
2. If there is any contradictory later evidence, do NOT AUTO_CLOSE.
|
||||
3. If evidence is promising but not airtight, choose MANUAL_REVIEW.
|
||||
4. If the issue still appears active or unresolved, choose KEEP_OPEN.
|
||||
5. Do not invent evidence.
|
||||
6. Output valid JSON only.
|
||||
|
||||
Maintainer-authoritative roles:
|
||||
- MEMBER
|
||||
- OWNER
|
||||
- COLLABORATOR
|
||||
|
||||
Partial fixes and multi-item issues:
|
||||
- A merged PR that only addresses SOME items in a multi-item request does NOT resolve the issue. If the issue lists 5 feature requests and a PR fixes 1, the issue is still open.
|
||||
- If a PR description or comment says "partially addresses", "partial fix", or similar, the issue is NOT resolved. Classify as KEEP_OPEN.
|
||||
- If a merged PR addresses the core ask but a later comment objects or reports a regression, classify as MANUAL_REVIEW (not resolved).
|
||||
|
||||
Workarounds vs. actual fixes:
|
||||
- A WORKAROUND is when a user changes their own setup to avoid the problem (editing configs, using a different setting, manual SQL fixes, switching tools, scripts). Workarounds do NOT count as resolution — the underlying issue is still present in the product.
|
||||
- An ACTUAL FIX is when a user reports the problem went away after upgrading to a specific version (e.g., "fixed after updating to v0.65.1") or after a specific PR was merged. This suggests the fix was shipped in the product itself.
|
||||
- A maintainer pointing to an existing alternative feature is NOT the same as fixing the issue. If the reporter never confirmed the alternative works for them, classify as KEEP_OPEN.
|
||||
- If only workarounds exist and no maintainer has confirmed a fix, classify as KEEP_OPEN.
|
||||
- If a user reports an actual fix via a version upgrade but no maintainer confirmed it, classify as MANUAL_REVIEW (not AUTO_CLOSE).
|
||||
|
||||
Stale issues:
|
||||
- An issue with no activity for over 12 months, where a maintainer offered an alternative or asked for more info and the reporter never responded, is a candidate for MANUAL_REVIEW — not necessarily KEEP_OPEN.
|
||||
|
||||
Important:
|
||||
- Later comments outweigh earlier ones.
|
||||
- A non-maintainer saying "fixed for me" is not enough for AUTO_CLOSE.
|
||||
- If uncertain, prefer MANUAL_REVIEW or KEEP_OPEN.
|
||||
@@ -1,80 +0,0 @@
|
||||
{
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"decision",
|
||||
"reason_code",
|
||||
"confidence",
|
||||
"hard_signals",
|
||||
"contradictions",
|
||||
"summary",
|
||||
"close_comment",
|
||||
"manual_review_note"
|
||||
],
|
||||
"properties": {
|
||||
"decision": {
|
||||
"type": "string",
|
||||
"enum": ["AUTO_CLOSE", "MANUAL_REVIEW", "KEEP_OPEN"]
|
||||
},
|
||||
"reason_code": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"resolved_by_merged_pr",
|
||||
"maintainer_confirmed_resolved",
|
||||
"duplicate_confirmed",
|
||||
"superseded_confirmed",
|
||||
"likely_fixed_but_unconfirmed",
|
||||
"still_open",
|
||||
"unclear"
|
||||
]
|
||||
},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1
|
||||
},
|
||||
"hard_signals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"required": ["type", "url"],
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"merged_pr",
|
||||
"maintainer_comment",
|
||||
"duplicate_reference",
|
||||
"superseded_reference"
|
||||
]
|
||||
},
|
||||
"url": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"contradictions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"required": ["type", "url"],
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"reporter_still_broken",
|
||||
"later_unresolved_comment",
|
||||
"ambiguous_pr_link",
|
||||
"other"
|
||||
]
|
||||
},
|
||||
"url": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"summary": { "type": "string" },
|
||||
"close_comment": { "type": "string" },
|
||||
"manual_review_note": { "type": "string" }
|
||||
}
|
||||
}
|
||||
213
.github/issue-resolution/scripts/apply-decisions.mjs
vendored
213
.github/issue-resolution/scripts/apply-decisions.mjs
vendored
@@ -1,213 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
|
||||
const decisions = JSON.parse(await fs.readFile("decisions.json", "utf8"));
|
||||
const dryRun = String(process.env.DRY_RUN).toLowerCase() === "true";
|
||||
|
||||
const ghHeaders = {
|
||||
Authorization: `Bearer ${process.env.GH_TOKEN}`,
|
||||
Accept: "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
};
|
||||
|
||||
// Use PROJECT_PAT for project board operations, fall back to GH_TOKEN
|
||||
const projectHeaders = {
|
||||
Authorization: `Bearer ${process.env.PROJECT_PAT || process.env.GH_TOKEN}`,
|
||||
Accept: "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
};
|
||||
|
||||
async function rest(url, method = "GET", body) {
|
||||
const res = await fetch(url, {
|
||||
method,
|
||||
headers: ghHeaders,
|
||||
body: body ? JSON.stringify(body) : undefined
|
||||
});
|
||||
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
|
||||
return res.status === 204 ? null : res.json();
|
||||
}
|
||||
|
||||
async function graphql(query, variables) {
|
||||
const res = await fetch("https://api.github.com/graphql", {
|
||||
method: "POST",
|
||||
headers: projectHeaders,
|
||||
body: JSON.stringify({ query, variables })
|
||||
});
|
||||
if (!res.ok) throw new Error(`${res.status}: ${await res.text()}`);
|
||||
const json = await res.json();
|
||||
if (json.errors) throw new Error(JSON.stringify(json.errors));
|
||||
return json.data;
|
||||
}
|
||||
|
||||
async function addLabel(owner, repo, issueNumber, labels) {
|
||||
return rest(
|
||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/labels`,
|
||||
"POST",
|
||||
{ labels }
|
||||
);
|
||||
}
|
||||
|
||||
async function addComment(owner, repo, issueNumber, body) {
|
||||
return rest(
|
||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/comments`,
|
||||
"POST",
|
||||
{ body }
|
||||
);
|
||||
}
|
||||
|
||||
async function closeIssue(owner, repo, issueNumber) {
|
||||
return rest(
|
||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`,
|
||||
"PATCH",
|
||||
{ state: "closed", state_reason: "completed" }
|
||||
);
|
||||
}
|
||||
|
||||
async function getIssueNodeId(owner, repo, issueNumber) {
|
||||
const issue = await rest(`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`);
|
||||
return issue.node_id;
|
||||
}
|
||||
|
||||
async function addToProject(issueNodeId) {
|
||||
const mutation = `
|
||||
mutation($projectId: ID!, $contentId: ID!) {
|
||||
addProjectV2ItemById(input: {projectId: $projectId, contentId: $contentId}) {
|
||||
item { id }
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
try {
|
||||
const data = await graphql(mutation, {
|
||||
projectId: process.env.PROJECT_ID,
|
||||
contentId: issueNodeId
|
||||
});
|
||||
return data.addProjectV2ItemById.item.id;
|
||||
} catch (err) {
|
||||
console.warn(`[WARN] Could not add to project (needs PAT with project scope): ${err.message}`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function setTextField(itemId, fieldId, value) {
|
||||
const mutation = `
|
||||
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
|
||||
updateProjectV2ItemFieldValue(input: {
|
||||
projectId: $projectId,
|
||||
itemId: $itemId,
|
||||
fieldId: $fieldId,
|
||||
value: { text: $value }
|
||||
}) {
|
||||
projectV2Item { id }
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
return graphql(mutation, {
|
||||
projectId: process.env.PROJECT_ID,
|
||||
itemId,
|
||||
fieldId,
|
||||
value
|
||||
});
|
||||
}
|
||||
|
||||
async function setNumberField(itemId, fieldId, value) {
|
||||
const mutation = `
|
||||
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: Float!) {
|
||||
updateProjectV2ItemFieldValue(input: {
|
||||
projectId: $projectId,
|
||||
itemId: $itemId,
|
||||
fieldId: $fieldId,
|
||||
value: { number: $value }
|
||||
}) {
|
||||
projectV2Item { id }
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
return graphql(mutation, {
|
||||
projectId: process.env.PROJECT_ID,
|
||||
itemId,
|
||||
fieldId,
|
||||
value
|
||||
});
|
||||
}
|
||||
|
||||
async function setSingleSelectField(itemId, fieldId, optionId) {
|
||||
const mutation = `
|
||||
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $optionId: String!) {
|
||||
updateProjectV2ItemFieldValue(input: {
|
||||
projectId: $projectId,
|
||||
itemId: $itemId,
|
||||
fieldId: $fieldId,
|
||||
value: { singleSelectOptionId: $optionId }
|
||||
}) {
|
||||
projectV2Item { id }
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
return graphql(mutation, {
|
||||
projectId: process.env.PROJECT_ID,
|
||||
itemId,
|
||||
fieldId,
|
||||
optionId
|
||||
});
|
||||
}
|
||||
|
||||
async function addToProjectWithFields(owner, repo, d) {
|
||||
const issueNodeId = await getIssueNodeId(owner, repo, d.issue_number);
|
||||
const itemId = await addToProject(issueNodeId);
|
||||
|
||||
if (itemId) {
|
||||
if (process.env.PROJECT_STATUS_FIELD_ID && process.env.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID) {
|
||||
await setSingleSelectField(itemId, process.env.PROJECT_STATUS_FIELD_ID, process.env.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID);
|
||||
}
|
||||
if (process.env.PROJECT_CONFIDENCE_FIELD_ID) {
|
||||
await setNumberField(itemId, process.env.PROJECT_CONFIDENCE_FIELD_ID, d.model.confidence);
|
||||
}
|
||||
if (process.env.PROJECT_REASON_FIELD_ID) {
|
||||
await setTextField(itemId, process.env.PROJECT_REASON_FIELD_ID, d.model.reason_code);
|
||||
}
|
||||
if (process.env.PROJECT_EVIDENCE_FIELD_ID) {
|
||||
await setTextField(itemId, process.env.PROJECT_EVIDENCE_FIELD_ID, d.issue_url);
|
||||
}
|
||||
console.log(` → Added to project board (Status: Needs Review)`);
|
||||
}
|
||||
}
|
||||
|
||||
for (const d of decisions) {
|
||||
const [owner, repo] = d.repository.split("/");
|
||||
|
||||
if (d.final_decision === "KEEP_OPEN") {
|
||||
console.log(`#${d.issue_number} → KEEP_OPEN (confidence: ${d.model.confidence}, reason: ${d.model.reason_code})`);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (dryRun) {
|
||||
console.log(`[DRY RUN] #${d.issue_number} → ${d.final_decision} (confidence: ${d.model.confidence}, reason: ${d.model.reason_code})`);
|
||||
// In dry-run: populate project board but don't touch issues
|
||||
if (d.final_decision === "MANUAL_REVIEW" || d.final_decision === "AUTO_CLOSE") {
|
||||
await addToProjectWithFields(owner, repo, d);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (d.final_decision === "AUTO_CLOSE") {
|
||||
await addLabel(owner, repo, d.issue_number, ["auto-closed-resolved"]);
|
||||
await addComment(owner, repo, d.issue_number, d.model.close_comment);
|
||||
await closeIssue(owner, repo, d.issue_number);
|
||||
await addToProjectWithFields(owner, repo, d);
|
||||
}
|
||||
|
||||
if (d.final_decision === "MANUAL_REVIEW") {
|
||||
await addLabel(owner, repo, d.issue_number, ["resolution-candidate"]);
|
||||
await addToProjectWithFields(owner, repo, d);
|
||||
await addComment(
|
||||
owner,
|
||||
repo,
|
||||
d.issue_number,
|
||||
d.model.manual_review_note ||
|
||||
"This issue looks like a possible resolution candidate, but not with enough certainty for automatic closure. Added to the review queue."
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,259 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
|
||||
const candidates = JSON.parse(await fs.readFile("candidates.json", "utf8"));
|
||||
const systemPrompt = await fs.readFile("prompts/issue-resolution-system.txt", "utf8");
|
||||
const outputSchema = JSON.parse(await fs.readFile("schemas/issue-resolution-output.json", "utf8"));
|
||||
|
||||
function isMaintainerRole(role) {
|
||||
return ["MEMBER", "OWNER", "COLLABORATOR"].includes(role || "");
|
||||
}
|
||||
|
||||
function preScore(candidate) {
|
||||
let score = 0;
|
||||
const hardSignals = [];
|
||||
const contradictions = [];
|
||||
|
||||
for (const t of candidate.timeline) {
|
||||
const sourceIssue = t.source?.issue;
|
||||
|
||||
if (t.event === "cross-referenced" && sourceIssue?.pull_request?.html_url) {
|
||||
hardSignals.push({
|
||||
type: "merged_pr",
|
||||
url: sourceIssue.html_url
|
||||
});
|
||||
score += 40; // provisional until PR merged state is verified
|
||||
}
|
||||
|
||||
if (["referenced", "connected"].includes(t.event)) {
|
||||
score += 10;
|
||||
}
|
||||
}
|
||||
|
||||
for (const c of candidate.comments) {
|
||||
const body = c.body.toLowerCase();
|
||||
|
||||
if (
|
||||
isMaintainerRole(c.author_association) &&
|
||||
/\b(fixed|resolved|duplicate|superseded|closing)\b/.test(body)
|
||||
) {
|
||||
score += 25;
|
||||
hardSignals.push({
|
||||
type: "maintainer_comment",
|
||||
url: c.html_url
|
||||
});
|
||||
}
|
||||
|
||||
if (/\b(still broken|still happening|not fixed|reproducible)\b/.test(body)) {
|
||||
score -= 50;
|
||||
contradictions.push({
|
||||
type: "later_unresolved_comment",
|
||||
url: c.html_url
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return { score, hardSignals, contradictions };
|
||||
}
|
||||
|
||||
// GitHub Models gpt-4o has an 8000 token input limit.
|
||||
// Reserve ~2000 tokens for system prompt + response overhead.
|
||||
// 1 token ~= 4 chars, so cap user message at ~24000 chars.
|
||||
const MAX_USER_MESSAGE_CHARS = 24000;
|
||||
|
||||
function truncate(text, maxChars) {
|
||||
if (text.length <= maxChars) return text;
|
||||
return text.slice(0, maxChars) + "\n\n[... truncated due to length]";
|
||||
}
|
||||
|
||||
function buildUserMessage(candidate, pre) {
|
||||
const { issue, comments, timeline } = candidate;
|
||||
|
||||
const commentBlock = comments
|
||||
.map((c) => `[${c.author_association}] ${c.user} (${c.created_at}):\n${c.body}`)
|
||||
.join("\n---\n");
|
||||
|
||||
const timelineBlock = timeline
|
||||
.filter((t) => ["cross-referenced", "referenced", "connected", "closed", "reopened"].includes(t.event))
|
||||
.map((t) => {
|
||||
let line = `${t.event} (${t.created_at})`;
|
||||
if (t.source?.issue?.html_url) line += ` — ${t.source.issue.html_url}`;
|
||||
if (t.source?.issue?.pull_request?.html_url) line += ` (PR: ${t.source.issue.pull_request.html_url})`;
|
||||
return line;
|
||||
})
|
||||
.join("\n");
|
||||
|
||||
const sections = [
|
||||
`## Issue #${issue.number}: ${issue.title}`,
|
||||
`URL: ${issue.html_url}`,
|
||||
`Created: ${issue.created_at} | Updated: ${issue.updated_at}`,
|
||||
`Labels: ${issue.labels.join(", ") || "none"}`,
|
||||
"",
|
||||
"### Body",
|
||||
truncate(issue.body || "(empty)", 4000),
|
||||
"",
|
||||
"### Comments",
|
||||
commentBlock || "(none)",
|
||||
"",
|
||||
"### Timeline events",
|
||||
timelineBlock || "(none)",
|
||||
];
|
||||
|
||||
if (candidate.linked_prs?.length) {
|
||||
sections.push("");
|
||||
sections.push("### Linked PRs (verified state)");
|
||||
for (const pr of candidate.linked_prs) {
|
||||
const status = pr.merged ? `MERGED (${pr.merged_at})` : pr.state.toUpperCase();
|
||||
sections.push(`- PR #${pr.number}: ${pr.title} — ${status} — ${pr.url}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (pre.hardSignals.length || pre.contradictions.length) {
|
||||
sections.push("");
|
||||
sections.push("### Automated evidence scan");
|
||||
for (const s of pre.hardSignals) {
|
||||
sections.push(`- SIGNAL: ${s.type} — ${s.url}`);
|
||||
}
|
||||
for (const c of pre.contradictions) {
|
||||
sections.push(`- CONTRADICTION: ${c.type} — ${c.url}`);
|
||||
}
|
||||
}
|
||||
|
||||
return truncate(sections.join("\n"), MAX_USER_MESSAGE_CHARS);
|
||||
}
|
||||
|
||||
const MODEL = "gpt-4o-mini";
|
||||
const MAX_RETRIES = 5;
|
||||
|
||||
function sleep(ms) {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function callGitHubModel(candidate, pre) {
|
||||
const body = JSON.stringify({
|
||||
model: MODEL,
|
||||
messages: [
|
||||
{ role: "system", content: systemPrompt },
|
||||
{ role: "user", content: buildUserMessage(candidate, pre) },
|
||||
],
|
||||
response_format: {
|
||||
type: "json_schema",
|
||||
json_schema: {
|
||||
name: "issue_resolution",
|
||||
strict: true,
|
||||
schema: outputSchema,
|
||||
},
|
||||
},
|
||||
temperature: 0.1,
|
||||
});
|
||||
|
||||
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
|
||||
const res = await fetch("https://models.inference.ai.azure.com/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${process.env.GH_TOKEN}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body,
|
||||
});
|
||||
|
||||
if (res.status === 429) {
|
||||
const retryAfter = Number(res.headers.get("retry-after")) || 30;
|
||||
if (retryAfter > 120) {
|
||||
console.warn(` [QUOTA EXHAUSTED] API wants ${retryAfter}s wait — skipping remaining issues.`);
|
||||
return null;
|
||||
}
|
||||
console.warn(` [RATE LIMITED] Waiting ${retryAfter}s (attempt ${attempt + 1}/${MAX_RETRIES})...`);
|
||||
await sleep(retryAfter * 1000);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`GitHub Models ${res.status}: ${text}`);
|
||||
}
|
||||
|
||||
const data = await res.json();
|
||||
return JSON.parse(data.choices[0].message.content);
|
||||
}
|
||||
|
||||
throw new Error(`GitHub Models: exceeded ${MAX_RETRIES} retries due to rate limiting`);
|
||||
}
|
||||
|
||||
function enforcePolicy(modelOut, pre) {
|
||||
const approvedReasons = new Set([
|
||||
"resolved_by_merged_pr",
|
||||
"maintainer_confirmed_resolved",
|
||||
"duplicate_confirmed",
|
||||
"superseded_confirmed"
|
||||
]);
|
||||
|
||||
const hasHardSignal =
|
||||
(modelOut.hard_signals || []).some(s =>
|
||||
["merged_pr", "maintainer_comment", "duplicate_reference", "superseded_reference"].includes(s.type)
|
||||
) || pre.hardSignals.length > 0;
|
||||
|
||||
const hasContradiction =
|
||||
(modelOut.contradictions || []).length > 0 || pre.contradictions.length > 0;
|
||||
|
||||
// Only auto-close with very strict criteria
|
||||
if (
|
||||
modelOut.decision === "AUTO_CLOSE" &&
|
||||
modelOut.confidence >= 0.97 &&
|
||||
approvedReasons.has(modelOut.reason_code) &&
|
||||
hasHardSignal &&
|
||||
!hasContradiction
|
||||
) {
|
||||
return "AUTO_CLOSE";
|
||||
}
|
||||
|
||||
// Downgrade AUTO_CLOSE that didn't pass the gate
|
||||
if (modelOut.decision === "AUTO_CLOSE") {
|
||||
return "MANUAL_REVIEW";
|
||||
}
|
||||
|
||||
// Otherwise trust the model
|
||||
return modelOut.decision;
|
||||
}
|
||||
|
||||
console.log(`Classifying ${candidates.length} candidates with ${MODEL}...\n`);
|
||||
|
||||
// 15 req/min limit → 1 request every 4s. Use 4.5s for safety margin.
|
||||
const PACE_MS = 4500;
|
||||
let lastRequestTime = 0;
|
||||
|
||||
async function paced(fn) {
|
||||
const elapsed = Date.now() - lastRequestTime;
|
||||
if (elapsed < PACE_MS) await sleep(PACE_MS - elapsed);
|
||||
lastRequestTime = Date.now();
|
||||
return fn();
|
||||
}
|
||||
|
||||
const decisions = [];
|
||||
for (const candidate of candidates) {
|
||||
const pre = preScore(candidate);
|
||||
const modelOut = await paced(() => callGitHubModel(candidate, pre));
|
||||
|
||||
if (modelOut === null) {
|
||||
console.warn(`\nQuota exhausted after ${decisions.length} issues. Writing partial results.`);
|
||||
break;
|
||||
}
|
||||
|
||||
const finalDecision = enforcePolicy(modelOut, pre);
|
||||
|
||||
decisions.push({
|
||||
repository: candidate.repository,
|
||||
issue_number: candidate.issue.number,
|
||||
issue_url: candidate.issue.html_url,
|
||||
title: candidate.issue.title,
|
||||
pre_score: pre.score,
|
||||
final_decision: finalDecision,
|
||||
model: modelOut
|
||||
});
|
||||
|
||||
console.log(
|
||||
`#${candidate.issue.number} | pre_score: ${pre.score} | model: ${modelOut.decision} @ ${modelOut.confidence} | final: ${finalDecision} | ${modelOut.reason_code}`
|
||||
);
|
||||
}
|
||||
|
||||
await fs.writeFile("decisions.json", JSON.stringify(decisions, null, 2));
|
||||
console.log(`\nWrote ${decisions.length} decisions to decisions.json`);
|
||||
@@ -1,123 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
|
||||
const token = process.env.GH_TOKEN;
|
||||
const repo = process.env.REPO; // "owner/repo"
|
||||
const maxIssues = Number(process.env.MAX_ISSUES) || 100;
|
||||
|
||||
const headers = {
|
||||
Authorization: `Bearer ${token}`,
|
||||
Accept: "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
};
|
||||
|
||||
async function rest(url) {
|
||||
const res = await fetch(url, { headers });
|
||||
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
|
||||
return res.json();
|
||||
}
|
||||
|
||||
async function restSafe(url) {
|
||||
const res = await fetch(url, { headers });
|
||||
if (!res.ok) return null;
|
||||
return res.json();
|
||||
}
|
||||
|
||||
async function paginate(url, max) {
|
||||
const items = [];
|
||||
let page = 1;
|
||||
while (items.length < max) {
|
||||
const perPage = Math.min(100, max - items.length);
|
||||
const sep = url.includes("?") ? "&" : "?";
|
||||
const batch = await rest(`${url}${sep}per_page=${perPage}&page=${page}`);
|
||||
if (!batch.length) break;
|
||||
items.push(...batch);
|
||||
page++;
|
||||
}
|
||||
return items.slice(0, max);
|
||||
}
|
||||
|
||||
console.log(`Fetching up to ${maxIssues} open issues from ${repo}...`);
|
||||
|
||||
const issues = await paginate(
|
||||
`https://api.github.com/repos/${repo}/issues?state=open&sort=updated&direction=asc`,
|
||||
maxIssues
|
||||
);
|
||||
|
||||
// Filter out pull requests (GitHub API returns PRs as issues too)
|
||||
const realIssues = issues.filter((i) => !i.pull_request);
|
||||
console.log(`Found ${realIssues.length} open issues (excluded PRs).`);
|
||||
|
||||
const candidates = [];
|
||||
for (const issue of realIssues) {
|
||||
const [comments, timeline] = await Promise.all([
|
||||
rest(`https://api.github.com/repos/${repo}/issues/${issue.number}/comments?per_page=100`),
|
||||
rest(`https://api.github.com/repos/${repo}/issues/${issue.number}/timeline?per_page=100`),
|
||||
]);
|
||||
|
||||
candidates.push({
|
||||
repository: repo,
|
||||
issue: {
|
||||
number: issue.number,
|
||||
html_url: issue.html_url,
|
||||
title: issue.title,
|
||||
body: issue.body,
|
||||
created_at: issue.created_at,
|
||||
updated_at: issue.updated_at,
|
||||
labels: issue.labels.map((l) => l.name),
|
||||
},
|
||||
comments: comments.map((c) => ({
|
||||
body: c.body,
|
||||
author_association: c.author_association,
|
||||
html_url: c.html_url,
|
||||
created_at: c.created_at,
|
||||
user: c.user?.login,
|
||||
})),
|
||||
timeline: timeline.map((t) => ({
|
||||
event: t.event,
|
||||
created_at: t.created_at,
|
||||
source: t.source
|
||||
? {
|
||||
issue: {
|
||||
html_url: t.source.issue?.html_url,
|
||||
pull_request: t.source.issue?.pull_request
|
||||
? { html_url: t.source.issue.pull_request.html_url }
|
||||
: undefined,
|
||||
},
|
||||
}
|
||||
: undefined,
|
||||
})),
|
||||
linked_prs: [],
|
||||
});
|
||||
|
||||
// Fetch merge status for cross-referenced PRs
|
||||
const prUrls = new Set();
|
||||
for (const t of timeline) {
|
||||
const prHtml = t.source?.issue?.pull_request?.html_url;
|
||||
if (t.event === "cross-referenced" && prHtml) {
|
||||
prUrls.add(prHtml);
|
||||
}
|
||||
}
|
||||
|
||||
const candidate = candidates[candidates.length - 1];
|
||||
for (const prHtml of prUrls) {
|
||||
// Extract owner/repo and PR number from URL like https://github.com/owner/repo/pull/123
|
||||
const match = prHtml.match(/github\.com\/([^/]+\/[^/]+)\/pull\/(\d+)/);
|
||||
if (!match) continue;
|
||||
const [, prRepo, prNum] = match;
|
||||
const pr = await restSafe(`https://api.github.com/repos/${prRepo}/pulls/${prNum}`);
|
||||
if (!pr) continue;
|
||||
candidate.linked_prs.push({
|
||||
number: pr.number,
|
||||
title: pr.title,
|
||||
url: prHtml,
|
||||
state: pr.state,
|
||||
merged: pr.merged || false,
|
||||
merged_at: pr.merged_at,
|
||||
});
|
||||
}
|
||||
|
||||
console.log(` #${issue.number} — ${comments.length} comments, ${timeline.length} timeline events, ${candidate.linked_prs.length} linked PRs`);
|
||||
}
|
||||
|
||||
await fs.writeFile("candidates.json", JSON.stringify(candidates, null, 2));
|
||||
console.log(`Wrote ${candidates.length} candidates to candidates.json`);
|
||||
63
.github/workflows/issue-resolution-triage.yml
vendored
63
.github/workflows/issue-resolution-triage.yml
vendored
@@ -1,63 +0,0 @@
|
||||
name: issue-resolution-triage
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [github-issue-resolver]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dry_run:
|
||||
description: "If true, do not close issues"
|
||||
required: false
|
||||
default: "true"
|
||||
max_issues:
|
||||
description: "How many issues to process"
|
||||
required: false
|
||||
default: "100"
|
||||
schedule:
|
||||
- cron: "17 2 * * *"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: read
|
||||
models: read
|
||||
|
||||
# todo: remove hardcoded values
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PROJECT_PAT: ${{ secrets.PROJECT_PAT }}
|
||||
DRY_RUN: "true"
|
||||
MAX_ISSUES: "100"
|
||||
REPO: ${{ github.repository }}
|
||||
PROJECT_ID: "PVT_kwDOBfz4Jc4BVeWR"
|
||||
PROJECT_STATUS_FIELD_ID: "PVTSSF_lADOBfz4Jc4BVeWRzhQ56sU"
|
||||
PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID: "a55a2be9"
|
||||
PROJECT_CONFIDENCE_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ57x4"
|
||||
PROJECT_REASON_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ5-Lg"
|
||||
PROJECT_EVIDENCE_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ5-Pw"
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: .github/issue-resolution
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
|
||||
- run: node scripts/fetch-candidates.mjs
|
||||
- run: node scripts/classify-candidates.mjs
|
||||
- run: node scripts/apply-decisions.mjs
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: triage-results
|
||||
path: |
|
||||
.github/issue-resolution/candidates.json
|
||||
.github/issue-resolution/decisions.json
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.4"
|
||||
SIGN_PIPE_VER: "v0.1.2"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
@@ -29,7 +31,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -97,7 +98,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersmanager := peers.NewManager(store, permissionsManagerMock)
|
||||
peersmanager := peers.NewManager(store)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersmanager)
|
||||
|
||||
@@ -333,10 +333,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.statusRecorder.MarkSignalConnected()
|
||||
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
||||
relayURLs = override
|
||||
}
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||
|
||||
@@ -944,12 +944,7 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
return fmt.Errorf("update relay token: %w", err)
|
||||
}
|
||||
|
||||
urls := update.Urls
|
||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
||||
urls = override
|
||||
}
|
||||
e.relayManager.UpdateServerURLs(urls)
|
||||
e.relayManager.UpdateServerURLs(update.Urls)
|
||||
|
||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
@@ -57,7 +58,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -1632,7 +1632,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
peersManager := peers.NewManager(store)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
|
||||
@@ -7,8 +7,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
)
|
||||
|
||||
func IsForceRelayed() bool {
|
||||
@@ -17,28 +16,3 @@ func IsForceRelayed() bool {
|
||||
}
|
||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||
}
|
||||
|
||||
// OverrideRelayURLs returns the relay server URL list set in
|
||||
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
|
||||
// the override is active. When the env var is unset, the boolean is false
|
||||
// and the caller should keep the list received from the management server.
|
||||
// Intended for lab/debug scenarios where a peer must pin to a specific home
|
||||
// relay regardless of what management offers.
|
||||
func OverrideRelayURLs() ([]string, bool) {
|
||||
raw := os.Getenv(EnvKeyNBHomeRelayServers)
|
||||
if raw == "" {
|
||||
return nil, false
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
urls := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
urls = append(urls, p)
|
||||
}
|
||||
}
|
||||
if len(urls) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return urls, true
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
@@ -38,7 +40,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -305,7 +306,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||
peersManager := peers.NewManager(store)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -71,7 +71,7 @@ require (
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/oapi-codegen/runtime v1.1.2
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -453,8 +453,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34 h1:g74mB64wnjCagzE1spKgPfTI/ont1SdSL3uX5bOecgM=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34/go.mod h1:lCOq5d1i19AQjEEW2d7aNK0Nn0KC0MKyfMz/PLwVBFg=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
|
||||
@@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) {
|
||||
// are stored with types that Dex can open.
|
||||
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
|
||||
switch connType {
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||
return "oidc", applyOIDCDefaults(connType, config)
|
||||
default:
|
||||
return connType, config
|
||||
@@ -218,8 +218,6 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin
|
||||
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
|
||||
case "okta", "pocketid":
|
||||
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "adfs":
|
||||
augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
|
||||
}
|
||||
|
||||
return augmented
|
||||
|
||||
@@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto
|
||||
var err error
|
||||
|
||||
switch cfg.Type {
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||
dexType = "oidc"
|
||||
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
|
||||
case "google":
|
||||
@@ -220,8 +220,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "adfs":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
|
||||
}
|
||||
return encodeConnectorConfig(oidcConfig)
|
||||
}
|
||||
@@ -285,7 +283,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa
|
||||
// inferOIDCProviderType infers the specific OIDC provider from connector ID
|
||||
func inferOIDCProviderType(connectorID string) string {
|
||||
connectorIDLower := strings.ToLower(connectorID)
|
||||
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} {
|
||||
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
|
||||
if strings.Contains(connectorIDLower, provider) {
|
||||
return provider
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -15,9 +16,11 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -55,6 +58,13 @@ type Controller struct {
|
||||
proxyController port_forwarding.Controller
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
holder *types.Holder
|
||||
|
||||
expNewNetworkMap bool
|
||||
expNewNetworkMapAIDs map[string]struct{}
|
||||
|
||||
compactedNetworkMap bool
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
@@ -71,6 +81,29 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||
}
|
||||
|
||||
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
|
||||
newNetworkMapBuilder = false
|
||||
}
|
||||
|
||||
compactedNetworkMap := true
|
||||
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
|
||||
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
|
||||
if err != nil && len(compactedEnv) > 0 {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
|
||||
}
|
||||
if err == nil && !parsedCompactedNmap {
|
||||
log.WithContext(ctx).Info("disabling compacted mode")
|
||||
compactedNetworkMap = false
|
||||
}
|
||||
|
||||
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
||||
expIDs := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
expIDs[id] = struct{}{}
|
||||
}
|
||||
|
||||
return &Controller{
|
||||
repo: newRepository(store),
|
||||
metrics: nMetrics,
|
||||
@@ -84,6 +117,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
|
||||
proxyController: proxyController,
|
||||
EphemeralPeersManager: ephemeralPeersManager,
|
||||
|
||||
holder: types.NewHolder(),
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
expNewNetworkMapAIDs: expIDs,
|
||||
|
||||
compactedNetworkMap: compactedNetworkMap,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,9 +153,17 @@ func (c *Controller) CountStreams() int {
|
||||
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
@@ -150,6 +197,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
@@ -192,7 +243,16 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
switch {
|
||||
case c.experimentalNetworkMap(accountID):
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
case c.compactedNetworkMap:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
default:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
@@ -258,6 +318,10 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return fmt.Errorf("recalculate network map cache: %v", err)
|
||||
}
|
||||
|
||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -307,7 +371,16 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return err
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
switch {
|
||||
case c.experimentalNetworkMap(accountId):
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
case c.compactedNetworkMap:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
default:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -378,9 +451,17 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return peer, emptyMap, nil, 0, nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
@@ -412,10 +493,20 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
if c.compactedNetworkMap {
|
||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -427,6 +518,108 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
c.enrichAccountFromHolder(account)
|
||||
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
||||
}
|
||||
|
||||
func (c *Controller) getPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
||||
if account == nil {
|
||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||
return &types.NetworkMap{
|
||||
Network: &types.Network{},
|
||||
}
|
||||
}
|
||||
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
||||
c.enrichAccountFromHolder(account)
|
||||
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
c.enrichAccountFromHolder(account)
|
||||
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
||||
account := c.getAccountFromHolder(accountId)
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
account.UpdatePeerInNetworkMapCache(peer)
|
||||
}
|
||||
|
||||
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
account.RecalculateNetworkMapCache(validatedPeers)
|
||||
c.updateAccountInHolder(account)
|
||||
}
|
||||
|
||||
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
||||
return err
|
||||
}
|
||||
c.recalculateNetworkMapCache(account, validatedPeers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) experimentalNetworkMap(accountId string) bool {
|
||||
_, ok := c.expNewNetworkMapAIDs[accountId]
|
||||
return c.expNewNetworkMap || ok
|
||||
}
|
||||
|
||||
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
|
||||
a := c.holder.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
c.holder.AddAccount(account)
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache = a.NetworkMapCache
|
||||
if account.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
|
||||
return c.holder.GetAccount(accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
|
||||
a := c.holder.GetAccount(accountID)
|
||||
if a != nil {
|
||||
return a
|
||||
}
|
||||
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return account
|
||||
}
|
||||
|
||||
func (c *Controller) updateAccountInHolder(account *types.Account) {
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
||||
if settings == nil {
|
||||
@@ -563,7 +756,16 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get peers by ids: %w", err)
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
c.UpdatePeerInNetworkMapCache(accountID, peer)
|
||||
}
|
||||
|
||||
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||
}
|
||||
@@ -573,6 +775,14 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
|
||||
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
|
||||
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
|
||||
}
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -607,6 +817,19 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
})
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||
continue
|
||||
}
|
||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
@@ -649,11 +872,21 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||
} else {
|
||||
account.InjectProxyPolicies(ctx)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
if c.compactedNetworkMap {
|
||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
|
||||
@@ -12,6 +12,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
||||
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
||||
|
||||
DnsForwarderPort = nbdns.ForwarderServerPort
|
||||
OldForwarderPort = nbdns.ForwarderClientPort
|
||||
DnsForwarderPortMinVersion = "v0.59.0"
|
||||
|
||||
@@ -16,9 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -38,17 +35,15 @@ type Manager interface {
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
accountManager account.Manager
|
||||
|
||||
networkMapController network_map.Controller
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
|
||||
func NewManager(store store.Store) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,28 +60,10 @@ func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
|
||||
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
||||
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
}
|
||||
|
||||
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
}
|
||||
|
||||
|
||||
@@ -4,20 +4,29 @@ package permissions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// AuthErrorHandler is called when an auth error occurs during permission validation.
|
||||
// If it returns true, the error is considered handled and the default error response is skipped.
|
||||
type AuthErrorHandler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool
|
||||
|
||||
type Manager interface {
|
||||
WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth), authErrHandler ...AuthErrorHandler) http.HandlerFunc
|
||||
ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error)
|
||||
ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool
|
||||
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
|
||||
@@ -36,6 +45,51 @@ func NewManager(store store.Store) Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPermission wraps an HTTP handler with permission checking logic.
|
||||
// An optional AuthErrorHandler can be provided to intercept auth errors before the default response is written.
|
||||
func (m *managerImpl) WithPermission(
|
||||
module modules.Module,
|
||||
operation operations.Operation,
|
||||
handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth),
|
||||
authErrHandler ...AuthErrorHandler,
|
||||
) http.HandlerFunc {
|
||||
var onAuthErr AuthErrorHandler
|
||||
if len(authErrHandler) > 0 {
|
||||
onAuthErr = authErrHandler[0]
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get user auth from context: %v", err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allowed, err := m.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, module, operation)
|
||||
if err != nil {
|
||||
if onAuthErr != nil && onAuthErr(w, r, &userAuth, err) {
|
||||
return
|
||||
}
|
||||
log.WithContext(r.Context()).Errorf("failed to validate permissions for user %s on account %s: %v", userAuth.UserId, userAuth.AccountId, err)
|
||||
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
permErr := status.NewPermissionDeniedError()
|
||||
if onAuthErr != nil && onAuthErr(w, r, &userAuth, permErr) {
|
||||
return
|
||||
}
|
||||
log.WithContext(r.Context()).Tracef("user %s on account %s is not allowed to %s in %s", userAuth.UserId, userAuth.AccountId, operation, module)
|
||||
util.WriteError(r.Context(), permErr, w)
|
||||
return
|
||||
}
|
||||
|
||||
handlerFunc(w, r, &userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) ValidateUserPermissions(
|
||||
ctx context.Context,
|
||||
accountID string,
|
||||
@@ -68,10 +122,6 @@ func (m *managerImpl) ValidateUserPermissions(
|
||||
return false, err
|
||||
}
|
||||
|
||||
if operation == operations.Read && user.IsServiceUser {
|
||||
return true, nil // this should be replaced by proper granular access role
|
||||
}
|
||||
|
||||
role, ok := roles.RolesMap[user.Role]
|
||||
if !ok {
|
||||
return false, status.NewUserRoleNotFoundError(string(user.Role))
|
||||
@@ -127,3 +177,17 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR
|
||||
func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// WrapHandler wraps a handler that expects UserAuth with context extraction.
|
||||
// Unlike WithPermission, it does not perform any permission checks.
|
||||
func WrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get user auth from context: %v", err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
h(w, r, &userAuth)
|
||||
}
|
||||
}
|
||||
@@ -5,15 +5,18 @@
|
||||
package permissions
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
"context"
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
account "github.com/netbirdio/netbird/management/server/account"
|
||||
modules "github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
operations "github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
roles "github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
types "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
@@ -108,3 +111,22 @@ func (mr *MockManagerMockRecorder) ValidateUserPermissions(ctx, accountID, userI
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserPermissions", reflect.TypeOf((*MockManager)(nil).ValidateUserPermissions), ctx, accountID, userID, module, operation)
|
||||
}
|
||||
|
||||
// WithPermission mocks base method.
|
||||
func (m *MockManager) WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(http.ResponseWriter, *http.Request, *auth.UserAuth), authErrHandler ...AuthErrorHandler) http.HandlerFunc {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{module, operation, handlerFunc}
|
||||
for _, a := range authErrHandler {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "WithPermission", varargs...)
|
||||
ret0, _ := ret[0].(http.HandlerFunc)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// WithPermission indicates an expected call of WithPermission.
|
||||
func (mr *MockManagerMockRecorder) WithPermission(module, operation, handlerFunc interface{}, authErrHandler ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{module, operation, handlerFunc}, authErrHandler...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithPermission", reflect.TypeOf((*MockManager)(nil).WithPermission), varargs...)
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package roles
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ var Admin = RolePermissions{
|
||||
modules.Accounts: {
|
||||
operations.Read: true,
|
||||
operations.Create: false,
|
||||
operations.Update: false,
|
||||
operations.Update: true,
|
||||
operations.Delete: false,
|
||||
},
|
||||
},
|
||||
@@ -1,7 +1,7 @@
|
||||
package roles
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package roles
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package roles
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package roles
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package roles
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -5,8 +5,11 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -15,21 +18,15 @@ type handler struct {
|
||||
manager accesslogs.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager, permissionsManager permissions.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events/proxy", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAccessLogs)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var filter accesslogs.AccessLogFilter
|
||||
filter.ParseFromRequest(r)
|
||||
|
||||
|
||||
@@ -9,25 +9,19 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
geo geolocation.Geolocation
|
||||
cleanupCancel context.CancelFunc
|
||||
store store.Store
|
||||
geo geolocation.Geolocation
|
||||
cleanupCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
|
||||
func NewManager(store store.Store, geo geolocation.Geolocation) accesslogs.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
geo: geo,
|
||||
store: store,
|
||||
geo: geo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,14 +57,6 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
|
||||
|
||||
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
|
||||
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, 0, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, 0, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,11 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -17,15 +20,15 @@ type handler struct {
|
||||
manager Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager Manager, permissionsManager permissions.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllDomains)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Create, h.createCustomDomain)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteCustomDomain)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}/validate", permissionsManager.WithPermission(modules.Services, operations.Create, h.triggerCustomDomainValidation)).Methods("GET", "OPTIONS") // TODO: this should be a POST
|
||||
}
|
||||
|
||||
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
|
||||
@@ -56,13 +59,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
||||
return resp
|
||||
}
|
||||
|
||||
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -77,13 +74,7 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, ret)
|
||||
}
|
||||
|
||||
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PostApiReverseProxiesDomainsJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -99,13 +90,7 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
|
||||
}
|
||||
|
||||
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
domainID := mux.Vars(r)["domainId"]
|
||||
if domainID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||
@@ -120,13 +105,7 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
domainID := mux.Vars(r)["domainId"]
|
||||
if domainID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||
|
||||
@@ -11,11 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type store interface {
|
||||
@@ -37,32 +33,22 @@ type proxyManager interface {
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyManager proxyManager
|
||||
permissionsManager permissions.Manager
|
||||
accountManager account.Manager
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyManager proxyManager
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
|
||||
func NewManager(store store, proxyMgr proxyManager, accountManager account.Manager) Manager {
|
||||
return Manager{
|
||||
store: store,
|
||||
proxyManager: proxyMgr,
|
||||
validator: domain.Validator{Resolver: net.DefaultResolver},
|
||||
permissionsManager: permissionsManager,
|
||||
accountManager: accountManager,
|
||||
store: store,
|
||||
proxyManager: proxyMgr,
|
||||
validator: domain.Validator{Resolver: net.DefaultResolver},
|
||||
accountManager: accountManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
domains, err := m.store.ListCustomDomains(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list custom domains: %w", err)
|
||||
@@ -118,14 +104,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
}
|
||||
|
||||
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Verify the target cluster is in the available clusters
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
@@ -159,14 +137,6 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
}
|
||||
|
||||
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get domain from store: %w", err)
|
||||
@@ -183,21 +153,6 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
|
||||
}
|
||||
|
||||
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"accountID": accountID,
|
||||
"domainID": domainID,
|
||||
}).WithError(err).Error("validate domain")
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
log.WithFields(log.Fields{
|
||||
"accountID": accountID,
|
||||
"domainID": domainID,
|
||||
}).WithError(err).Error("validate domain")
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"accountID": accountID,
|
||||
"domainID": domainID,
|
||||
|
||||
@@ -23,7 +23,7 @@ type Proxy struct {
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Capabilities Capabilities `gorm:"embedded"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
@@ -6,12 +6,14 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -30,25 +32,19 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
|
||||
}
|
||||
|
||||
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
|
||||
domainmanager.RegisterEndpoints(domainRouter, domainManager)
|
||||
domainmanager.RegisterEndpoints(domainRouter, domainManager, permissionsManager)
|
||||
|
||||
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
||||
accesslogsmanager.RegisterEndpoints(router, accessLogsManager, permissionsManager)
|
||||
|
||||
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/clusters", permissionsManager.WithPermission(modules.Services, operations.Read, h.getClusters)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllServices)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Create, h.createService)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Read, h.getService)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Update, h.updateService)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteService)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -63,13 +59,7 @@ func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, apiServices)
|
||||
}
|
||||
|
||||
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.ServiceRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -77,12 +67,13 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
service := new(rpservice.Service)
|
||||
var err error
|
||||
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = service.Validate(); err != nil {
|
||||
if err := service.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -96,13 +87,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
serviceID := mux.Vars(r)["serviceId"]
|
||||
if serviceID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||
@@ -118,13 +103,7 @@ func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
serviceID := mux.Vars(r)["serviceId"]
|
||||
if serviceID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||
@@ -139,12 +118,13 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
service := new(rpservice.Service)
|
||||
service.ID = serviceID
|
||||
var err error
|
||||
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = service.Validate(); err != nil {
|
||||
if err := service.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -158,13 +138,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
serviceID := mux.Vars(r)["serviceId"]
|
||||
if serviceID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||
@@ -179,13 +153,7 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
@@ -86,18 +85,17 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
}
|
||||
|
||||
mgr := &Manager{
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
permissionsManager: permissions.NewManager(testStore),
|
||||
proxyController: mockCtrl,
|
||||
capabilities: mockCaps,
|
||||
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
proxyController: mockCtrl,
|
||||
capabilities: mockCaps,
|
||||
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
|
||||
}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
|
||||
|
||||
@@ -21,9 +21,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -82,24 +79,22 @@ type CapabilityProvider interface {
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
proxyController proxy.Controller
|
||||
capabilities CapabilityProvider
|
||||
clusterDeriver ClusterDeriver
|
||||
exposeReaper *exposeReaper
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
proxyController proxy.Controller
|
||||
capabilities CapabilityProvider
|
||||
clusterDeriver ClusterDeriver
|
||||
exposeReaper *exposeReaper
|
||||
}
|
||||
|
||||
// NewManager creates a new service manager.
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
|
||||
func NewManager(store store.Store, accountManager account.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
|
||||
mgr := &Manager{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
proxyController: proxyController,
|
||||
capabilities: capabilities,
|
||||
clusterDeriver: clusterDeriver,
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
proxyController: proxyController,
|
||||
capabilities: capabilities,
|
||||
clusterDeriver: clusterDeriver,
|
||||
}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
return mgr
|
||||
@@ -112,26 +107,10 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||
|
||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetActiveProxyClusters(ctx)
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
@@ -185,14 +164,6 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
||||
}
|
||||
|
||||
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
@@ -206,14 +177,6 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s
|
||||
}
|
||||
|
||||
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -224,7 +187,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
err := m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
@@ -491,14 +454,6 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
@@ -785,16 +740,8 @@ func validateResourceTargetType(target *service.Target, resource *resourcetypes.
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var s *service.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
@@ -825,16 +772,8 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var services []*service.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
@@ -1119,7 +1058,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
|
||||
}
|
||||
groupIDs := make([]string, 0, len(groupNames))
|
||||
for _, groupName := range groupNames {
|
||||
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator)
|
||||
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
|
||||
}
|
||||
|
||||
@@ -23,9 +23,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -700,12 +697,10 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
|
||||
require.NoError(t, err)
|
||||
|
||||
permsMgr := permissions.NewManager(testStore)
|
||||
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
}
|
||||
@@ -718,10 +713,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
require.NoError(t, err)
|
||||
|
||||
mgr := &Manager{
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
permissionsManager: permsMgr,
|
||||
proxyController: proxyController,
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
proxyController: proxyController,
|
||||
clusterDeriver: &testClusterDeriver{
|
||||
domains: []string{"test.netbird.io"},
|
||||
},
|
||||
@@ -1130,7 +1124,6 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPerms := permissions.NewMockManager(ctrl)
|
||||
mockAcct := account.NewMockManager(ctrl)
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
@@ -1141,10 +1134,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
mgr := &Manager{
|
||||
store: sqlStore,
|
||||
permissionsManager: mockPerms,
|
||||
accountManager: mockAcct,
|
||||
proxyController: proxyController,
|
||||
store: sqlStore,
|
||||
accountManager: mockAcct,
|
||||
proxyController: proxyController,
|
||||
}
|
||||
|
||||
service := &rpservice.Service{
|
||||
@@ -1167,9 +1159,6 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion")
|
||||
|
||||
mockPerms.EXPECT().
|
||||
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
|
||||
Return(true, nil)
|
||||
mockAcct.EXPECT().
|
||||
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
|
||||
mockAcct.EXPECT().
|
||||
|
||||
@@ -6,8 +6,11 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -17,25 +20,19 @@ type handler struct {
|
||||
manager zones.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager zones.Manager, permissionsManager permissions.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllZones)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createZone)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getZone)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateZone)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteZone)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -50,13 +47,7 @@ func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, apiZones)
|
||||
}
|
||||
|
||||
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PostApiDnsZonesJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -66,7 +57,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
|
||||
zone := new(zones.Zone)
|
||||
zone.FromAPIRequest(&req)
|
||||
|
||||
if err = zone.Validate(); err != nil {
|
||||
if err := zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -80,13 +71,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -102,13 +87,7 @@ func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -116,7 +95,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdJSONRequestBody
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
@@ -125,7 +104,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
zone.FromAPIRequest(&req)
|
||||
zone.ID = zoneID
|
||||
|
||||
if err = zone.Validate(); err != nil {
|
||||
if err := zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -139,20 +118,14 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
|
||||
if err := h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -7,62 +7,34 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
dnsDomain string
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
dnsDomain string
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
|
||||
func NewManager(store store.Store, accountManager account.Manager, dnsDomain string) zones.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
dnsDomain: dnsDomain,
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
dnsDomain: dnsDomain,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var err error
|
||||
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -102,14 +74,6 @@ func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string,
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get zone: %w", err)
|
||||
@@ -150,14 +114,6 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
|
||||
@@ -13,9 +13,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -29,7 +26,7 @@ const (
|
||||
testDNSDomain = "netbird.selfhosted"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *gomock.Controller, func()) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -49,23 +46,17 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccoun
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockAccountManager := &mock_server.MockAccountManager{}
|
||||
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: testStore,
|
||||
accountManager: mockAccountManager,
|
||||
permissionsManager: mockPermissionsManager,
|
||||
dnsDomain: testDNSDomain,
|
||||
}
|
||||
manager := NewManager(testStore, mockAccountManager, testDNSDomain).(*managerImpl)
|
||||
|
||||
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||
return manager, testStore, mockAccountManager, ctrl, cleanup
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -77,10 +68,6 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
err = testStore.CreateZone(ctx, zone2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
@@ -88,43 +75,13 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
assert.Equal(t, zone2.ID, result[1].ID)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission validation error", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -132,10 +89,6 @@ func TestManagerImpl_GetZone(t *testing.T) {
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, zone.ID, result.ID)
|
||||
@@ -143,29 +96,13 @@ func TestManagerImpl_GetZone(t *testing.T) {
|
||||
assert.Equal(t, zone.Domain, result.Domain)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -177,10 +114,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
@@ -199,31 +132,8 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "new.example.com",
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("invalid group", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -233,17 +143,13 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
DistributionGroups: []string{"invalid-group"},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("duplicate domain", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -259,10 +165,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -273,7 +175,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("peer DNS domain conflict", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -291,10 +193,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -305,7 +203,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("default DNS domain conflict", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -317,10 +215,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -335,7 +229,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -352,10 +246,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
@@ -375,7 +265,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("domain change not allowed", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -392,10 +282,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -405,31 +291,8 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: testZoneID,
|
||||
Name: "Updated Name",
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("zone not found", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -439,10 +302,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -453,7 +312,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success with records", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -469,10 +328,6 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCallCount := 0
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCallCount++
|
||||
@@ -493,7 +348,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("success without records", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -501,10 +356,6 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
@@ -522,31 +373,11 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("zone not found", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
@@ -6,8 +6,11 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -17,25 +20,19 @@ type handler struct {
|
||||
manager records.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager records.Manager, permissionsManager permissions.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllRecords)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createRecord)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getRecord)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateRecord)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteRecord)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -56,13 +53,7 @@ func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, apiRecords)
|
||||
}
|
||||
|
||||
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -78,7 +69,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
|
||||
record := new(records.Record)
|
||||
record.FromAPIRequest(&req)
|
||||
|
||||
if err = record.Validate(); err != nil {
|
||||
if err := record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -92,13 +83,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -120,13 +105,7 @@ func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -140,7 +119,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
@@ -149,7 +128,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
record.FromAPIRequest(&req)
|
||||
record.ID = recordID
|
||||
|
||||
if err = record.Validate(); err != nil {
|
||||
if err := record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -163,13 +142,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -182,7 +155,7 @@ func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
|
||||
if err := h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,64 +9,36 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
|
||||
func NewManager(store store.Store, accountManager account.Manager) records.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
|
||||
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
@@ -101,18 +73,11 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
var record *records.Record
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
@@ -160,18 +125,11 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var record *records.Record
|
||||
var zone *zones.Zone
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
|
||||
@@ -12,12 +12,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -27,7 +23,7 @@ const (
|
||||
testGroupID = "test-group-id"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *gomock.Controller, func()) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -51,22 +47,17 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_serv
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockAccountManager := &mock_server.MockAccountManager{}
|
||||
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: testStore,
|
||||
accountManager: mockAccountManager,
|
||||
permissionsManager: mockPermissionsManager,
|
||||
}
|
||||
manager := NewManager(testStore, mockAccountManager).(*managerImpl)
|
||||
|
||||
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||
return manager, testStore, zone, mockAccountManager, ctrl, cleanup
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -78,10 +69,6 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
@@ -89,43 +76,13 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
assert.Equal(t, record2.ID, result[1].ID)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission validation error", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -133,10 +90,6 @@ func TestManagerImpl_GetRecord(t *testing.T) {
|
||||
err := testStore.CreateDNSRecord(ctx, record)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record.ID, result.ID)
|
||||
@@ -146,29 +99,13 @@ func TestManagerImpl_GetRecord(t *testing.T) {
|
||||
assert.Equal(t, record.TTL, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success - A record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -179,10 +116,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
@@ -202,7 +135,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("success - AAAA record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -213,10 +146,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
@@ -231,7 +160,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("success - CNAME record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -242,10 +171,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
@@ -259,32 +184,8 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
assert.Equal(t, inputRecord.Content, result.Content)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record name not in zone", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, zone, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -295,10 +196,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -306,7 +203,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("duplicate record", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -321,10 +218,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -332,7 +225,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -347,10 +240,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -362,7 +251,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -378,10 +267,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
TTL: 600, // Changed TTL
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
@@ -400,7 +285,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("update only TTL - no validation", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -416,10 +301,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
TTL: 600, // Only TTL changed
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
// Event should be stored
|
||||
}
|
||||
@@ -430,33 +311,8 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
assert.Equal(t, 600, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: testRecordID,
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.100",
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record not found", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, zone, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -468,17 +324,13 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("update creates duplicate", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -498,10 +350,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
@@ -513,7 +361,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -521,10 +369,6 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
err := testStore.CreateDNSRecord(ctx, record)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
@@ -542,31 +386,11 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record not found", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
manager, _, zone, _, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
@@ -233,7 +233,7 @@ func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
|
||||
|
||||
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
||||
return Create(s, func() accesslogs.Manager {
|
||||
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
||||
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.GeoLocationManager())
|
||||
accessLogManager.StartPeriodicCleanup(
|
||||
context.Background(),
|
||||
s.Config.ReverseProxy.AccessLogRetentionDays,
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
@@ -27,7 +28,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
)
|
||||
@@ -82,13 +82,13 @@ func (s *BaseServer) SettingsManager() settings.Manager {
|
||||
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
|
||||
}
|
||||
|
||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
|
||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, idpConfig)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) PeersManager() peers.Manager {
|
||||
return Create(s, func() peers.Manager {
|
||||
manager := peers.NewManager(s.Store(), s.PermissionsManager())
|
||||
manager := peers.NewManager(s.Store())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
manager.SetNetworkMapController(s.NetworkMapController())
|
||||
manager.SetIntegratedPeerValidator(s.IntegratedValidator())
|
||||
@@ -161,43 +161,43 @@ func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
||||
|
||||
func (s *BaseServer) GroupsManager() groups.Manager {
|
||||
return Create(s, func() groups.Manager {
|
||||
return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
|
||||
return groups.NewManager(s.Store(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||
return Create(s, func() resources.Manager {
|
||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
|
||||
return resources.NewManager(s.Store(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RoutesManager() routers.Manager {
|
||||
return Create(s, func() routers.Manager {
|
||||
return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
|
||||
return routers.NewManager(s.Store(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) NetworksManager() networks.Manager {
|
||||
return Create(s, func() networks.Manager {
|
||||
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
||||
return networks.NewManager(s.Store(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ZonesManager() zones.Manager {
|
||||
return Create(s, func() zones.Manager {
|
||||
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
|
||||
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.DNSDomain())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RecordsManager() records.Manager {
|
||||
return Create(s, func() records.Manager {
|
||||
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
||||
return recordsManager.NewManager(s.Store(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ServiceManager() service.Manager {
|
||||
return Create(s, func() service.Manager {
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
|
||||
|
||||
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
return Create(s, func() *manager.Manager {
|
||||
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
|
||||
m := manager.NewManager(s.Store(), s.ProxyManager(), s.AccountManager())
|
||||
return &m
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
@@ -39,9 +40,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -282,22 +280,14 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
|
||||
// User that performs the update has to belong to the account.
|
||||
// Returns an updated Settings
|
||||
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var oldSettings *types.Settings
|
||||
var updateAccountPeers bool
|
||||
var groupChangesAffectPeers bool
|
||||
var reloadReverseProxy bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var groupsUpdated bool
|
||||
var err error
|
||||
|
||||
oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
@@ -725,15 +715,6 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return err
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
|
||||
}
|
||||
|
||||
userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users))
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||
@@ -976,6 +957,10 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, fmt.Errorf("user %s does not belong to account %s", userID, accountID)
|
||||
}
|
||||
|
||||
key := user.IntegrationReference.CacheKey(accountID, userID)
|
||||
ud, err := am.externalCacheManager.Get(am.ctx, key)
|
||||
if err != nil {
|
||||
@@ -1287,41 +1272,16 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
|
||||
|
||||
// GetAccountByID returns an account associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAccountMeta returns the account metadata associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// GetAccountOnboarding retrieves the onboarding information for a specific account.
|
||||
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
|
||||
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
|
||||
log.Errorf("failed to get account onboarding for account %s: %v", accountID, err)
|
||||
@@ -1338,15 +1298,6 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
|
||||
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
|
||||
return nil, fmt.Errorf("failed to get account onboarding: %w", err)
|
||||
@@ -1405,9 +1356,8 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
// Permission checks are now handled by the HTTP middleware via WithPermission wrapper
|
||||
// User account association is already validated above by GetUserByUserID
|
||||
|
||||
if !user.IsServiceUser && userAuth.Invited {
|
||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||
@@ -1849,13 +1799,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
@@ -2197,14 +2140,6 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate user permissions: %w", err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update peer IP transaction: %w", err)
|
||||
|
||||
@@ -60,7 +60,7 @@ type Manager interface {
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -75,7 +75,7 @@ type Manager interface {
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
||||
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
|
||||
|
||||
@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
|
||||
}
|
||||
|
||||
// GetGroupByName mocks base method.
|
||||
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
|
||||
ret0, _ := ret[0].(*types.Group)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupByName indicates an expected call of GetGroupByName.
|
||||
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
|
||||
}
|
||||
|
||||
// GetIdentityProvider mocks base method.
|
||||
@@ -946,18 +946,18 @@ func (mr *MockManagerMockRecorder) GetPeerNetwork(ctx, peerID interface{}) *gomo
|
||||
}
|
||||
|
||||
// GetPeers mocks base method.
|
||||
func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*peer.Peer, error) {
|
||||
func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter)
|
||||
ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter, all)
|
||||
ret0, _ := ret[0].([]*peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeers indicates an expected call of GetPeers.
|
||||
func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter, all interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter, all)
|
||||
}
|
||||
|
||||
// GetPolicy mocks base method.
|
||||
|
||||
@@ -22,9 +22,11 @@ import (
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
@@ -49,7 +51,6 @@ import (
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -408,7 +409,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@@ -1171,6 +1172,11 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
@@ -1226,6 +1232,11 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
@@ -1264,6 +1275,11 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
@@ -1317,6 +1333,11 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
@@ -1377,6 +1398,11 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
@@ -1608,6 +1634,75 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
}
|
||||
|
||||
func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, prefix2, err := route.ParseNetwork("192.168.0.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
account := &types.Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
|
||||
},
|
||||
Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
|
||||
Routes: map[route.ID]*route.Route{
|
||||
"route-1": {
|
||||
ID: "route-1",
|
||||
Network: prefix,
|
||||
NetID: "network-1",
|
||||
Description: "network-1",
|
||||
Peer: "peer-1",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
"route-2": {
|
||||
ID: "route-2",
|
||||
Network: prefix2,
|
||||
NetID: "network-2",
|
||||
Description: "network-2",
|
||||
Peer: "peer-2",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
"route-3": {
|
||||
ID: "route-3",
|
||||
Network: prefix,
|
||||
NetID: "network-1",
|
||||
Description: "network-1",
|
||||
Peer: "peer-2",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
|
||||
|
||||
assert.Len(t, routes, 2)
|
||||
routeIDs := make(map[route.ID]struct{}, 2)
|
||||
for _, r := range routes {
|
||||
routeIDs[r.ID] = struct{}{}
|
||||
}
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||
|
||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
|
||||
|
||||
assert.Len(t, emptyRoutes, 0)
|
||||
}
|
||||
|
||||
func TestAccount_Copy(t *testing.T) {
|
||||
account := &types.Account{
|
||||
Id: "account1",
|
||||
@@ -1730,7 +1825,9 @@ func TestAccount_Copy(t *testing.T) {
|
||||
AccountID: "account1",
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
}
|
||||
account.InitOnce()
|
||||
err := hasNilField(account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -3051,7 +3148,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
peersManager := peers.NewManager(store)
|
||||
|
||||
proxyManager := proxy.NewMockManager(ctrl)
|
||||
proxyManager.EXPECT().
|
||||
@@ -3068,7 +3165,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
|
||||
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -3079,7 +3176,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil))
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, proxyController, proxyManager, nil))
|
||||
|
||||
return manager, updateManager, nil
|
||||
}
|
||||
@@ -3157,13 +3254,6 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
return manager, updateManager, account, peer1, peer2, peer3
|
||||
}
|
||||
|
||||
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
|
||||
// wrappers wait for an expected update message. Sized for slow CI runners
|
||||
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
|
||||
// seconds. Only runs down on failure; passing tests return immediately
|
||||
// when the channel delivers.
|
||||
const peerUpdateTimeout = 5 * time.Second
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
@@ -3182,7 +3272,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
|
||||
if msg == nil {
|
||||
t.Errorf("Received nil update message, expected valid message")
|
||||
}
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Timed out waiting for update message")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
@@ -22,14 +20,6 @@ const (
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
@@ -39,18 +29,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -14,11 +14,11 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -79,16 +78,6 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
if len(dnsSettings.DisabledManagementGroups) != 1 {
|
||||
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
|
||||
}
|
||||
|
||||
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
|
||||
if err == nil {
|
||||
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
|
||||
}
|
||||
|
||||
s, ok := status.FromError(err)
|
||||
if !ok && s.Type() != status.PermissionDenied {
|
||||
t.Errorf("returned error should be Permission Denied, got err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveDNSSettings(t *testing.T) {
|
||||
@@ -223,7 +212,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
// return empty extra settings for expected calls to UpdateAccountPeers
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
peersManager := peers.NewManager(store)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -234,7 +223,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
|
||||
|
||||
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
}
|
||||
@@ -458,7 +447,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -478,7 +467,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -518,7 +507,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -9,11 +9,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func isEnabled() bool {
|
||||
@@ -23,14 +20,6 @@ func isEnabled() bool {
|
||||
|
||||
// GetEvents returns a list of activity events of an account
|
||||
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
@@ -32,13 +30,24 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
|
||||
// Permission checks are now handled by the HTTP middleware via WithPermission wrapper
|
||||
// This method is called from authenticated/authorized handlers, so we just validate
|
||||
// that the user exists and is part of the account
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
if user == nil {
|
||||
return status.NewUserNotFoundError(userID)
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsBlocked() {
|
||||
return status.NewUserBlockedError()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -61,27 +70,17 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
}
|
||||
|
||||
// CreateGroup object of the peers
|
||||
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,19 +124,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
|
||||
// UpdateGroup object of the peers
|
||||
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -196,33 +187,24 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
|
||||
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -243,6 +225,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -264,21 +247,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
|
||||
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -311,6 +287,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -416,14 +393,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
|
||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*types.Group
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
peer2 "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -620,7 +619,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -638,7 +637,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -656,7 +655,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -689,7 +688,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -730,7 +729,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -757,18 +756,17 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving a group linked to network router should update account peers and send peer update
|
||||
t.Run("saving group linked to network router", func(t *testing.T) {
|
||||
permissionsManager := permissions.NewManager(manager.Store)
|
||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
|
||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||
groupsManager := groups.NewManager(manager.Store, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, groupsManager, manager, manager.serviceManager)
|
||||
routersManager := routers.NewManager(manager.Store, manager)
|
||||
networksManager := networks.NewManager(manager.Store, resourcesManager, routersManager, manager)
|
||||
|
||||
network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{
|
||||
ID: "network_test",
|
||||
@@ -804,7 +802,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -6,9 +6,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -25,31 +22,21 @@ type Manager interface {
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
accountManager account.Manager
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
type mockManager struct {
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
|
||||
func NewManager(store store.Store, accountManager account.Manager) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
accountManager: accountManager,
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting account groups: %w", err)
|
||||
@@ -73,14 +60,6 @@ func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID str
|
||||
}
|
||||
|
||||
func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, userID, groupID, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error adding resource to group: %w", err)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/rs/cors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -31,10 +32,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
|
||||
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
@@ -124,25 +123,25 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router, permissionsManager)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
|
||||
users.AddEndpoints(accountManager, router)
|
||||
users.AddInvitesEndpoints(accountManager, router)
|
||||
users.AddEndpoints(accountManager, router, permissionsManager)
|
||||
users.AddInvitesEndpoints(accountManager, router, permissionsManager)
|
||||
users.AddPublicInvitesEndpoints(accountManager, router)
|
||||
setup_keys.AddEndpoints(accountManager, router)
|
||||
policies.AddEndpoints(accountManager, LocationManager, router)
|
||||
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
|
||||
setup_keys.AddEndpoints(accountManager, router, permissionsManager)
|
||||
policies.AddEndpoints(accountManager, LocationManager, router, permissionsManager)
|
||||
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router, permissionsManager)
|
||||
policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router)
|
||||
groups.AddEndpoints(accountManager, router)
|
||||
routes.AddEndpoints(accountManager, router)
|
||||
dns.AddEndpoints(accountManager, router)
|
||||
events.AddEndpoints(accountManager, router)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
||||
zonesManager.RegisterEndpoints(router, zManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
groups.AddEndpoints(accountManager, router, permissionsManager)
|
||||
routes.AddEndpoints(accountManager, router, permissionsManager)
|
||||
dns.AddEndpoints(accountManager, router, permissionsManager)
|
||||
events.AddEndpoints(accountManager, router, permissionsManager)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, permissionsManager, router)
|
||||
zonesManager.RegisterEndpoints(router, zManager, permissionsManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager, permissionsManager)
|
||||
idp.AddEndpoints(accountManager, router, permissionsManager)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router, permissionsManager)
|
||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
||||
}
|
||||
|
||||
@@ -12,10 +12,13 @@ import (
|
||||
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -40,11 +43,11 @@ type handler struct {
|
||||
settingsManager settings.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
accountsHandler := newHandler(accountManager, settingsManager)
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Update, accountsHandler.updateAccount)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Delete, accountsHandler.deleteAccount)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/accounts", permissionsManager.WithPermission(modules.Accounts, operations.Read, accountsHandler.getAllAccounts)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new handler HTTP handler
|
||||
@@ -99,7 +102,7 @@ func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID st
|
||||
}
|
||||
|
||||
func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error {
|
||||
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "")
|
||||
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "", true)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "get peer count: %v", err)
|
||||
}
|
||||
@@ -136,34 +139,26 @@ func calculateRequiredAddresses(peerCount int) int64 {
|
||||
}
|
||||
|
||||
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
|
||||
settings, err := h.settingsManager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID)
|
||||
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, settings, meta, onboarding)
|
||||
resp := toAccountResponse(userAuth.AccountId, settings, meta, onboarding)
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
@@ -233,24 +228,15 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
|
||||
}
|
||||
|
||||
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
_, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
accountID := vars["accountId"]
|
||||
if len(accountID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
|
||||
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
accountID := mux.Vars(r)["accountId"]
|
||||
if accountID != userAuth.AccountId {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiAccountsAccountIdJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -267,7 +253,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
|
||||
return
|
||||
}
|
||||
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil {
|
||||
if err := h.validateNetworkRange(r.Context(), accountID, userAuth.UserId, prefix); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@@ -282,19 +268,19 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding)
|
||||
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userAuth.UserId, onboarding)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userAuth.UserId, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -306,21 +292,14 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// deleteAccount is a HTTP DELETE handler to delete an account
|
||||
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
accountID := mux.Vars(r)["accountId"]
|
||||
if accountID != userAuth.AccountId {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetAccountID := vars["accountId"]
|
||||
if len(targetAccountID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
|
||||
err := h.accountManager.DeleteAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -290,8 +291,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")
|
||||
router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT")
|
||||
router.HandleFunc("/api/accounts", permissions.WrapHandler(handler.getAllAccounts)).Methods("GET")
|
||||
router.HandleFunc("/api/accounts/{accountId}", permissions.WrapHandler(handler.updateAccount)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -5,11 +5,13 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -19,15 +21,15 @@ type dnsSettingsHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
addDNSSettingEndpoint(accountManager, router)
|
||||
addDNSNameserversEndpoint(accountManager, router)
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
addDNSSettingEndpoint(accountManager, router, permissionsManager)
|
||||
addDNSNameserversEndpoint(accountManager, router, permissionsManager)
|
||||
}
|
||||
|
||||
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
dnsSettingsHandler := newDNSSettingsHandler(accountManager)
|
||||
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Read, dnsSettingsHandler.getDNSSettings)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Update, dnsSettingsHandler.updateDNSSettings)).Methods("PUT", "OPTIONS")
|
||||
}
|
||||
|
||||
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
|
||||
@@ -36,17 +38,8 @@ func newDNSSettingsHandler(accountManager account.Manager) *dnsSettingsHandler {
|
||||
}
|
||||
|
||||
// getDNSSettings returns the DNS settings for the account
|
||||
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
|
||||
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -60,17 +53,9 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
// updateDNSSettings handles update to DNS settings of an account
|
||||
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PutApiDnsSettingsJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -80,7 +65,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re
|
||||
DisabledManagementGroups: req.DisabledManagementGroups,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
|
||||
err = h.accountManager.SaveDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId, updateDNSSettings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
@@ -115,8 +116,8 @@ func TestDNSSettingsHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
|
||||
router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT")
|
||||
router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.getDNSSettings)).Methods("GET")
|
||||
router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.updateDNSSettings)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -21,13 +23,13 @@ type nameserversHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
nameserversHandler := newNameserversHandler(accountManager)
|
||||
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getAllNameservers)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Create, nameserversHandler.createNameserverGroup)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Update, nameserversHandler.updateNameserverGroup)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getNameserverGroup)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Delete, nameserversHandler.deleteNameserverGroup)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newNameserversHandler returns a new instance of nameserversHandler handler
|
||||
@@ -36,17 +38,8 @@ func newNameserversHandler(accountManager account.Manager) *nameserversHandler {
|
||||
}
|
||||
|
||||
// getAllNameservers returns the list of nameserver groups for the account
|
||||
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
|
||||
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -61,17 +54,9 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re
|
||||
}
|
||||
|
||||
// createNameserverGroup handles nameserver group creation request
|
||||
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PostApiDnsNameserversJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -83,7 +68,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), userAuth.AccountId, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userAuth.UserId, req.SearchDomainsEnabled)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -95,15 +80,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// updateNameserverGroup handles update to a nameserver group identified by a given ID
|
||||
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
@@ -111,7 +88,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -135,7 +112,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
|
||||
err = h.accountManager.SaveNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, updatedNSGroup)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -147,22 +124,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// deleteNameserverGroup handles nameserver group deletion request
|
||||
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
|
||||
err := h.accountManager.DeleteNameServerGroup(r.Context(), userAuth.AccountId, nsGroupID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -172,22 +141,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// getNameserverGroup handles a nameserver group Get request identified by ID
|
||||
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, nsGroupID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
@@ -201,10 +202,10 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")
|
||||
router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.getNameserverGroup)).Methods("GET")
|
||||
router.HandleFunc("/api/dns/nameservers", permissions.WrapHandler(p.createNameserverGroup)).Methods("POST")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.deleteNameserverGroup)).Methods("DELETE")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.updateNameserverGroup)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -5,11 +5,13 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -19,10 +21,10 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
eventsHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events/audit", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events/audit", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new events handler
|
||||
@@ -31,17 +33,8 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllEvents list of the given account
|
||||
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
|
||||
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
@@ -196,7 +197,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")
|
||||
router.HandleFunc("/api/events/", permissions.WrapHandler(handler.getAllEvents)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -7,11 +7,13 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -19,46 +21,45 @@ import (
|
||||
|
||||
// handler is a handler that returns groups of the account
|
||||
type handler struct {
|
||||
accountManager account.Manager
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
groupsHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS")
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
groupsHandler := newHandler(accountManager, permissionsManager)
|
||||
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getGroup)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Delete, groupsHandler.deleteGroup)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new groups handler
|
||||
func newHandler(accountManager account.Manager) *handler {
|
||||
func newHandler(accountManager account.Manager, permissionsManager permissions.Manager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
// getAllGroups list for the account
|
||||
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) canReadPeers(r *http.Request, userAuth *auth.UserAuth) bool {
|
||||
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read)
|
||||
return err == nil && allowed
|
||||
}
|
||||
|
||||
// getAllGroups list for the account
|
||||
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
// Check if filtering by name
|
||||
groupName := r.URL.Query().Get("name")
|
||||
if groupName != "" {
|
||||
// Get single group by name
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -71,13 +72,13 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Get all groups
|
||||
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
groups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -92,15 +93,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// updateGroup handles update to a group identified by a given ID
|
||||
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
groupID, ok := vars["groupId"]
|
||||
if !ok {
|
||||
@@ -112,13 +105,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||
existingGroup, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -166,13 +159,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
IntegrationReference: existingGroup.IntegrationReference,
|
||||
}
|
||||
|
||||
if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
|
||||
if err := h.accountManager.UpdateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, userAuth.AccountId, err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -182,17 +175,9 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createGroup handles group creation request
|
||||
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PostApiGroupsJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -226,13 +211,13 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||
Issued: types.GroupIssuedAPI,
|
||||
}
|
||||
|
||||
err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group)
|
||||
err = h.accountManager.CreateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -242,22 +227,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// deleteGroup handles group deletion request
|
||||
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
|
||||
err := h.accountManager.DeleteGroup(r.Context(), userAuth.AccountId, userAuth.UserId, groupID)
|
||||
if err != nil {
|
||||
wrappedErr, ok := err.(interface{ Unwrap() []error })
|
||||
if ok && len(wrappedErr.Unwrap()) > 0 {
|
||||
@@ -273,34 +250,26 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getGroup returns a group
|
||||
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||
group, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
|
||||
|
||||
}
|
||||
|
||||
func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group {
|
||||
|
||||
@@ -13,10 +13,14 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
@@ -33,8 +37,18 @@ var TestPeers = map[string]*nbpeer.Peer{
|
||||
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
||||
}
|
||||
|
||||
func initGroupTestData(initGroups ...*types.Group) *handler {
|
||||
func initGroupTestData(t *testing.T, initGroups ...*types.Group) *handler {
|
||||
t.Helper()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
permissionsManagerMock.EXPECT().
|
||||
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
|
||||
Return(true, nil).
|
||||
AnyTimes()
|
||||
|
||||
return &handler{
|
||||
permissionsManager: permissionsManagerMock,
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error {
|
||||
if !strings.HasPrefix(group.ID, "id-") {
|
||||
@@ -71,14 +85,14 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
|
||||
|
||||
return groups, nil
|
||||
},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
|
||||
if groupName == "All" {
|
||||
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "unknown group name")
|
||||
},
|
||||
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) {
|
||||
return maps.Values(TestPeers), nil
|
||||
},
|
||||
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
|
||||
@@ -128,7 +142,7 @@ func TestGetGroup(t *testing.T) {
|
||||
Name: "Group",
|
||||
}
|
||||
|
||||
p := initGroupTestData(group)
|
||||
p := initGroupTestData(t, group)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -141,7 +155,7 @@ func TestGetGroup(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET")
|
||||
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.getGroup)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -254,7 +268,7 @@ func TestWriteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := initGroupTestData()
|
||||
p := initGroupTestData(t)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -267,8 +281,8 @@ func TestWriteGroup(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups", p.createGroup).Methods("POST")
|
||||
router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT")
|
||||
router.HandleFunc("/api/groups", permissions.WrapHandler(p.createGroup)).Methods("POST")
|
||||
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.updateGroup)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -332,7 +346,7 @@ func TestGetAllGroups(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := initGroupTestData()
|
||||
p := initGroupTestData(t)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -345,7 +359,7 @@ func TestGetAllGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET")
|
||||
router.HandleFunc("/api/groups", permissions.WrapHandler(p.getAllGroups)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -414,7 +428,7 @@ func TestDeleteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := initGroupTestData()
|
||||
p := initGroupTestData(t)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -426,7 +440,7 @@ func TestDeleteGroup(t *testing.T) {
|
||||
AccountId: "test_id",
|
||||
})
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE")
|
||||
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.deleteGroup)).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -6,9 +6,12 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -20,13 +23,13 @@ type handler struct {
|
||||
}
|
||||
|
||||
// AddEndpoints registers identity provider endpoints
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
h := newHandler(accountManager)
|
||||
router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getAllIdentityProviders)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Create, h.createIdentityProvider)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getIdentityProvider)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Update, h.updateIdentityProvider)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Delete, h.deleteIdentityProvider)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newHandler(accountManager account.Manager) *handler {
|
||||
@@ -36,16 +39,8 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllIdentityProviders returns all identity providers for the account
|
||||
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID)
|
||||
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
providers, err := h.accountManager.GetIdentityProviders(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -60,15 +55,7 @@ func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// getIdentityProvider returns a specific identity provider
|
||||
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
@@ -76,7 +63,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID)
|
||||
provider, err := h.accountManager.GetIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -86,15 +73,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createIdentityProvider creates a new identity provider
|
||||
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.IdentityProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -103,7 +82,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
idp := fromAPIRequest(&req)
|
||||
|
||||
created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp)
|
||||
created, err := h.accountManager.CreateIdentityProvider(r.Context(), userAuth.AccountId, userAuth.UserId, idp)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -113,15 +92,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// updateIdentityProvider updates an existing identity provider
|
||||
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
@@ -137,7 +108,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
idp := fromAPIRequest(&req)
|
||||
|
||||
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp)
|
||||
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId, idp)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -147,15 +118,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// deleteIdentityProvider deletes an identity provider
|
||||
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
@@ -163,7 +126,7 @@ func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil {
|
||||
if err := h.accountManager.DeleteIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -120,7 +121,7 @@ func TestGetAllIdentityProviders(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET")
|
||||
router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.getAllIdentityProviders)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -180,7 +181,7 @@ func TestGetIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET")
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.getIdentityProvider)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -242,7 +243,7 @@ func TestCreateIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST")
|
||||
router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.createIdentityProvider)).Methods("POST")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -328,7 +329,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT")
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.updateIdentityProvider)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -388,7 +389,7 @@ func TestDeleteIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE")
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.deleteIdentityProvider)).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -7,7 +7,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -29,12 +33,12 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
}
|
||||
|
||||
// AddVersionEndpoint registers the authenticated version endpoint.
|
||||
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
h := &handler{
|
||||
instanceManager: instanceManager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/instance/version", permissionsManager.WithPermission(modules.Settings, operations.Read, h.getVersionInfo)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// getInstanceStatus returns the instance status including whether setup is required.
|
||||
@@ -77,7 +81,7 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// getVersionInfo returns version information for NetBird components.
|
||||
// This endpoint requires authentication.
|
||||
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)
|
||||
|
||||
@@ -10,12 +10,17 @@ import (
|
||||
"net/mail"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -295,8 +300,15 @@ func TestSetup_ManagerError(t *testing.T) {
|
||||
|
||||
func TestGetVersionInfo_Success(t *testing.T) {
|
||||
manager := &mockInstanceManager{}
|
||||
ctrl := gomock.NewController(t)
|
||||
permissionsManager := permissions.NewMockManager(ctrl)
|
||||
permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
handler(w, r, &auth.UserAuth{})
|
||||
}
|
||||
}).AnyTimes()
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router)
|
||||
AddVersionEndpoint(manager, router, permissionsManager)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -323,8 +335,15 @@ func TestGetVersionInfo_Error(t *testing.T) {
|
||||
return nil, errors.New("failed to fetch versions")
|
||||
},
|
||||
}
|
||||
ctrl := gomock.NewController(t)
|
||||
permissionsManager := permissions.NewMockManager(ctrl)
|
||||
permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
handler(w, r, &auth.UserAuth{})
|
||||
}
|
||||
}).AnyTimes()
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router)
|
||||
AddVersionEndpoint(manager, router, permissionsManager)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
@@ -18,6 +20,7 @@ import (
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -33,16 +36,16 @@ type handler struct {
|
||||
groupsManager groups.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, router *mux.Router) {
|
||||
addRouterEndpoints(routerManager, router)
|
||||
addResourceEndpoints(resourceManager, groupsManager, router)
|
||||
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
addRouterEndpoints(routerManager, permissionsManager, router)
|
||||
addResourceEndpoints(resourceManager, groupsManager, permissionsManager, router)
|
||||
|
||||
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager)
|
||||
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getAllNetworks)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Create, networksHandler.createNetwork)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getNetwork)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Update, networksHandler.updateNetwork)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, networksHandler.deleteNetwork)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager) *handler {
|
||||
@@ -55,40 +58,32 @@ func newHandler(networksManager networks.Manager, resourceManager resources.Mana
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networks, err := h.networksManager.GetAllNetworks(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
|
||||
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID)
|
||||
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID)
|
||||
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -97,16 +92,9 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account))
|
||||
}
|
||||
|
||||
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.NetworkRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -115,14 +103,14 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
network := &types.Network{}
|
||||
network.FromAPIRequest(&req)
|
||||
|
||||
network.AccountID = accountID
|
||||
network, err = h.networksManager.CreateNetwork(r.Context(), userID, network)
|
||||
network.AccountID = userAuth.AccountId
|
||||
network, err = h.networksManager.CreateNetwork(r.Context(), userAuth.UserId, network)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -133,14 +121,7 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs))
|
||||
}
|
||||
|
||||
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
@@ -148,19 +129,19 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID)
|
||||
network, err := h.networksManager.GetNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -171,14 +152,7 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
|
||||
}
|
||||
|
||||
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
@@ -187,7 +161,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req api.NetworkRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -197,20 +171,20 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
network.FromAPIRequest(&req)
|
||||
|
||||
network.ID = networkID
|
||||
network.AccountID = accountID
|
||||
network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network)
|
||||
network.AccountID = userAuth.AccountId
|
||||
network, err = h.networksManager.UpdateNetwork(r.Context(), userAuth.UserId, network)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -221,14 +195,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
|
||||
}
|
||||
|
||||
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
@@ -236,7 +203,7 @@ func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID)
|
||||
err := h.networksManager.DeleteNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -6,10 +6,13 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -19,14 +22,14 @@ type resourceHandler struct {
|
||||
groupsManager groups.Manager
|
||||
}
|
||||
|
||||
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) {
|
||||
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
resourceHandler := newResourceHandler(resourcesManager, groupsManager)
|
||||
router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/networks/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInAccount)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInNetwork)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Create, resourceHandler.createResource)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getResource)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Update, resourceHandler.updateResource)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, resourceHandler.deleteResource)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler {
|
||||
@@ -36,22 +39,15 @@ func newResourceHandler(resourceManager resources.Manager, groupsManager groups.
|
||||
}
|
||||
}
|
||||
|
||||
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID)
|
||||
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -66,22 +62,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resourcesResponse)
|
||||
}
|
||||
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -97,17 +85,9 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt
|
||||
util.WriteJSONObject(r.Context(), w, resourcesResponse)
|
||||
}
|
||||
|
||||
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.NetworkResourceRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -117,14 +97,14 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
|
||||
resource.FromAPIRequest(&req)
|
||||
|
||||
resource.NetworkID = mux.Vars(r)["networkId"]
|
||||
resource.AccountID = accountID
|
||||
resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource)
|
||||
resource.AccountID = userAuth.AccountId
|
||||
resource, err = h.resourceManager.CreateResource(r.Context(), userAuth.UserId, resource)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -135,23 +115,16 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
|
||||
}
|
||||
|
||||
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resourceID := mux.Vars(r)["resourceId"]
|
||||
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID)
|
||||
resource, err := h.resourceManager.GetResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -162,16 +135,9 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
|
||||
}
|
||||
|
||||
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.NetworkResourceRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -182,14 +148,14 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
resource.ID = mux.Vars(r)["resourceId"]
|
||||
resource.NetworkID = mux.Vars(r)["networkId"]
|
||||
resource.AccountID = accountID
|
||||
resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource)
|
||||
resource.AccountID = userAuth.AccountId
|
||||
resource, err = h.resourceManager.UpdateResource(r.Context(), userAuth.UserId, resource)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -200,17 +166,10 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
|
||||
}
|
||||
|
||||
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resourceID := mux.Vars(r)["resourceId"]
|
||||
err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID)
|
||||
err := h.resourceManager.DeleteResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -6,9 +6,12 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -17,14 +20,14 @@ type routersHandler struct {
|
||||
routersManager routers.Manager
|
||||
}
|
||||
|
||||
func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) {
|
||||
func addRouterEndpoints(routersManager routers.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
routersHandler := newRoutersHandler(routersManager)
|
||||
router.HandleFunc("/networks/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getNetworkRouters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/networks/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getAllRouters)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getNetworkRouters)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Create, routersHandler.createRouter)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getRouter)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Update, routersHandler.updateRouter)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, routersHandler.deleteRouter)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newRoutersHandler(routersManager routers.Manager) *routersHandler {
|
||||
@@ -33,16 +36,8 @@ func newRoutersHandler(routersManager routers.Manager) *routersHandler {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
|
||||
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -58,17 +53,9 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, routersResponse)
|
||||
}
|
||||
|
||||
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID)
|
||||
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -82,18 +69,10 @@ func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Reques
|
||||
util.WriteJSONObject(r.Context(), w, routersResponse)
|
||||
}
|
||||
|
||||
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
var req api.NetworkRouterRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -103,7 +82,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
|
||||
router.FromAPIRequest(&req)
|
||||
|
||||
router.NetworkID = networkID
|
||||
router.AccountID = accountID
|
||||
router.AccountID = userAuth.AccountId
|
||||
router.Enabled = true
|
||||
|
||||
if err := router.Validate(); err != nil {
|
||||
@@ -111,7 +90,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
router, err = h.routersManager.CreateRouter(r.Context(), userID, router)
|
||||
router, err = h.routersManager.CreateRouter(r.Context(), userAuth.UserId, router)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -120,18 +99,10 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routerID := mux.Vars(r)["routerId"]
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID)
|
||||
router, err := h.routersManager.GetRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -140,17 +111,9 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.NetworkRouterRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -161,14 +124,14 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
router.NetworkID = mux.Vars(r)["networkId"]
|
||||
router.ID = mux.Vars(r)["routerId"]
|
||||
router.AccountID = accountID
|
||||
router.AccountID = userAuth.AccountId
|
||||
|
||||
if err := router.Validate(); err != nil {
|
||||
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
router, err = h.routersManager.UpdateRouter(r.Context(), userID, router)
|
||||
router, err = h.routersManager.UpdateRouter(r.Context(), userAuth.UserId, router)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -177,17 +140,10 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routerID := mux.Vars(r)["routerId"]
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID)
|
||||
err := h.routersManager.DeleteRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package peers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -12,15 +11,15 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -35,14 +34,15 @@ type Handler struct {
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
|
||||
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
|
||||
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAllPeers)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetPeer)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Update, peersHandler.UpdatePeer)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Delete, peersHandler.DeletePeer)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/accessible-peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAccessiblePeers)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/temporary-access", permissionsManager.WithPermission(modules.Peers, operations.Create, peersHandler.CreateTemporaryAccess)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.ListJobs)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Create, peersHandler.CreateJob)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.GetJob)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// NewHandler creates a new peers Handler
|
||||
@@ -54,14 +54,7 @@ func NewHandler(accountManager account.Manager, networkMapController network_map
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
|
||||
@@ -73,37 +66,30 @@ func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
if err := h.accountManager.CreatePeerJob(r.Context(), userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
|
||||
jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID)
|
||||
jobs, err := h.accountManager.GetAllPeerJobs(r.Context(), userAuth.AccountId, userAuth.UserId, peerID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -111,79 +97,88 @@ func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
|
||||
for _, job := range jobs {
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
respBody = append(respBody, resp)
|
||||
}
|
||||
|
||||
util.WriteJSONObject(ctx, w, respBody)
|
||||
util.WriteJSONObject(r.Context(), w, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
jobID := vars["jobId"]
|
||||
|
||||
job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID)
|
||||
job, err := h.accountManager.GetPeerJobByID(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, jobID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
|
||||
// GetPeer handles GET request for a single peer
|
||||
func (h *Handler) GetPeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if peer.ProxyMeta.Embedded {
|
||||
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
|
||||
grps, _ := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peerID)
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
|
||||
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
reason := invalidPeers[peer.ID]
|
||||
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
}
|
||||
|
||||
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
// UpdatePeer handles PUT request to update a peer
|
||||
func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PeerRequest{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@@ -192,11 +187,10 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
}
|
||||
|
||||
update := &nbpeer.Peer{
|
||||
ID: peerID,
|
||||
SSHEnabled: req.SshEnabled,
|
||||
Name: req.Name,
|
||||
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
||||
|
||||
ID: peerID,
|
||||
SSHEnabled: req.SshEnabled,
|
||||
Name: req.Name,
|
||||
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
||||
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
|
||||
}
|
||||
|
||||
@@ -210,41 +204,41 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
if req.Ip != nil {
|
||||
addr, err := netip.ParseAddr(*req.Ip)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
if err = h.accountManager.UpdatePeerIP(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, addr); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
|
||||
peer, err := h.accountManager.UpdatePeer(r.Context(), userAuth.AccountId, userAuth.UserId, update)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peer.ID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
|
||||
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -254,25 +248,8 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
}
|
||||
|
||||
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(ctx, w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||
func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
// DeletePeer handles DELETE request to delete a peer
|
||||
func (h *Handler) DeletePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
@@ -280,48 +257,34 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodDelete:
|
||||
h.deletePeer(r.Context(), accountID, userID, peerID, w)
|
||||
err := h.accountManager.DeletePeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to delete peer: %v", err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
case http.MethodGet:
|
||||
h.getPeer(r.Context(), accountID, peerID, userID, w)
|
||||
return
|
||||
case http.MethodPut:
|
||||
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
|
||||
return
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// GetAllPeers returns a list of all peers associated with a provided account
|
||||
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nameFilter := r.URL.Query().Get("name")
|
||||
ipFilter := r.URL.Query().Get("ip")
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, nameFilter, ipFilter)
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, nameFilter, ipFilter, true)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator)
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, _ := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
@@ -332,7 +295,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
|
||||
}
|
||||
|
||||
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
@@ -356,15 +319,7 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersM
|
||||
}
|
||||
|
||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
@@ -372,25 +327,22 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed && !userAuth.IsChild {
|
||||
// Check if user is an admin/service user through their role
|
||||
isAdmin := user.Role == types.UserRoleAdmin || user.Role == types.UserRoleOwner
|
||||
|
||||
if !isAdmin && !user.IsServiceUser && !userAuth.IsChild {
|
||||
if account.Settings.RegularUsersViewBlocked {
|
||||
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
|
||||
return
|
||||
@@ -408,7 +360,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
@@ -417,18 +369,12 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||
|
||||
netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
@@ -437,7 +383,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
var req api.PeerTemporaryAccessRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
|
||||
@@ -19,11 +19,11 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -174,7 +174,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
},
|
||||
GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||
GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) {
|
||||
return peers, nil
|
||||
},
|
||||
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
|
||||
@@ -307,9 +307,9 @@ func TestGetPeers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
|
||||
router.HandleFunc("/api/peers/", permissions.WrapHandler(p.GetAllPeers)).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.GetPeer)).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -498,7 +498,7 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}/accessible-peers", permissions.WrapHandler(p.GetAccessiblePeers)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -582,7 +582,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT")
|
||||
router.HandleFunc("/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT")
|
||||
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
|
||||
@@ -14,12 +14,12 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -121,7 +121,7 @@ func TestGetCitiesByCountry(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET")
|
||||
router.HandleFunc("/api/locations/countries/{country}/cities", permissions.WrapHandler(geolocationHandler.getCitiesByCountry)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -214,7 +214,7 @@ func TestGetAllCountries(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET")
|
||||
router.HandleFunc("/api/locations/countries", permissions.WrapHandler(geolocationHandler.getAllCountries)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -6,12 +6,12 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -30,8 +30,8 @@ type geolocationsHandler struct {
|
||||
|
||||
func AddLocationsEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, permissionsManager)
|
||||
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/locations/countries", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getAllCountries)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/locations/countries/{country}/cities", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getCitiesByCountry)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newGeolocationsHandlerHandler creates a new Geolocations handler
|
||||
@@ -44,12 +44,7 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
|
||||
}
|
||||
|
||||
// getAllCountries retrieves a list of all countries
|
||||
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if l.geolocationManager == nil {
|
||||
// TODO: update error message to include geo db self hosted doc link when ready
|
||||
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||
@@ -70,12 +65,7 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
|
||||
// getCitiesByCountry retrieves a list of cities based on the given country code
|
||||
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
countryCode := vars["country"]
|
||||
if !countryCodeRegex.MatchString(countryCode) {
|
||||
@@ -102,27 +92,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
|
||||
util.WriteJSONObject(r.Context(), w, cities)
|
||||
}
|
||||
|
||||
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toCountryResponse(country geolocation.Country) api.Country {
|
||||
return api.Country{
|
||||
CountryName: country.CountryName,
|
||||
|
||||
@@ -7,10 +7,13 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -21,13 +24,13 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
policiesHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getAllPolicies)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Create, policiesHandler.createPolicy)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Update, policiesHandler.updatePolicy)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getPolicy)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, policiesHandler.deletePolicy)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new policies handler
|
||||
@@ -38,22 +41,14 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllPolicies list for the account
|
||||
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
listPolicies, err := h.accountManager.ListPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -73,15 +68,7 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// updatePolicy handles update to a policy identified by a given ID
|
||||
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@@ -89,26 +76,18 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||
_, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
h.savePolicy(w, r, accountID, userID, policyID, false)
|
||||
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, policyID, false)
|
||||
}
|
||||
|
||||
// createPolicy handles policy creation request
|
||||
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
h.savePolicy(w, r, accountID, userID, "", true)
|
||||
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, "", true)
|
||||
}
|
||||
|
||||
// savePolicy handles policy creation and update
|
||||
@@ -303,14 +282,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
}
|
||||
|
||||
// deletePolicy handles policy deletion request
|
||||
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@@ -318,7 +290,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
|
||||
if err := h.accountManager.DeletePolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@@ -327,15 +299,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getPolicy handles a group Get request identified by ID
|
||||
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@@ -343,13 +307,13 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||
policy, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -111,7 +112,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET")
|
||||
router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.getPolicy)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -275,8 +276,8 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/policies", p.createPolicy).Methods("POST")
|
||||
router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT")
|
||||
router.HandleFunc("/api/policies", permissions.WrapHandler(p.createPolicy)).Methods("POST")
|
||||
router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.updatePolicy)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -6,10 +6,13 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -21,13 +24,13 @@ type postureChecksHandler struct {
|
||||
geolocationManager geolocation.Geolocation
|
||||
}
|
||||
|
||||
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
|
||||
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager)
|
||||
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getAllPostureChecks)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Create, postureCheckHandler.createPostureCheck)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Update, postureCheckHandler.updatePostureCheck)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getPostureCheck)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, postureCheckHandler.deletePostureCheck)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newPostureChecksHandler creates a new PostureChecks handler
|
||||
@@ -39,15 +42,8 @@ func newPostureChecksHandler(accountManager account.Manager, geolocationManager
|
||||
}
|
||||
|
||||
// getAllPostureChecks list for the account
|
||||
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
|
||||
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -62,15 +58,7 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// updatePostureCheck handles update to a posture check identified by a given ID
|
||||
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
@@ -78,37 +66,22 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||
_, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
p.savePostureChecks(w, r, accountID, userID, postureChecksID, false)
|
||||
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, postureChecksID, false)
|
||||
}
|
||||
|
||||
// createPostureCheck handles posture check creation request
|
||||
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
p.savePostureChecks(w, r, accountID, userID, "", true)
|
||||
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, "", true)
|
||||
}
|
||||
|
||||
// getPostureCheck handles a posture check Get request identified by ID
|
||||
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
@@ -116,7 +89,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
|
||||
return
|
||||
}
|
||||
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -126,14 +99,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
|
||||
}
|
||||
|
||||
// deletePostureCheck handles posture check deletion request
|
||||
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
@@ -141,7 +107,7 @@ func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http
|
||||
return
|
||||
}
|
||||
|
||||
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
|
||||
if err := p.accountManager.DeletePostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
@@ -183,7 +184,7 @@ func TestGetPostureCheck(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET")
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(p.getPostureCheck)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -841,8 +842,8 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST")
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT")
|
||||
router.HandleFunc("/api/posture-checks", permissions.WrapHandler(defaultHandler.createPostureCheck)).Methods("POST")
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(defaultHandler.updatePostureCheck)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -8,9 +8,12 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -26,13 +29,13 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
routesHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getAllRoutes)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Create, routesHandler.createRoute)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Update, routesHandler.updateRoute)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getRoute)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Delete, routesHandler.deleteRoute)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler returns a new instance of routes handler
|
||||
@@ -43,16 +46,8 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllRoutes returns the list of routes for the account
|
||||
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
|
||||
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routes, err := h.accountManager.ListRoutes(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -71,17 +66,9 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createRoute handles route creation request
|
||||
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PostApiRoutesJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -134,8 +121,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
|
||||
skipAutoApply = false
|
||||
}
|
||||
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply)
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), userAuth.AccountId, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userAuth.UserId, req.KeepRoute, skipAutoApply)
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -185,14 +172,7 @@ func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *
|
||||
}
|
||||
|
||||
// updateRoute handles update to a route identified by a given ID
|
||||
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
routeID := vars["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
@@ -200,7 +180,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
_, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -271,7 +251,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
newRoute.AccessControlGroups = *req.AccessControlGroups
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
|
||||
err = h.accountManager.SaveRoute(r.Context(), userAuth.AccountId, userAuth.UserId, newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -287,21 +267,14 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// deleteRoute handles route deletion request
|
||||
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
err := h.accountManager.DeleteRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -311,22 +284,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getRoute handles a route Get request identified by ID
|
||||
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
@@ -501,10 +502,10 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET")
|
||||
router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE")
|
||||
router.HandleFunc("/api/routes", p.createRoute).Methods("POST")
|
||||
router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT")
|
||||
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.getRoute)).Methods("GET")
|
||||
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.deleteRoute)).Methods("DELETE")
|
||||
router.HandleFunc("/api/routes", permissions.WrapHandler(p.createRoute)).Methods("POST")
|
||||
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.updateRoute)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -8,9 +8,12 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -21,13 +24,13 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
keysHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getAllSetupKeys)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Create, keysHandler.createSetupKey)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getSetupKey)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Update, keysHandler.updateSetupKey)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Delete, keysHandler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new setup key handler
|
||||
@@ -38,16 +41,9 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// createSetupKey is a POST requests that creates a new SetupKey
|
||||
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
req := &api.PostApiSetupKeysJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -85,8 +81,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
allowExtraDNSLabels = *req.AllowExtraDnsLabels
|
||||
}
|
||||
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, userID, ephemeral, allowExtraDNSLabels)
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), userAuth.AccountId, req.Name, types.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, userAuth.UserId, ephemeral, allowExtraDNSLabels)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -100,14 +96,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getSetupKey is a GET request to get a SetupKey by ID
|
||||
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
@@ -115,7 +104,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
|
||||
key, err := h.accountManager.GetSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -125,14 +114,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// updateSetupKey is a PUT request to update server.SetupKey
|
||||
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
@@ -141,7 +123,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
req := &api.PutApiSetupKeysKeyIdJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -157,7 +139,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
newKey.Revoked = req.Revoked
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), userAuth.AccountId, newKey, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -166,15 +148,8 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getAllSetupKeys is a GET request that returns a list of SetupKey
|
||||
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
|
||||
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -188,14 +163,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||
}
|
||||
|
||||
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
@@ -203,7 +171,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID)
|
||||
err := h.accountManager.DeleteSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -171,11 +172,11 @@ func TestSetupKeysHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.getAllSetupKeys)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.createSetupKey)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.getSetupKey)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.updateSetupKey)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -9,10 +9,13 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -55,14 +58,14 @@ type invitesHandler struct {
|
||||
}
|
||||
|
||||
// AddInvitesEndpoints registers invite-related endpoints
|
||||
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
h := &invitesHandler{accountManager: accountManager}
|
||||
|
||||
// Authenticated endpoints (require admin)
|
||||
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Read, h.listInvites)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Create, h.createInvite)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}", permissionsManager.WithPermission(modules.Users, operations.Delete, h.deleteInvite)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}/regenerate", permissionsManager.WithPermission(modules.Users, operations.Update, h.regenerateInvite)).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
|
||||
@@ -79,14 +82,7 @@ func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Route
|
||||
}
|
||||
|
||||
// listInvites handles GET /api/users/invites
|
||||
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -102,14 +98,7 @@ func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createInvite handles POST /api/users/invites
|
||||
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.UserInviteCreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -191,18 +180,12 @@ func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
|
||||
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
@@ -238,14 +221,7 @@ func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// deleteInvite handles DELETE /api/users/invites/{inviteId}
|
||||
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
@@ -253,7 +229,7 @@ func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
|
||||
err := h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -110,7 +110,11 @@ func TestListInvites(t *testing.T) {
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.listInvites(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.listInvites(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -235,7 +239,11 @@ func TestCreateInvite(t *testing.T) {
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.createInvite(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.createInvite(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -573,7 +581,11 @@ func TestRegenerateInvite(t *testing.T) {
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.regenerateInvite(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.regenerateInvite(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -651,7 +663,11 @@ func TestDeleteInvite(t *testing.T) {
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.deleteInvite(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.deleteInvite(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
})
|
||||
|
||||
@@ -6,9 +6,12 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -19,12 +22,12 @@ type patHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
tokenHandler := newPATsHandler(accountManager)
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getAllTokens)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Create, tokenHandler.createToken)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getToken)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Delete, tokenHandler.deleteToken)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newPATsHandler creates a new patHandler HTTP handler
|
||||
@@ -35,22 +38,15 @@ func newPATsHandler(accountManager account.Manager) *patHandler {
|
||||
}
|
||||
|
||||
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -65,14 +61,7 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getToken is HTTP GET handler that returns a personal access token for the given user
|
||||
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -86,7 +75,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -96,14 +85,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createToken is HTTP POST handler that creates a personal access token for the given user
|
||||
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -112,13 +94,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req api.PostApiUsersUserIdTokensJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.Name, req.ExpiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -128,14 +110,7 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -149,7 +124,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
err := h.accountManager.DeletePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -181,10 +182,10 @@ func TestTokenHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.getAllTokens)).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.getToken)).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.createToken)).Methods("POST")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.deleteToken)).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -8,14 +8,16 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
|
||||
// handler is a handler that returns users of the account
|
||||
@@ -23,18 +25,18 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
userHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, router)
|
||||
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getAllUsers, userHandler.getOwnUser)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/current", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getCurrentUser, userHandler.getCurrentUserFallback)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.updateUser)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.deleteUser)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.createUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.inviteUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/approve", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.approveUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/reject", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.rejectUser)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/password", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.changePassword)).Methods("PUT", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, router, permissionsManager)
|
||||
}
|
||||
|
||||
// newHandler creates a new UsersHandler HTTP handler
|
||||
@@ -45,19 +47,12 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// updateUser is a PUT requests to update User data
|
||||
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPut {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -71,6 +66,11 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if existingUser.AccountID != userAuth.AccountId {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "user not found"), w)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PutApiUsersUserIdJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@@ -89,7 +89,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.User{
|
||||
Id: targetUserID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
@@ -102,23 +102,16 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
|
||||
}
|
||||
|
||||
// deleteUser is a DELETE request to delete a user
|
||||
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -126,7 +119,7 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
|
||||
err := h.accountManager.DeleteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -136,21 +129,14 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createUser creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
req := &api.PostApiUsersJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -171,7 +157,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
name = *req.Name
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.UserInfo{
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: req.Role,
|
||||
@@ -183,25 +169,18 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
|
||||
}
|
||||
|
||||
// getAllUsers returns a list of users of the account this user belongs to.
|
||||
// It also gathers additional user data (like email and name) from the IDP manager.
|
||||
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -215,7 +194,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
continue
|
||||
}
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
users = append(users, toUserResponse(d, userAuth.UserId))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -226,7 +205,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if includeServiceUser == d.IsServiceUser {
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
users = append(users, toUserResponse(d, userAuth.UserId))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,19 +214,12 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// inviteUser resend invitations to users who haven't activated their accounts,
|
||||
// prior to the expiration period.
|
||||
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -255,7 +227,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
|
||||
err := h.accountManager.InviteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -264,19 +236,13 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth)
|
||||
user, err := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -356,7 +322,7 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
|
||||
}
|
||||
|
||||
// approveUser is a POST request to approve a user that is pending approval
|
||||
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -369,11 +335,6 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -385,7 +346,7 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// rejectUser is a DELETE request to reject a user that is pending approval
|
||||
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -398,12 +359,7 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
err := h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -421,7 +377,7 @@ type passwordChangeRequest struct {
|
||||
// changePassword is a PUT request to change user's password.
|
||||
// Only available when embedded IDP is enabled.
|
||||
// Users can only change their own password.
|
||||
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPut {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -434,19 +390,13 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req passwordChangeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
|
||||
err := h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -454,3 +404,39 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func (h *handler) getCurrentUserFallback(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.PermissionDenied {
|
||||
return false
|
||||
}
|
||||
|
||||
user, userErr := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
|
||||
if userErr != nil {
|
||||
util.WriteError(r.Context(), userErr, w)
|
||||
return true
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toUserWithPermissionsResponse(user, userAuth.UserId))
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *handler) getOwnUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.PermissionDenied {
|
||||
return false
|
||||
}
|
||||
|
||||
if r.URL.Query().Get("service_user") != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
user, userErr := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
|
||||
if userErr != nil {
|
||||
util.WriteError(r.Context(), userErr, w)
|
||||
return true
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, []*api.User{toUserResponse(user.UserInfo, userAuth.UserId)})
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -15,10 +15,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
roles2 "github.com/netbirdio/netbird/management/internals/modules/permissions/roles"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
@@ -38,6 +39,7 @@ var usersTestAccount = &types.Account{
|
||||
Users: map[string]*types.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
AccountID: existingAccountID,
|
||||
Role: "admin",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
@@ -45,6 +47,7 @@ var usersTestAccount = &types.Account{
|
||||
},
|
||||
regularUserID: {
|
||||
Id: regularUserID,
|
||||
AccountID: existingAccountID,
|
||||
Role: "user",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
@@ -52,6 +55,7 @@ var usersTestAccount = &types.Account{
|
||||
},
|
||||
serviceUserID: {
|
||||
Id: serviceUserID,
|
||||
AccountID: existingAccountID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
AutoGroups: []string{"group_1"},
|
||||
@@ -59,6 +63,7 @@ var usersTestAccount = &types.Account{
|
||||
},
|
||||
nonDeletableServiceUserID: {
|
||||
Id: nonDeletableServiceUserID,
|
||||
AccountID: existingAccountID,
|
||||
Role: "admin",
|
||||
IsServiceUser: true,
|
||||
NonDeletable: true,
|
||||
@@ -151,7 +156,7 @@ func initUsersTestData() *handler {
|
||||
NonDeletable: false,
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.Owner),
|
||||
Permissions: mergeRolePermissions(roles2.Owner),
|
||||
}, nil
|
||||
case "regular-user":
|
||||
return &users.UserInfoWithPermissions{
|
||||
@@ -165,7 +170,7 @@ func initUsersTestData() *handler {
|
||||
NonDeletable: false,
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
Permissions: mergeRolePermissions(roles2.User),
|
||||
}, nil
|
||||
|
||||
case "admin-user":
|
||||
@@ -181,7 +186,7 @@ func initUsersTestData() *handler {
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.Admin),
|
||||
Permissions: mergeRolePermissions(roles2.Admin),
|
||||
}, nil
|
||||
case "restricted-user":
|
||||
return &users.UserInfoWithPermissions{
|
||||
@@ -196,7 +201,7 @@ func initUsersTestData() *handler {
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
Permissions: mergeRolePermissions(roles2.User),
|
||||
Restricted: true,
|
||||
}, nil
|
||||
}
|
||||
@@ -232,7 +237,11 @@ func TestGetUsers(t *testing.T) {
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userHandler.getAllUsers(recorder, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.getAllUsers(recorder, req, userAuth)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -343,7 +352,7 @@ func TestUpdateUser(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
|
||||
router.HandleFunc("/api/users/{userId}", permissions.WrapHandler(userHandler.updateUser)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -439,7 +448,11 @@ func TestCreateUser(t *testing.T) {
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userHandler.createUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.createUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -490,7 +503,11 @@ func TestInviteUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.inviteUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.inviteUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -549,7 +566,11 @@ func TestDeleteUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.deleteUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.deleteUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -608,7 +629,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
Issued: ptr("api"),
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Owner)),
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Owner)),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -627,7 +648,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
Issued: ptr("api"),
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -646,7 +667,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
Issued: ptr("api"),
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Admin)),
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Admin)),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -666,7 +687,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
IsRestricted: true,
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -682,7 +703,11 @@ func TestCurrentUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.getCurrentUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: tc.requestAuth.UserId,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.getCurrentUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -702,8 +727,8 @@ func ptr[T any, PT *T](x T) PT {
|
||||
return &x
|
||||
}
|
||||
|
||||
func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
|
||||
permissions := roles.Permissions{}
|
||||
func mergeRolePermissions(role roles2.RolePermissions) roles2.Permissions {
|
||||
permissions := roles2.Permissions{}
|
||||
|
||||
for k := range modules.All {
|
||||
if rolePermissions, ok := role.Permissions[k]; ok {
|
||||
@@ -716,7 +741,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
|
||||
return permissions
|
||||
}
|
||||
|
||||
func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[string]bool {
|
||||
func stringifyPermissionsKeys(permissions roles2.Permissions) map[string]map[string]bool {
|
||||
modules := make(map[string]map[string]bool)
|
||||
for module, operations := range permissions {
|
||||
modules[string(module)] = make(map[string]bool)
|
||||
@@ -779,7 +804,7 @@ func TestApproveUserEndpoint(t *testing.T) {
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
|
||||
router.HandleFunc("/users/{userId}/approve", permissions.WrapHandler(handler.approveUser)).Methods("POST")
|
||||
|
||||
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
|
||||
require.NoError(t, err)
|
||||
@@ -837,7 +862,7 @@ func TestRejectUserEndpoint(t *testing.T) {
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE")
|
||||
router.HandleFunc("/users/{userId}/reject", permissions.WrapHandler(handler.rejectUser)).Methods("DELETE")
|
||||
|
||||
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
|
||||
require.NoError(t, err)
|
||||
@@ -928,7 +953,7 @@ func TestChangePasswordEndpoint(t *testing.T) {
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/password", handler.changePassword).Methods("PUT")
|
||||
router.HandleFunc("/users/{userId}/password", permissions.WrapHandler(handler.changePassword)).Methods("PUT")
|
||||
|
||||
reqPath := "/users/" + tc.targetUserID + "/password"
|
||||
req, err := http.NewRequest("PUT", reqPath, bytes.NewBufferString(tc.requestBody))
|
||||
@@ -967,7 +992,7 @@ func TestChangePasswordEndpoint_WrongMethod(t *testing.T) {
|
||||
req = nbcontext.SetUserAuthInRequest(req, userAuth)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.changePassword(rr, req)
|
||||
handler.changePassword(rr, req, &userAuth)
|
||||
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func Test_Accounts_GetAll(t *testing.T) {
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
@@ -233,6 +233,71 @@ func Test_Accounts_Update(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Accounts_Update_CrossAccountAttack(t *testing.T) {
|
||||
t.Run("Other user attempts to update testAccount via URL", func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(&api.AccountRequest{
|
||||
Settings: api.AccountSettings{
|
||||
PeerLoginExpirationEnabled: false,
|
||||
PeerLoginExpiration: 86400,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
// OtherUserId belongs to otherAccountId, but we target testAccountId in URL
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account update must be rejected")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Accounts_Delete(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, false},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, false},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Delete account", func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Accounts_Delete_CrossAccountAttack(t *testing.T) {
|
||||
t.Run("Other user attempts to delete testAccount via URL", func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
|
||||
|
||||
// OtherUserId belongs to otherAccountId, but we target testAccountId in URL
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account delete must be rejected")
|
||||
})
|
||||
}
|
||||
|
||||
func stringPointer(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
@@ -0,0 +1,445 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func Test_Records_GetAll(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get all records", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones/testZoneId/records", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.DNSRecord{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, len(got))
|
||||
assert.Equal(t, "sub.example.com", got[0].Name)
|
||||
assert.Equal(t, api.DNSRecordTypeA, got[0].Type)
|
||||
assert.Equal(t, "1.2.3.4", got[0].Content)
|
||||
assert.Equal(t, 300, got[0].Ttl)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Records_GetById(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
recordId string
|
||||
expectedStatus int
|
||||
expectRecord bool
|
||||
}{
|
||||
{
|
||||
name: "Get existing record",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "testRecordId",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectRecord: true,
|
||||
},
|
||||
{
|
||||
name: "Get non-existing record",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "nonExistingRecordId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectRecord: false,
|
||||
},
|
||||
{
|
||||
name: "Get record from non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
recordId: "testRecordId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectRecord: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
|
||||
|
||||
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
|
||||
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expectRecord {
|
||||
got := &api.DNSRecord{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.Equal(t, "testRecordId", got.Id)
|
||||
assert.Equal(t, "sub.example.com", got.Name)
|
||||
assert.Equal(t, api.DNSRecordTypeA, got.Type)
|
||||
assert.Equal(t, "1.2.3.4", got.Content)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Records_Create(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
requestBody *api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
|
||||
expectedStatus int
|
||||
verifyResponse func(t *testing.T, record *api.DNSRecord)
|
||||
}{
|
||||
{
|
||||
name: "Create A record",
|
||||
zoneId: "testZoneId",
|
||||
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
|
||||
Name: "new.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "5.6.7.8",
|
||||
Ttl: 600,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
|
||||
t.Helper()
|
||||
assert.NotEmpty(t, record.Id)
|
||||
assert.Equal(t, "new.example.com", record.Name)
|
||||
assert.Equal(t, api.DNSRecordTypeA, record.Type)
|
||||
assert.Equal(t, "5.6.7.8", record.Content)
|
||||
assert.Equal(t, 600, record.Ttl)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create CNAME record",
|
||||
zoneId: "testZoneId",
|
||||
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
|
||||
Name: "alias.example.com",
|
||||
Type: api.DNSRecordTypeCNAME,
|
||||
Content: "target.example.com",
|
||||
Ttl: 300,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
|
||||
t.Helper()
|
||||
assert.NotEmpty(t, record.Id)
|
||||
assert.Equal(t, "alias.example.com", record.Name)
|
||||
assert.Equal(t, api.DNSRecordTypeCNAME, record.Type)
|
||||
assert.Equal(t, "target.example.com", record.Content)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create record with invalid content for A type",
|
||||
zoneId: "testZoneId",
|
||||
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
|
||||
Name: "bad.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "not-an-ip",
|
||||
Ttl: 300,
|
||||
},
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
},
|
||||
{
|
||||
name: "Create record in non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
|
||||
Name: "new.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "5.6.7.8",
|
||||
Ttl: 600,
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
path := strings.Replace("/api/dns/zones/{zoneId}/records", "{zoneId}", tc.zoneId, 1)
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPost, path, user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.verifyResponse != nil {
|
||||
got := &api.DNSRecord{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify the created record directly in the DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbRecord := testing_tools.VerifyRecordInDB(t, db, got.Id)
|
||||
assert.Equal(t, got.Name, dbRecord.Name)
|
||||
assert.Equal(t, got.Content, dbRecord.Content)
|
||||
assert.Equal(t, got.Ttl, dbRecord.TTL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Records_Update(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
recordId string
|
||||
requestBody *api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
|
||||
expectedStatus int
|
||||
verifyResponse func(t *testing.T, record *api.DNSRecord)
|
||||
}{
|
||||
{
|
||||
name: "Update record content and TTL",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "testRecordId",
|
||||
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
|
||||
Name: "sub.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "10.20.30.40",
|
||||
Ttl: 600,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "sub.example.com", record.Name)
|
||||
assert.Equal(t, "10.20.30.40", record.Content)
|
||||
assert.Equal(t, 600, record.Ttl)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Update non-existing record",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "nonExistingRecordId",
|
||||
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
|
||||
Name: "sub.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "10.20.30.40",
|
||||
Ttl: 600,
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "Update record in non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
recordId: "testRecordId",
|
||||
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
|
||||
Name: "sub.example.com",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Content: "10.20.30.40",
|
||||
Ttl: 600,
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
|
||||
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPut, path, user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.verifyResponse != nil {
|
||||
got := &api.DNSRecord{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify the updated record directly in the DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbRecord := testing_tools.VerifyRecordInDB(t, db, tc.recordId)
|
||||
assert.Equal(t, "10.20.30.40", dbRecord.Content)
|
||||
assert.Equal(t, 600, dbRecord.TTL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Records_Delete(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
recordId string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Delete existing record",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "testRecordId",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete non-existing record",
|
||||
zoneId: "testZoneId",
|
||||
recordId: "nonExistingRecordId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "Delete record from non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
recordId: "testRecordId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
|
||||
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
|
||||
// Verify deletion in DB for successful deletes by privileged users
|
||||
if tc.expectedStatus == http.StatusOK && user.expectResponse {
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
testing_tools.VerifyRecordNotInDB(t, db, tc.recordId)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,416 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func Test_Zones_GetAll(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get all zones", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.Zone{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, len(got))
|
||||
assert.Equal(t, "Test Zone", got[0].Name)
|
||||
assert.Equal(t, "example.com", got[0].Domain)
|
||||
assert.Equal(t, true, got[0].Enabled)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Zones_GetById(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
expectedStatus int
|
||||
expectZone bool
|
||||
}{
|
||||
{
|
||||
name: "Get existing zone",
|
||||
zoneId: "testZoneId",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectZone: true,
|
||||
},
|
||||
{
|
||||
name: "Get non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectZone: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expectZone {
|
||||
got := &api.Zone{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.Equal(t, "testZoneId", got.Id)
|
||||
assert.Equal(t, "Test Zone", got.Name)
|
||||
assert.Equal(t, "example.com", got.Domain)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Zones_Create(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
enabled := true
|
||||
disabled := false
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
requestBody *api.PostApiDnsZonesJSONRequestBody
|
||||
expectedStatus int
|
||||
verifyResponse func(t *testing.T, zone *api.Zone)
|
||||
}{
|
||||
{
|
||||
name: "Create zone with valid data",
|
||||
requestBody: &api.PostApiDnsZonesJSONRequestBody{
|
||||
Name: "New Zone",
|
||||
Domain: "newzone.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testing_tools.TestGroupId},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, zone *api.Zone) {
|
||||
t.Helper()
|
||||
assert.NotEmpty(t, zone.Id)
|
||||
assert.Equal(t, "New Zone", zone.Name)
|
||||
assert.Equal(t, "newzone.com", zone.Domain)
|
||||
assert.Equal(t, true, zone.Enabled)
|
||||
assert.Equal(t, false, zone.EnableSearchDomain)
|
||||
assert.Equal(t, 1, len(zone.DistributionGroups))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create zone with search domain enabled",
|
||||
requestBody: &api.PostApiDnsZonesJSONRequestBody{
|
||||
Name: "Search Zone",
|
||||
Domain: "search.example.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testing_tools.TestGroupId},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, zone *api.Zone) {
|
||||
t.Helper()
|
||||
assert.NotEmpty(t, zone.Id)
|
||||
assert.Equal(t, "Search Zone", zone.Name)
|
||||
assert.Equal(t, true, zone.EnableSearchDomain)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create disabled zone",
|
||||
requestBody: &api.PostApiDnsZonesJSONRequestBody{
|
||||
Name: "Disabled Zone",
|
||||
Domain: "disabled.example.com",
|
||||
Enabled: &disabled,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testing_tools.TestGroupId},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, zone *api.Zone) {
|
||||
t.Helper()
|
||||
assert.NotEmpty(t, zone.Id)
|
||||
assert.Equal(t, false, zone.Enabled)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create zone with empty distribution groups",
|
||||
requestBody: &api.PostApiDnsZonesJSONRequestBody{
|
||||
Name: "No Groups Zone",
|
||||
Domain: "nogroups.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{},
|
||||
},
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/dns/zones", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.verifyResponse != nil {
|
||||
got := &api.Zone{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify the created zone directly in the DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbZone := testing_tools.VerifyZoneInDB(t, db, got.Id)
|
||||
assert.Equal(t, got.Name, dbZone.Name)
|
||||
assert.Equal(t, got.Domain, dbZone.Domain)
|
||||
assert.Equal(t, got.Enabled, dbZone.Enabled)
|
||||
assert.Equal(t, got.EnableSearchDomain, dbZone.EnableSearchDomain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Zones_Update(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
enabled := true
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
requestBody *api.PutApiDnsZonesZoneIdJSONRequestBody
|
||||
expectedStatus int
|
||||
verifyResponse func(t *testing.T, zone *api.Zone)
|
||||
}{
|
||||
{
|
||||
name: "Update zone name and settings",
|
||||
zoneId: "testZoneId",
|
||||
requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{
|
||||
Name: "Updated Zone",
|
||||
Domain: "example.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testing_tools.TestGroupId},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, zone *api.Zone) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "Updated Zone", zone.Name)
|
||||
assert.Equal(t, "example.com", zone.Domain)
|
||||
assert.Equal(t, true, zone.EnableSearchDomain)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Update non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{
|
||||
Name: "Whatever",
|
||||
Domain: "whatever.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testing_tools.TestGroupId},
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.verifyResponse != nil {
|
||||
got := &api.Zone{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify the updated zone directly in the DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbZone := testing_tools.VerifyZoneInDB(t, db, tc.zoneId)
|
||||
assert.Equal(t, "Updated Zone", dbZone.Name)
|
||||
assert.Equal(t, "example.com", dbZone.Domain)
|
||||
assert.Equal(t, true, dbZone.Enabled)
|
||||
assert.Equal(t, true, dbZone.EnableSearchDomain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Zones_Delete(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
zoneId string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Delete existing zone",
|
||||
zoneId: "testZoneId",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete non-existing zone",
|
||||
zoneId: "nonExistingZoneId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
|
||||
// Verify deletion in DB for successful deletes by privileged users
|
||||
if tc.expectedStatus == http.StatusOK && user.expectResponse {
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
testing_tools.VerifyZoneNotInDB(t, db, tc.zoneId)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -78,6 +78,68 @@ func Test_Events_GetAll(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Events_GetAll_Audit(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get all audit events", func(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, false)
|
||||
|
||||
// First, perform a mutation to generate an event (create a group as admin)
|
||||
groupBody, err := json.Marshal(&api.GroupRequest{Name: "auditTestGroup"})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal group request: %v", err)
|
||||
}
|
||||
createReq := testing_tools.BuildRequest(t, groupBody, http.MethodPost, "/api/groups", testing_tools.TestAdminId)
|
||||
createRecorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(createRecorder, createReq)
|
||||
assert.Equal(t, http.StatusOK, createRecorder.Code, "Failed to create group to generate event")
|
||||
|
||||
// Now query audit events
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events/audit", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.Event{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.GreaterOrEqual(t, len(got), 1, "Expected at least one event after creating a group")
|
||||
|
||||
// Verify the group creation event exists
|
||||
found := false
|
||||
for _, event := range got {
|
||||
if event.ActivityCode == "group.add" {
|
||||
found = true
|
||||
assert.Equal(t, testing_tools.TestAdminId, event.InitiatorId)
|
||||
assert.Equal(t, "Group created", event.Activity)
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected to find a group.add event")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Events_GetAll_Empty(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, true)
|
||||
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func Test_Geolocations_GetAllCountries(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get all countries", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/locations/countries", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.Country{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, len(got))
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Geolocations_GetCitiesByCountry(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get cities by country", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "US", 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.City{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, len(got))
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Geolocations_GetCitiesByCountry_InvalidCode(t *testing.T) {
|
||||
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "INVALID", 1), testing_tools.TestAdminId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, http.StatusUnprocessableEntity, true)
|
||||
}
|
||||
@@ -26,7 +26,7 @@ func Test_Groups_GetAll(t *testing.T) {
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
@@ -71,7 +71,7 @@ func Test_Groups_GetById(t *testing.T) {
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
@@ -216,7 +216,6 @@ func Test_Groups_Create(t *testing.T) {
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify group exists in DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbGroup := testing_tools.VerifyGroupInDB(t, db, got.Id)
|
||||
assert.Equal(t, tc.requestBody.Name, dbGroup.Name)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user