mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-11 10:29:58 +00:00
Compare commits
4 Commits
github-iss
...
feature/us
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a8dbef89b | ||
|
|
569ebb400b | ||
|
|
8ec17daf3a | ||
|
|
8bccbf9304 |
5
.github/issue-resolution/package.json
vendored
5
.github/issue-resolution/package.json
vendored
@@ -1,5 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "issue-resolution",
|
|
||||||
"private": true,
|
|
||||||
"type": "module"
|
|
||||||
}
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
You are a GitHub issue resolution classifier.
|
|
||||||
|
|
||||||
Your job is to decide whether an open GitHub issue is:
|
|
||||||
- AUTO_CLOSE
|
|
||||||
- MANUAL_REVIEW
|
|
||||||
- KEEP_OPEN
|
|
||||||
|
|
||||||
Rules:
|
|
||||||
1. AUTO_CLOSE is only allowed if there is objective, hard evidence:
|
|
||||||
- a merged linked PR that clearly resolves the issue, or
|
|
||||||
- an explicit maintainer/member/owner/collaborator comment saying the issue is fixed, resolved, duplicate, or superseded
|
|
||||||
2. If there is any contradictory later evidence, do NOT AUTO_CLOSE.
|
|
||||||
3. If evidence is promising but not airtight, choose MANUAL_REVIEW.
|
|
||||||
4. If the issue still appears active or unresolved, choose KEEP_OPEN.
|
|
||||||
5. Do not invent evidence.
|
|
||||||
6. Output valid JSON only.
|
|
||||||
|
|
||||||
Maintainer-authoritative roles:
|
|
||||||
- MEMBER
|
|
||||||
- OWNER
|
|
||||||
- COLLABORATOR
|
|
||||||
|
|
||||||
Partial fixes and multi-item issues:
|
|
||||||
- A merged PR that only addresses SOME items in a multi-item request does NOT resolve the issue. If the issue lists 5 feature requests and a PR fixes 1, the issue is still open.
|
|
||||||
- If a PR description or comment says "partially addresses", "partial fix", or similar, the issue is NOT resolved. Classify as KEEP_OPEN.
|
|
||||||
- If a merged PR addresses the core ask but a later comment objects or reports a regression, classify as MANUAL_REVIEW (not resolved).
|
|
||||||
|
|
||||||
Workarounds vs. actual fixes:
|
|
||||||
- A WORKAROUND is when a user changes their own setup to avoid the problem (editing configs, using a different setting, manual SQL fixes, switching tools, scripts). Workarounds do NOT count as resolution — the underlying issue is still present in the product.
|
|
||||||
- An ACTUAL FIX is when a user reports the problem went away after upgrading to a specific version (e.g., "fixed after updating to v0.65.1") or after a specific PR was merged. This suggests the fix was shipped in the product itself.
|
|
||||||
- A maintainer pointing to an existing alternative feature is NOT the same as fixing the issue. If the reporter never confirmed the alternative works for them, classify as KEEP_OPEN.
|
|
||||||
- If only workarounds exist and no maintainer has confirmed a fix, classify as KEEP_OPEN.
|
|
||||||
- If a user reports an actual fix via a version upgrade but no maintainer confirmed it, classify as MANUAL_REVIEW (not AUTO_CLOSE).
|
|
||||||
|
|
||||||
Stale issues:
|
|
||||||
- An issue with no activity for over 12 months, where a maintainer offered an alternative or asked for more info and the reporter never responded, is a candidate for MANUAL_REVIEW — not necessarily KEEP_OPEN.
|
|
||||||
|
|
||||||
Important:
|
|
||||||
- Later comments outweigh earlier ones.
|
|
||||||
- A non-maintainer saying "fixed for me" is not enough for AUTO_CLOSE.
|
|
||||||
- If uncertain, prefer MANUAL_REVIEW or KEEP_OPEN.
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"decision",
|
|
||||||
"reason_code",
|
|
||||||
"confidence",
|
|
||||||
"hard_signals",
|
|
||||||
"contradictions",
|
|
||||||
"summary",
|
|
||||||
"close_comment",
|
|
||||||
"manual_review_note"
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"decision": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["AUTO_CLOSE", "MANUAL_REVIEW", "KEEP_OPEN"]
|
|
||||||
},
|
|
||||||
"reason_code": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"resolved_by_merged_pr",
|
|
||||||
"maintainer_confirmed_resolved",
|
|
||||||
"duplicate_confirmed",
|
|
||||||
"superseded_confirmed",
|
|
||||||
"likely_fixed_but_unconfirmed",
|
|
||||||
"still_open",
|
|
||||||
"unclear"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"confidence": {
|
|
||||||
"type": "number",
|
|
||||||
"minimum": 0,
|
|
||||||
"maximum": 1
|
|
||||||
},
|
|
||||||
"hard_signals": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": ["type", "url"],
|
|
||||||
"properties": {
|
|
||||||
"type": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"merged_pr",
|
|
||||||
"maintainer_comment",
|
|
||||||
"duplicate_reference",
|
|
||||||
"superseded_reference"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"url": { "type": "string" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"contradictions": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": ["type", "url"],
|
|
||||||
"properties": {
|
|
||||||
"type": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"reporter_still_broken",
|
|
||||||
"later_unresolved_comment",
|
|
||||||
"ambiguous_pr_link",
|
|
||||||
"other"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"url": { "type": "string" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"summary": { "type": "string" },
|
|
||||||
"close_comment": { "type": "string" },
|
|
||||||
"manual_review_note": { "type": "string" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
213
.github/issue-resolution/scripts/apply-decisions.mjs
vendored
213
.github/issue-resolution/scripts/apply-decisions.mjs
vendored
@@ -1,213 +0,0 @@
|
|||||||
import fs from "node:fs/promises";
|
|
||||||
|
|
||||||
const decisions = JSON.parse(await fs.readFile("decisions.json", "utf8"));
|
|
||||||
const dryRun = String(process.env.DRY_RUN).toLowerCase() === "true";
|
|
||||||
|
|
||||||
const ghHeaders = {
|
|
||||||
Authorization: `Bearer ${process.env.GH_TOKEN}`,
|
|
||||||
Accept: "application/vnd.github+json",
|
|
||||||
"X-GitHub-Api-Version": "2022-11-28",
|
|
||||||
};
|
|
||||||
|
|
||||||
// Use PROJECT_PAT for project board operations, fall back to GH_TOKEN
|
|
||||||
const projectHeaders = {
|
|
||||||
Authorization: `Bearer ${process.env.PROJECT_PAT || process.env.GH_TOKEN}`,
|
|
||||||
Accept: "application/vnd.github+json",
|
|
||||||
"X-GitHub-Api-Version": "2022-11-28",
|
|
||||||
};
|
|
||||||
|
|
||||||
async function rest(url, method = "GET", body) {
|
|
||||||
const res = await fetch(url, {
|
|
||||||
method,
|
|
||||||
headers: ghHeaders,
|
|
||||||
body: body ? JSON.stringify(body) : undefined
|
|
||||||
});
|
|
||||||
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
|
|
||||||
return res.status === 204 ? null : res.json();
|
|
||||||
}
|
|
||||||
|
|
||||||
async function graphql(query, variables) {
|
|
||||||
const res = await fetch("https://api.github.com/graphql", {
|
|
||||||
method: "POST",
|
|
||||||
headers: projectHeaders,
|
|
||||||
body: JSON.stringify({ query, variables })
|
|
||||||
});
|
|
||||||
if (!res.ok) throw new Error(`${res.status}: ${await res.text()}`);
|
|
||||||
const json = await res.json();
|
|
||||||
if (json.errors) throw new Error(JSON.stringify(json.errors));
|
|
||||||
return json.data;
|
|
||||||
}
|
|
||||||
|
|
||||||
async function addLabel(owner, repo, issueNumber, labels) {
|
|
||||||
return rest(
|
|
||||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/labels`,
|
|
||||||
"POST",
|
|
||||||
{ labels }
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function addComment(owner, repo, issueNumber, body) {
|
|
||||||
return rest(
|
|
||||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/comments`,
|
|
||||||
"POST",
|
|
||||||
{ body }
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function closeIssue(owner, repo, issueNumber) {
|
|
||||||
return rest(
|
|
||||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`,
|
|
||||||
"PATCH",
|
|
||||||
{ state: "closed", state_reason: "completed" }
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function getIssueNodeId(owner, repo, issueNumber) {
|
|
||||||
const issue = await rest(`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`);
|
|
||||||
return issue.node_id;
|
|
||||||
}
|
|
||||||
|
|
||||||
async function addToProject(issueNodeId) {
|
|
||||||
const mutation = `
|
|
||||||
mutation($projectId: ID!, $contentId: ID!) {
|
|
||||||
addProjectV2ItemById(input: {projectId: $projectId, contentId: $contentId}) {
|
|
||||||
item { id }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
`;
|
|
||||||
|
|
||||||
try {
|
|
||||||
const data = await graphql(mutation, {
|
|
||||||
projectId: process.env.PROJECT_ID,
|
|
||||||
contentId: issueNodeId
|
|
||||||
});
|
|
||||||
return data.addProjectV2ItemById.item.id;
|
|
||||||
} catch (err) {
|
|
||||||
console.warn(`[WARN] Could not add to project (needs PAT with project scope): ${err.message}`);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function setTextField(itemId, fieldId, value) {
|
|
||||||
const mutation = `
|
|
||||||
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
|
|
||||||
updateProjectV2ItemFieldValue(input: {
|
|
||||||
projectId: $projectId,
|
|
||||||
itemId: $itemId,
|
|
||||||
fieldId: $fieldId,
|
|
||||||
value: { text: $value }
|
|
||||||
}) {
|
|
||||||
projectV2Item { id }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
`;
|
|
||||||
|
|
||||||
return graphql(mutation, {
|
|
||||||
projectId: process.env.PROJECT_ID,
|
|
||||||
itemId,
|
|
||||||
fieldId,
|
|
||||||
value
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async function setNumberField(itemId, fieldId, value) {
|
|
||||||
const mutation = `
|
|
||||||
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: Float!) {
|
|
||||||
updateProjectV2ItemFieldValue(input: {
|
|
||||||
projectId: $projectId,
|
|
||||||
itemId: $itemId,
|
|
||||||
fieldId: $fieldId,
|
|
||||||
value: { number: $value }
|
|
||||||
}) {
|
|
||||||
projectV2Item { id }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
`;
|
|
||||||
|
|
||||||
return graphql(mutation, {
|
|
||||||
projectId: process.env.PROJECT_ID,
|
|
||||||
itemId,
|
|
||||||
fieldId,
|
|
||||||
value
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async function setSingleSelectField(itemId, fieldId, optionId) {
|
|
||||||
const mutation = `
|
|
||||||
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $optionId: String!) {
|
|
||||||
updateProjectV2ItemFieldValue(input: {
|
|
||||||
projectId: $projectId,
|
|
||||||
itemId: $itemId,
|
|
||||||
fieldId: $fieldId,
|
|
||||||
value: { singleSelectOptionId: $optionId }
|
|
||||||
}) {
|
|
||||||
projectV2Item { id }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
`;
|
|
||||||
|
|
||||||
return graphql(mutation, {
|
|
||||||
projectId: process.env.PROJECT_ID,
|
|
||||||
itemId,
|
|
||||||
fieldId,
|
|
||||||
optionId
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async function addToProjectWithFields(owner, repo, d) {
|
|
||||||
const issueNodeId = await getIssueNodeId(owner, repo, d.issue_number);
|
|
||||||
const itemId = await addToProject(issueNodeId);
|
|
||||||
|
|
||||||
if (itemId) {
|
|
||||||
if (process.env.PROJECT_STATUS_FIELD_ID && process.env.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID) {
|
|
||||||
await setSingleSelectField(itemId, process.env.PROJECT_STATUS_FIELD_ID, process.env.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID);
|
|
||||||
}
|
|
||||||
if (process.env.PROJECT_CONFIDENCE_FIELD_ID) {
|
|
||||||
await setNumberField(itemId, process.env.PROJECT_CONFIDENCE_FIELD_ID, d.model.confidence);
|
|
||||||
}
|
|
||||||
if (process.env.PROJECT_REASON_FIELD_ID) {
|
|
||||||
await setTextField(itemId, process.env.PROJECT_REASON_FIELD_ID, d.model.reason_code);
|
|
||||||
}
|
|
||||||
if (process.env.PROJECT_EVIDENCE_FIELD_ID) {
|
|
||||||
await setTextField(itemId, process.env.PROJECT_EVIDENCE_FIELD_ID, d.issue_url);
|
|
||||||
}
|
|
||||||
console.log(` → Added to project board (Status: Needs Review)`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const d of decisions) {
|
|
||||||
const [owner, repo] = d.repository.split("/");
|
|
||||||
|
|
||||||
if (d.final_decision === "KEEP_OPEN") {
|
|
||||||
console.log(`#${d.issue_number} → KEEP_OPEN (confidence: ${d.model.confidence}, reason: ${d.model.reason_code})`);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dryRun) {
|
|
||||||
console.log(`[DRY RUN] #${d.issue_number} → ${d.final_decision} (confidence: ${d.model.confidence}, reason: ${d.model.reason_code})`);
|
|
||||||
// In dry-run: populate project board but don't touch issues
|
|
||||||
if (d.final_decision === "MANUAL_REVIEW" || d.final_decision === "AUTO_CLOSE") {
|
|
||||||
await addToProjectWithFields(owner, repo, d);
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (d.final_decision === "AUTO_CLOSE") {
|
|
||||||
await addLabel(owner, repo, d.issue_number, ["auto-closed-resolved"]);
|
|
||||||
await addComment(owner, repo, d.issue_number, d.model.close_comment);
|
|
||||||
await closeIssue(owner, repo, d.issue_number);
|
|
||||||
await addToProjectWithFields(owner, repo, d);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (d.final_decision === "MANUAL_REVIEW") {
|
|
||||||
await addLabel(owner, repo, d.issue_number, ["resolution-candidate"]);
|
|
||||||
await addToProjectWithFields(owner, repo, d);
|
|
||||||
await addComment(
|
|
||||||
owner,
|
|
||||||
repo,
|
|
||||||
d.issue_number,
|
|
||||||
d.model.manual_review_note ||
|
|
||||||
"This issue looks like a possible resolution candidate, but not with enough certainty for automatic closure. Added to the review queue."
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,259 +0,0 @@
|
|||||||
import fs from "node:fs/promises";
|
|
||||||
|
|
||||||
const candidates = JSON.parse(await fs.readFile("candidates.json", "utf8"));
|
|
||||||
const systemPrompt = await fs.readFile("prompts/issue-resolution-system.txt", "utf8");
|
|
||||||
const outputSchema = JSON.parse(await fs.readFile("schemas/issue-resolution-output.json", "utf8"));
|
|
||||||
|
|
||||||
function isMaintainerRole(role) {
|
|
||||||
return ["MEMBER", "OWNER", "COLLABORATOR"].includes(role || "");
|
|
||||||
}
|
|
||||||
|
|
||||||
function preScore(candidate) {
|
|
||||||
let score = 0;
|
|
||||||
const hardSignals = [];
|
|
||||||
const contradictions = [];
|
|
||||||
|
|
||||||
for (const t of candidate.timeline) {
|
|
||||||
const sourceIssue = t.source?.issue;
|
|
||||||
|
|
||||||
if (t.event === "cross-referenced" && sourceIssue?.pull_request?.html_url) {
|
|
||||||
hardSignals.push({
|
|
||||||
type: "merged_pr",
|
|
||||||
url: sourceIssue.html_url
|
|
||||||
});
|
|
||||||
score += 40; // provisional until PR merged state is verified
|
|
||||||
}
|
|
||||||
|
|
||||||
if (["referenced", "connected"].includes(t.event)) {
|
|
||||||
score += 10;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const c of candidate.comments) {
|
|
||||||
const body = c.body.toLowerCase();
|
|
||||||
|
|
||||||
if (
|
|
||||||
isMaintainerRole(c.author_association) &&
|
|
||||||
/\b(fixed|resolved|duplicate|superseded|closing)\b/.test(body)
|
|
||||||
) {
|
|
||||||
score += 25;
|
|
||||||
hardSignals.push({
|
|
||||||
type: "maintainer_comment",
|
|
||||||
url: c.html_url
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (/\b(still broken|still happening|not fixed|reproducible)\b/.test(body)) {
|
|
||||||
score -= 50;
|
|
||||||
contradictions.push({
|
|
||||||
type: "later_unresolved_comment",
|
|
||||||
url: c.html_url
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return { score, hardSignals, contradictions };
|
|
||||||
}
|
|
||||||
|
|
||||||
// GitHub Models gpt-4o has an 8000 token input limit.
|
|
||||||
// Reserve ~2000 tokens for system prompt + response overhead.
|
|
||||||
// 1 token ~= 4 chars, so cap user message at ~24000 chars.
|
|
||||||
const MAX_USER_MESSAGE_CHARS = 24000;
|
|
||||||
|
|
||||||
function truncate(text, maxChars) {
|
|
||||||
if (text.length <= maxChars) return text;
|
|
||||||
return text.slice(0, maxChars) + "\n\n[... truncated due to length]";
|
|
||||||
}
|
|
||||||
|
|
||||||
function buildUserMessage(candidate, pre) {
|
|
||||||
const { issue, comments, timeline } = candidate;
|
|
||||||
|
|
||||||
const commentBlock = comments
|
|
||||||
.map((c) => `[${c.author_association}] ${c.user} (${c.created_at}):\n${c.body}`)
|
|
||||||
.join("\n---\n");
|
|
||||||
|
|
||||||
const timelineBlock = timeline
|
|
||||||
.filter((t) => ["cross-referenced", "referenced", "connected", "closed", "reopened"].includes(t.event))
|
|
||||||
.map((t) => {
|
|
||||||
let line = `${t.event} (${t.created_at})`;
|
|
||||||
if (t.source?.issue?.html_url) line += ` — ${t.source.issue.html_url}`;
|
|
||||||
if (t.source?.issue?.pull_request?.html_url) line += ` (PR: ${t.source.issue.pull_request.html_url})`;
|
|
||||||
return line;
|
|
||||||
})
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
const sections = [
|
|
||||||
`## Issue #${issue.number}: ${issue.title}`,
|
|
||||||
`URL: ${issue.html_url}`,
|
|
||||||
`Created: ${issue.created_at} | Updated: ${issue.updated_at}`,
|
|
||||||
`Labels: ${issue.labels.join(", ") || "none"}`,
|
|
||||||
"",
|
|
||||||
"### Body",
|
|
||||||
truncate(issue.body || "(empty)", 4000),
|
|
||||||
"",
|
|
||||||
"### Comments",
|
|
||||||
commentBlock || "(none)",
|
|
||||||
"",
|
|
||||||
"### Timeline events",
|
|
||||||
timelineBlock || "(none)",
|
|
||||||
];
|
|
||||||
|
|
||||||
if (candidate.linked_prs?.length) {
|
|
||||||
sections.push("");
|
|
||||||
sections.push("### Linked PRs (verified state)");
|
|
||||||
for (const pr of candidate.linked_prs) {
|
|
||||||
const status = pr.merged ? `MERGED (${pr.merged_at})` : pr.state.toUpperCase();
|
|
||||||
sections.push(`- PR #${pr.number}: ${pr.title} — ${status} — ${pr.url}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pre.hardSignals.length || pre.contradictions.length) {
|
|
||||||
sections.push("");
|
|
||||||
sections.push("### Automated evidence scan");
|
|
||||||
for (const s of pre.hardSignals) {
|
|
||||||
sections.push(`- SIGNAL: ${s.type} — ${s.url}`);
|
|
||||||
}
|
|
||||||
for (const c of pre.contradictions) {
|
|
||||||
sections.push(`- CONTRADICTION: ${c.type} — ${c.url}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return truncate(sections.join("\n"), MAX_USER_MESSAGE_CHARS);
|
|
||||||
}
|
|
||||||
|
|
||||||
const MODEL = "gpt-4o-mini";
|
|
||||||
const MAX_RETRIES = 5;
|
|
||||||
|
|
||||||
function sleep(ms) {
|
|
||||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
|
||||||
}
|
|
||||||
|
|
||||||
async function callGitHubModel(candidate, pre) {
|
|
||||||
const body = JSON.stringify({
|
|
||||||
model: MODEL,
|
|
||||||
messages: [
|
|
||||||
{ role: "system", content: systemPrompt },
|
|
||||||
{ role: "user", content: buildUserMessage(candidate, pre) },
|
|
||||||
],
|
|
||||||
response_format: {
|
|
||||||
type: "json_schema",
|
|
||||||
json_schema: {
|
|
||||||
name: "issue_resolution",
|
|
||||||
strict: true,
|
|
||||||
schema: outputSchema,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
temperature: 0.1,
|
|
||||||
});
|
|
||||||
|
|
||||||
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
|
|
||||||
const res = await fetch("https://models.inference.ai.azure.com/chat/completions", {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
Authorization: `Bearer ${process.env.GH_TOKEN}`,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (res.status === 429) {
|
|
||||||
const retryAfter = Number(res.headers.get("retry-after")) || 30;
|
|
||||||
if (retryAfter > 120) {
|
|
||||||
console.warn(` [QUOTA EXHAUSTED] API wants ${retryAfter}s wait — skipping remaining issues.`);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
console.warn(` [RATE LIMITED] Waiting ${retryAfter}s (attempt ${attempt + 1}/${MAX_RETRIES})...`);
|
|
||||||
await sleep(retryAfter * 1000);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!res.ok) {
|
|
||||||
const text = await res.text();
|
|
||||||
throw new Error(`GitHub Models ${res.status}: ${text}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = await res.json();
|
|
||||||
return JSON.parse(data.choices[0].message.content);
|
|
||||||
}
|
|
||||||
|
|
||||||
throw new Error(`GitHub Models: exceeded ${MAX_RETRIES} retries due to rate limiting`);
|
|
||||||
}
|
|
||||||
|
|
||||||
function enforcePolicy(modelOut, pre) {
|
|
||||||
const approvedReasons = new Set([
|
|
||||||
"resolved_by_merged_pr",
|
|
||||||
"maintainer_confirmed_resolved",
|
|
||||||
"duplicate_confirmed",
|
|
||||||
"superseded_confirmed"
|
|
||||||
]);
|
|
||||||
|
|
||||||
const hasHardSignal =
|
|
||||||
(modelOut.hard_signals || []).some(s =>
|
|
||||||
["merged_pr", "maintainer_comment", "duplicate_reference", "superseded_reference"].includes(s.type)
|
|
||||||
) || pre.hardSignals.length > 0;
|
|
||||||
|
|
||||||
const hasContradiction =
|
|
||||||
(modelOut.contradictions || []).length > 0 || pre.contradictions.length > 0;
|
|
||||||
|
|
||||||
// Only auto-close with very strict criteria
|
|
||||||
if (
|
|
||||||
modelOut.decision === "AUTO_CLOSE" &&
|
|
||||||
modelOut.confidence >= 0.97 &&
|
|
||||||
approvedReasons.has(modelOut.reason_code) &&
|
|
||||||
hasHardSignal &&
|
|
||||||
!hasContradiction
|
|
||||||
) {
|
|
||||||
return "AUTO_CLOSE";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Downgrade AUTO_CLOSE that didn't pass the gate
|
|
||||||
if (modelOut.decision === "AUTO_CLOSE") {
|
|
||||||
return "MANUAL_REVIEW";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise trust the model
|
|
||||||
return modelOut.decision;
|
|
||||||
}
|
|
||||||
|
|
||||||
console.log(`Classifying ${candidates.length} candidates with ${MODEL}...\n`);
|
|
||||||
|
|
||||||
// 15 req/min limit → 1 request every 4s. Use 4.5s for safety margin.
|
|
||||||
const PACE_MS = 4500;
|
|
||||||
let lastRequestTime = 0;
|
|
||||||
|
|
||||||
async function paced(fn) {
|
|
||||||
const elapsed = Date.now() - lastRequestTime;
|
|
||||||
if (elapsed < PACE_MS) await sleep(PACE_MS - elapsed);
|
|
||||||
lastRequestTime = Date.now();
|
|
||||||
return fn();
|
|
||||||
}
|
|
||||||
|
|
||||||
const decisions = [];
|
|
||||||
for (const candidate of candidates) {
|
|
||||||
const pre = preScore(candidate);
|
|
||||||
const modelOut = await paced(() => callGitHubModel(candidate, pre));
|
|
||||||
|
|
||||||
if (modelOut === null) {
|
|
||||||
console.warn(`\nQuota exhausted after ${decisions.length} issues. Writing partial results.`);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const finalDecision = enforcePolicy(modelOut, pre);
|
|
||||||
|
|
||||||
decisions.push({
|
|
||||||
repository: candidate.repository,
|
|
||||||
issue_number: candidate.issue.number,
|
|
||||||
issue_url: candidate.issue.html_url,
|
|
||||||
title: candidate.issue.title,
|
|
||||||
pre_score: pre.score,
|
|
||||||
final_decision: finalDecision,
|
|
||||||
model: modelOut
|
|
||||||
});
|
|
||||||
|
|
||||||
console.log(
|
|
||||||
`#${candidate.issue.number} | pre_score: ${pre.score} | model: ${modelOut.decision} @ ${modelOut.confidence} | final: ${finalDecision} | ${modelOut.reason_code}`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
await fs.writeFile("decisions.json", JSON.stringify(decisions, null, 2));
|
|
||||||
console.log(`\nWrote ${decisions.length} decisions to decisions.json`);
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
import fs from "node:fs/promises";
|
|
||||||
|
|
||||||
const token = process.env.GH_TOKEN;
|
|
||||||
const repo = process.env.REPO; // "owner/repo"
|
|
||||||
const maxIssues = Number(process.env.MAX_ISSUES) || 100;
|
|
||||||
|
|
||||||
const headers = {
|
|
||||||
Authorization: `Bearer ${token}`,
|
|
||||||
Accept: "application/vnd.github+json",
|
|
||||||
"X-GitHub-Api-Version": "2022-11-28",
|
|
||||||
};
|
|
||||||
|
|
||||||
async function rest(url) {
|
|
||||||
const res = await fetch(url, { headers });
|
|
||||||
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
|
|
||||||
return res.json();
|
|
||||||
}
|
|
||||||
|
|
||||||
async function restSafe(url) {
|
|
||||||
const res = await fetch(url, { headers });
|
|
||||||
if (!res.ok) return null;
|
|
||||||
return res.json();
|
|
||||||
}
|
|
||||||
|
|
||||||
async function paginate(url, max) {
|
|
||||||
const items = [];
|
|
||||||
let page = 1;
|
|
||||||
while (items.length < max) {
|
|
||||||
const perPage = Math.min(100, max - items.length);
|
|
||||||
const sep = url.includes("?") ? "&" : "?";
|
|
||||||
const batch = await rest(`${url}${sep}per_page=${perPage}&page=${page}`);
|
|
||||||
if (!batch.length) break;
|
|
||||||
items.push(...batch);
|
|
||||||
page++;
|
|
||||||
}
|
|
||||||
return items.slice(0, max);
|
|
||||||
}
|
|
||||||
|
|
||||||
console.log(`Fetching up to ${maxIssues} open issues from ${repo}...`);
|
|
||||||
|
|
||||||
const issues = await paginate(
|
|
||||||
`https://api.github.com/repos/${repo}/issues?state=open&sort=updated&direction=asc`,
|
|
||||||
maxIssues
|
|
||||||
);
|
|
||||||
|
|
||||||
// Filter out pull requests (GitHub API returns PRs as issues too)
|
|
||||||
const realIssues = issues.filter((i) => !i.pull_request);
|
|
||||||
console.log(`Found ${realIssues.length} open issues (excluded PRs).`);
|
|
||||||
|
|
||||||
const candidates = [];
|
|
||||||
for (const issue of realIssues) {
|
|
||||||
const [comments, timeline] = await Promise.all([
|
|
||||||
rest(`https://api.github.com/repos/${repo}/issues/${issue.number}/comments?per_page=100`),
|
|
||||||
rest(`https://api.github.com/repos/${repo}/issues/${issue.number}/timeline?per_page=100`),
|
|
||||||
]);
|
|
||||||
|
|
||||||
candidates.push({
|
|
||||||
repository: repo,
|
|
||||||
issue: {
|
|
||||||
number: issue.number,
|
|
||||||
html_url: issue.html_url,
|
|
||||||
title: issue.title,
|
|
||||||
body: issue.body,
|
|
||||||
created_at: issue.created_at,
|
|
||||||
updated_at: issue.updated_at,
|
|
||||||
labels: issue.labels.map((l) => l.name),
|
|
||||||
},
|
|
||||||
comments: comments.map((c) => ({
|
|
||||||
body: c.body,
|
|
||||||
author_association: c.author_association,
|
|
||||||
html_url: c.html_url,
|
|
||||||
created_at: c.created_at,
|
|
||||||
user: c.user?.login,
|
|
||||||
})),
|
|
||||||
timeline: timeline.map((t) => ({
|
|
||||||
event: t.event,
|
|
||||||
created_at: t.created_at,
|
|
||||||
source: t.source
|
|
||||||
? {
|
|
||||||
issue: {
|
|
||||||
html_url: t.source.issue?.html_url,
|
|
||||||
pull_request: t.source.issue?.pull_request
|
|
||||||
? { html_url: t.source.issue.pull_request.html_url }
|
|
||||||
: undefined,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
: undefined,
|
|
||||||
})),
|
|
||||||
linked_prs: [],
|
|
||||||
});
|
|
||||||
|
|
||||||
// Fetch merge status for cross-referenced PRs
|
|
||||||
const prUrls = new Set();
|
|
||||||
for (const t of timeline) {
|
|
||||||
const prHtml = t.source?.issue?.pull_request?.html_url;
|
|
||||||
if (t.event === "cross-referenced" && prHtml) {
|
|
||||||
prUrls.add(prHtml);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const candidate = candidates[candidates.length - 1];
|
|
||||||
for (const prHtml of prUrls) {
|
|
||||||
// Extract owner/repo and PR number from URL like https://github.com/owner/repo/pull/123
|
|
||||||
const match = prHtml.match(/github\.com\/([^/]+\/[^/]+)\/pull\/(\d+)/);
|
|
||||||
if (!match) continue;
|
|
||||||
const [, prRepo, prNum] = match;
|
|
||||||
const pr = await restSafe(`https://api.github.com/repos/${prRepo}/pulls/${prNum}`);
|
|
||||||
if (!pr) continue;
|
|
||||||
candidate.linked_prs.push({
|
|
||||||
number: pr.number,
|
|
||||||
title: pr.title,
|
|
||||||
url: prHtml,
|
|
||||||
state: pr.state,
|
|
||||||
merged: pr.merged || false,
|
|
||||||
merged_at: pr.merged_at,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
console.log(` #${issue.number} — ${comments.length} comments, ${timeline.length} timeline events, ${candidate.linked_prs.length} linked PRs`);
|
|
||||||
}
|
|
||||||
|
|
||||||
await fs.writeFile("candidates.json", JSON.stringify(candidates, null, 2));
|
|
||||||
console.log(`Wrote ${candidates.length} candidates to candidates.json`);
|
|
||||||
63
.github/workflows/issue-resolution-triage.yml
vendored
63
.github/workflows/issue-resolution-triage.yml
vendored
@@ -1,63 +0,0 @@
|
|||||||
name: issue-resolution-triage
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [github-issue-resolver]
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
dry_run:
|
|
||||||
description: "If true, do not close issues"
|
|
||||||
required: false
|
|
||||||
default: "true"
|
|
||||||
max_issues:
|
|
||||||
description: "How many issues to process"
|
|
||||||
required: false
|
|
||||||
default: "100"
|
|
||||||
schedule:
|
|
||||||
- cron: "17 2 * * *"
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
issues: write
|
|
||||||
pull-requests: read
|
|
||||||
models: read
|
|
||||||
|
|
||||||
# todo: remove hardcoded values
|
|
||||||
jobs:
|
|
||||||
triage:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
PROJECT_PAT: ${{ secrets.PROJECT_PAT }}
|
|
||||||
DRY_RUN: "true"
|
|
||||||
MAX_ISSUES: "100"
|
|
||||||
REPO: ${{ github.repository }}
|
|
||||||
PROJECT_ID: "PVT_kwDOBfz4Jc4BVeWR"
|
|
||||||
PROJECT_STATUS_FIELD_ID: "PVTSSF_lADOBfz4Jc4BVeWRzhQ56sU"
|
|
||||||
PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID: "a55a2be9"
|
|
||||||
PROJECT_CONFIDENCE_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ57x4"
|
|
||||||
PROJECT_REASON_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ5-Lg"
|
|
||||||
PROJECT_EVIDENCE_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ5-Pw"
|
|
||||||
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
working-directory: .github/issue-resolution
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "20"
|
|
||||||
|
|
||||||
- run: node scripts/fetch-candidates.mjs
|
|
||||||
- run: node scripts/classify-candidates.mjs
|
|
||||||
- run: node scripts/apply-decisions.mjs
|
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
|
||||||
if: always()
|
|
||||||
with:
|
|
||||||
name: triage-results
|
|
||||||
path: |
|
|
||||||
.github/issue-resolution/candidates.json
|
|
||||||
.github/issue-resolution/decisions.json
|
|
||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.4"
|
SIGN_PIPE_VER: "v0.1.2"
|
||||||
GORELEASER_VER: "v2.14.3"
|
GORELEASER_VER: "v2.14.3"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|||||||
2
Makefile
2
Makefile
@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
|||||||
$(GOLANGCI_LINT):
|
$(GOLANGCI_LINT):
|
||||||
@echo "Installing golangci-lint..."
|
@echo "Installing golangci-lint..."
|
||||||
@mkdir -p ./bin
|
@mkdir -p ./bin
|
||||||
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
|
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||||
|
|
||||||
# Lint only changed files (fast, for pre-push)
|
# Lint only changed files (fast, for pre-push)
|
||||||
lint: $(GOLANGCI_LINT)
|
lint: $(GOLANGCI_LINT)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
@@ -16,7 +15,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/debug"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -28,7 +26,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
types "github.com/netbirdio/netbird/upload-server/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -71,30 +68,7 @@ type Client struct {
|
|||||||
uiVersion string
|
uiVersion string
|
||||||
networkChangeListener listener.NetworkChangeListener
|
networkChangeListener listener.NetworkChangeListener
|
||||||
|
|
||||||
stateMu sync.RWMutex
|
|
||||||
connectClient *internal.ConnectClient
|
connectClient *internal.ConnectClient
|
||||||
config *profilemanager.Config
|
|
||||||
cacheDir string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
|
|
||||||
c.stateMu.Lock()
|
|
||||||
defer c.stateMu.Unlock()
|
|
||||||
c.config = cfg
|
|
||||||
c.cacheDir = cacheDir
|
|
||||||
c.connectClient = cc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
|
|
||||||
c.stateMu.RLock()
|
|
||||||
defer c.stateMu.RUnlock()
|
|
||||||
return c.config, c.cacheDir, c.connectClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) getConnectClient() *internal.ConnectClient {
|
|
||||||
c.stateMu.RLock()
|
|
||||||
defer c.stateMu.RUnlock()
|
|
||||||
return c.connectClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
@@ -119,7 +93,6 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
|||||||
|
|
||||||
cfgFile := platformFiles.ConfigurationFilePath()
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
stateFile := platformFiles.StateFilePath()
|
stateFile := platformFiles.StateFilePath()
|
||||||
cacheDir := platformFiles.CacheDir()
|
|
||||||
|
|
||||||
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
@@ -151,9 +124,8 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
c.setState(cfg, cacheDir, connectClient)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -163,7 +135,6 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
|||||||
|
|
||||||
cfgFile := platformFiles.ConfigurationFilePath()
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
stateFile := platformFiles.StateFilePath()
|
stateFile := platformFiles.StateFilePath()
|
||||||
cacheDir := platformFiles.CacheDir()
|
|
||||||
|
|
||||||
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
@@ -186,9 +157,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
c.setState(cfg, cacheDir, connectClient)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -203,12 +173,11 @@ func (c *Client) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) RenewTun(fd int) error {
|
func (c *Client) RenewTun(fd int) error {
|
||||||
cc := c.getConnectClient()
|
if c.connectClient == nil {
|
||||||
if cc == nil {
|
|
||||||
return fmt.Errorf("engine not running")
|
return fmt.Errorf("engine not running")
|
||||||
}
|
}
|
||||||
|
|
||||||
e := cc.Engine()
|
e := c.connectClient.Engine()
|
||||||
if e == nil {
|
if e == nil {
|
||||||
return fmt.Errorf("engine not initialized")
|
return fmt.Errorf("engine not initialized")
|
||||||
}
|
}
|
||||||
@@ -216,73 +185,6 @@ func (c *Client) RenewTun(fd int) error {
|
|||||||
return e.RenewTun(fd)
|
return e.RenewTun(fd)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
|
|
||||||
// It works both with and without a running engine.
|
|
||||||
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
|
|
||||||
cfg, cacheDir, cc := c.stateSnapshot()
|
|
||||||
|
|
||||||
// If the engine hasn't been started, load config from disk
|
|
||||||
if cfg == nil {
|
|
||||||
var err error
|
|
||||||
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
|
||||||
ConfigPath: platformFiles.ConfigurationFilePath(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("load config: %w", err)
|
|
||||||
}
|
|
||||||
cacheDir = platformFiles.CacheDir()
|
|
||||||
}
|
|
||||||
|
|
||||||
deps := debug.GeneratorDependencies{
|
|
||||||
InternalConfig: cfg,
|
|
||||||
StatusRecorder: c.recorder,
|
|
||||||
TempDir: cacheDir,
|
|
||||||
}
|
|
||||||
|
|
||||||
if cc != nil {
|
|
||||||
resp, err := cc.GetLatestSyncResponse()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("get latest sync response: %v", err)
|
|
||||||
}
|
|
||||||
deps.SyncResponse = resp
|
|
||||||
|
|
||||||
if e := cc.Engine(); e != nil {
|
|
||||||
if cm := e.GetClientMetrics(); cm != nil {
|
|
||||||
deps.ClientMetrics = cm
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bundleGenerator := debug.NewBundleGenerator(
|
|
||||||
deps,
|
|
||||||
debug.BundleConfig{
|
|
||||||
Anonymize: anonymize,
|
|
||||||
IncludeSystemInfo: true,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
path, err := bundleGenerator.Generate()
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("generate debug bundle: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := os.Remove(path); err != nil {
|
|
||||||
log.Errorf("failed to remove debug bundle file: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("upload debug bundle: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("debug bundle uploaded with key %s", key)
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTraceLogLevel configure the logger to trace level
|
// SetTraceLogLevel configure the logger to trace level
|
||||||
func (c *Client) SetTraceLogLevel() {
|
func (c *Client) SetTraceLogLevel() {
|
||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
@@ -312,13 +214,12 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Networks() *NetworkArray {
|
func (c *Client) Networks() *NetworkArray {
|
||||||
cc := c.getConnectClient()
|
if c.connectClient == nil {
|
||||||
if cc == nil {
|
|
||||||
log.Error("not connected")
|
log.Error("not connected")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
engine := cc.Engine()
|
engine := c.connectClient.Engine()
|
||||||
if engine == nil {
|
if engine == nil {
|
||||||
log.Error("could not get engine")
|
log.Error("could not get engine")
|
||||||
return nil
|
return nil
|
||||||
@@ -399,7 +300,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||||
client := c.getConnectClient()
|
client := c.connectClient
|
||||||
if client == nil {
|
if client == nil {
|
||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,5 +7,4 @@ package android
|
|||||||
type PlatformFiles interface {
|
type PlatformFiles interface {
|
||||||
ConfigurationFilePath() string
|
ConfigurationFilePath() string
|
||||||
StateFilePath() string
|
StateFilePath() string
|
||||||
CacheDir() string
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
// Package firewalld integrates with the firewalld daemon so NetBird can place
|
|
||||||
// its wg interface into firewalld's "trusted" zone. This is required because
|
|
||||||
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
|
|
||||||
// versions, which returns EPERM to any other process that tries to insert
|
|
||||||
// rules into them. The workaround mirrors what Tailscale does: let firewalld
|
|
||||||
// itself add the accept rules to its own chains by trusting the interface.
|
|
||||||
package firewalld
|
|
||||||
|
|
||||||
// TrustedZone is the firewalld zone name used for interfaces whose traffic
|
|
||||||
// should bypass firewalld filtering.
|
|
||||||
const TrustedZone = "trusted"
|
|
||||||
@@ -1,260 +0,0 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
package firewalld
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
dbusDest = "org.fedoraproject.FirewallD1"
|
|
||||||
dbusPath = "/org/fedoraproject/FirewallD1"
|
|
||||||
dbusRootIface = "org.fedoraproject.FirewallD1"
|
|
||||||
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
|
|
||||||
|
|
||||||
errZoneAlreadySet = "ZONE_ALREADY_SET"
|
|
||||||
errAlreadyEnabled = "ALREADY_ENABLED"
|
|
||||||
errUnknownIface = "UNKNOWN_INTERFACE"
|
|
||||||
errNotEnabled = "NOT_ENABLED"
|
|
||||||
|
|
||||||
// callTimeout bounds each individual DBus or firewall-cmd invocation.
|
|
||||||
// A fresh context is created for each call so a slow DBus probe can't
|
|
||||||
// exhaust the deadline before the firewall-cmd fallback gets to run.
|
|
||||||
callTimeout = 3 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errDBusUnavailable = errors.New("firewalld dbus unavailable")
|
|
||||||
|
|
||||||
// trustLogOnce ensures the "added to trusted zone" message is logged at
|
|
||||||
// Info level only for the first successful add per process; repeat adds
|
|
||||||
// from other init paths are quieter.
|
|
||||||
trustLogOnce sync.Once
|
|
||||||
|
|
||||||
parentCtxMu sync.RWMutex
|
|
||||||
parentCtx context.Context = context.Background()
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetParentContext installs a parent context whose cancellation aborts any
|
|
||||||
// in-flight TrustInterface call. It does not affect UntrustInterface, which
|
|
||||||
// always uses a fresh Background-rooted timeout so cleanup can still run
|
|
||||||
// during engine shutdown when the engine context is already cancelled.
|
|
||||||
func SetParentContext(ctx context.Context) {
|
|
||||||
parentCtxMu.Lock()
|
|
||||||
parentCtx = ctx
|
|
||||||
parentCtxMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func getParentContext() context.Context {
|
|
||||||
parentCtxMu.RLock()
|
|
||||||
defer parentCtxMu.RUnlock()
|
|
||||||
return parentCtx
|
|
||||||
}
|
|
||||||
|
|
||||||
// TrustInterface places iface into firewalld's trusted zone if firewalld is
|
|
||||||
// running. It is idempotent and best-effort: errors are returned so callers
|
|
||||||
// can log, but a non-running firewalld is not an error. Only the first
|
|
||||||
// successful call per process logs at Info. Respects the parent context set
|
|
||||||
// via SetParentContext so startup-time cancellation unblocks it.
|
|
||||||
func TrustInterface(iface string) error {
|
|
||||||
parent := getParentContext()
|
|
||||||
if !isRunning(parent) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := addTrusted(parent, iface); err != nil {
|
|
||||||
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
|
|
||||||
}
|
|
||||||
trustLogOnce.Do(func() {
|
|
||||||
log.Infof("added %s to firewalld trusted zone", iface)
|
|
||||||
})
|
|
||||||
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
|
|
||||||
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
|
|
||||||
// during shutdown after the engine context has been cancelled.
|
|
||||||
func UntrustInterface(iface string) error {
|
|
||||||
if !isRunning(context.Background()) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := removeTrusted(context.Background(), iface); err != nil {
|
|
||||||
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
|
|
||||||
return context.WithTimeout(parent, callTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRunning(parent context.Context) bool {
|
|
||||||
ctx, cancel := newCallContext(parent)
|
|
||||||
ok, err := isRunningDBus(ctx)
|
|
||||||
cancel()
|
|
||||||
if err == nil {
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
ctx, cancel = newCallContext(parent)
|
|
||||||
defer cancel()
|
|
||||||
return isRunningCLI(ctx)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func addTrusted(parent context.Context, iface string) error {
|
|
||||||
ctx, cancel := newCallContext(parent)
|
|
||||||
err := addDBus(ctx, iface)
|
|
||||||
cancel()
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if !errors.Is(err, errDBusUnavailable) {
|
|
||||||
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
|
|
||||||
}
|
|
||||||
ctx, cancel = newCallContext(parent)
|
|
||||||
defer cancel()
|
|
||||||
return addCLI(ctx, iface)
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeTrusted(parent context.Context, iface string) error {
|
|
||||||
ctx, cancel := newCallContext(parent)
|
|
||||||
err := removeDBus(ctx, iface)
|
|
||||||
cancel()
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if !errors.Is(err, errDBusUnavailable) {
|
|
||||||
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
|
|
||||||
}
|
|
||||||
ctx, cancel = newCallContext(parent)
|
|
||||||
defer cancel()
|
|
||||||
return removeCLI(ctx, iface)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRunningDBus(ctx context.Context) (bool, error) {
|
|
||||||
conn, err := dbus.SystemBus()
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
|
||||||
}
|
|
||||||
obj := conn.Object(dbusDest, dbusPath)
|
|
||||||
|
|
||||||
var zone string
|
|
||||||
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
|
|
||||||
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRunningCLI(ctx context.Context) bool {
|
|
||||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func addDBus(ctx context.Context, iface string) error {
|
|
||||||
conn, err := dbus.SystemBus()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
|
||||||
}
|
|
||||||
obj := conn.Object(dbusDest, dbusPath)
|
|
||||||
|
|
||||||
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
|
|
||||||
if call.Err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if dbusErrContains(call.Err, errAlreadyEnabled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if dbusErrContains(call.Err, errZoneAlreadySet) {
|
|
||||||
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
|
|
||||||
if move.Err != nil {
|
|
||||||
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("firewalld addInterface: %w", call.Err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeDBus(ctx context.Context, iface string) error {
|
|
||||||
conn, err := dbus.SystemBus()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
|
||||||
}
|
|
||||||
obj := conn.Object(dbusDest, dbusPath)
|
|
||||||
|
|
||||||
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
|
|
||||||
if call.Err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func addCLI(ctx context.Context, iface string) error {
|
|
||||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
|
||||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// --change-interface (no --permanent) binds the interface for the
|
|
||||||
// current runtime only; we do not want membership to persist across
|
|
||||||
// reboots because netbird re-asserts it on every startup.
|
|
||||||
out, err := exec.CommandContext(ctx,
|
|
||||||
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
|
|
||||||
).CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeCLI(ctx context.Context, iface string) error {
|
|
||||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
|
||||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := exec.CommandContext(ctx,
|
|
||||||
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
|
|
||||||
).CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
msg := strings.TrimSpace(string(out))
|
|
||||||
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func dbusErrContains(err error, code string) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
var de dbus.Error
|
|
||||||
if errors.As(err, &de) {
|
|
||||||
for _, b := range de.Body {
|
|
||||||
if s, ok := b.(string); ok && strings.Contains(s, code) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Contains(err.Error(), code)
|
|
||||||
}
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
package firewalld
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDBusErrContains(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
code string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"nil error", nil, errZoneAlreadySet, false},
|
|
||||||
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
|
|
||||||
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
|
|
||||||
{
|
|
||||||
"dbus.Error body match",
|
|
||||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
|
|
||||||
errZoneAlreadySet,
|
|
||||||
true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dbus.Error body miss",
|
|
||||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
|
|
||||||
errAlreadyEnabled,
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dbus.Error non-string body falls back to Error()",
|
|
||||||
dbus.Error{Name: "x", Body: []any{123}},
|
|
||||||
"x",
|
|
||||||
true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
got := dbusErrContains(tc.err, tc.code)
|
|
||||||
if got != tc.want {
|
|
||||||
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
//go:build !linux
|
|
||||||
|
|
||||||
package firewalld
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
// SetParentContext is a no-op on non-Linux platforms because firewalld only
|
|
||||||
// runs on Linux.
|
|
||||||
func SetParentContext(context.Context) {
|
|
||||||
// intentionally empty: firewalld is a Linux-only daemon
|
|
||||||
}
|
|
||||||
|
|
||||||
// TrustInterface is a no-op on non-Linux platforms because firewalld only
|
|
||||||
// runs on Linux.
|
|
||||||
func TrustInterface(string) error {
|
|
||||||
// intentionally empty: firewalld is a Linux-only daemon
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
|
|
||||||
// runs on Linux.
|
|
||||||
func UntrustInterface(string) error {
|
|
||||||
// intentionally empty: firewalld is a Linux-only daemon
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -87,12 +86,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trust after all fatal init steps so a later failure doesn't leave the
|
|
||||||
// interface in firewalld's trusted zone without a corresponding Close.
|
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// persist early to ensure cleanup of chains
|
// persist early to ensure cleanup of chains
|
||||||
go func() {
|
go func() {
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
@@ -198,12 +191,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
|
||||||
// stays persisted and the crash-recovery path retries firewalld cleanup.
|
|
||||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// attempt to delete state only if all other operations succeeded
|
// attempt to delete state only if all other operations succeeded
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
@@ -230,11 +217,6 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -218,10 +217,6 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
@@ -41,8 +40,6 @@ const (
|
|||||||
chainNameForward = "FORWARD"
|
chainNameForward = "FORWARD"
|
||||||
chainNameMangleForward = "netbird-mangle-forward"
|
chainNameMangleForward = "netbird-mangle-forward"
|
||||||
|
|
||||||
firewalldTableName = "firewalld"
|
|
||||||
|
|
||||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||||
userDataAcceptInputRule = "inputaccept"
|
userDataAcceptInputRule = "inputaccept"
|
||||||
@@ -136,10 +133,6 @@ func (r *router) Reset() error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.removeNatPreroutingRules(); err != nil {
|
if err := r.removeNatPreroutingRules(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||||
}
|
}
|
||||||
@@ -287,10 +280,6 @@ func (r *router) createContainers() error {
|
|||||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
log.Errorf("failed to refresh rules: %s", err)
|
log.Errorf("failed to refresh rules: %s", err)
|
||||||
}
|
}
|
||||||
@@ -1330,13 +1319,6 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
|
||||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
|
||||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
|
||||||
if chain.Table.Name == firewalldTableName {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip all iptables-managed tables in the ip family
|
// Skip all iptables-managed tables in the ip family
|
||||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -3,9 +3,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,9 +16,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Close(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
}
|
}
|
||||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,8 +24,5 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AllowNetbird()
|
return m.nativeFirewall.AllowNetbird()
|
||||||
}
|
}
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
Name() string
|
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
GetWGDevice() *wgdevice.Device
|
GetWGDevice() *wgdevice.Device
|
||||||
|
|||||||
@@ -31,20 +31,12 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
|
|||||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
NameFunc func() string
|
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
AddressFunc func() wgaddr.Address
|
AddressFunc func() wgaddr.Address
|
||||||
GetWGDeviceFunc func() *wgdevice.Device
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
GetDeviceFunc func() *device.FilteredDevice
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) Name() string {
|
|
||||||
if i.NameFunc == nil {
|
|
||||||
return "wgtest"
|
|
||||||
}
|
|
||||||
return i.NameFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||||
if i.GetWGDeviceFunc == nil {
|
if i.GetWGDeviceFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -239,12 +239,8 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
|||||||
ipv6Count++
|
ipv6Count++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||||
// routing-correctness checks above are the real assertions; the counts
|
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||||
// are a sanity bound to catch a totally silent path.
|
|
||||||
minDelivered := packetsPerFamily * 80 / 100
|
|
||||||
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
|
||||||
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||||
|
|||||||
@@ -217,6 +217,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
|||||||
// Close closes the tunnel interface
|
// Close closes the tunnel interface
|
||||||
func (w *WGIface) Close() error {
|
func (w *WGIface) Close() error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
|
||||||
@@ -224,15 +225,7 @@ func (w *WGIface) Close() error {
|
|||||||
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release w.mu before calling w.tun.Close(): the underlying
|
if err := w.tun.Close(); err != nil {
|
||||||
// wireguard-go device.Close() waits for its send/receive goroutines
|
|
||||||
// to drain. Some of those goroutines re-enter WGIface methods that
|
|
||||||
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
|
|
||||||
// holding the mutex here would deadlock the shutdown path.
|
|
||||||
tun := w.tun
|
|
||||||
w.mu.Unlock()
|
|
||||||
|
|
||||||
if err := tun.Close(); err != nil {
|
|
||||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package iface
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
// fakeTunDevice implements WGTunDevice and lets the test control when
|
|
||||||
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
|
|
||||||
// until its goroutines drain. Some of those goroutines (e.g. the packet
|
|
||||||
// filter DNS hook in client/internal/dns) call back into WGIface, so if
|
|
||||||
// WGIface.Close() held w.mu across tun.Close() the shutdown would
|
|
||||||
// deadlock.
|
|
||||||
type fakeTunDevice struct {
|
|
||||||
closeStarted chan struct{}
|
|
||||||
unblockClose chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
|
|
||||||
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
|
|
||||||
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
|
|
||||||
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
|
|
||||||
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
|
|
||||||
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
|
|
||||||
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
|
|
||||||
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
|
|
||||||
|
|
||||||
func (f *fakeTunDevice) Close() error {
|
|
||||||
close(f.closeStarted)
|
|
||||||
<-f.unblockClose
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeProxyFactory struct{}
|
|
||||||
|
|
||||||
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
|
|
||||||
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
|
|
||||||
func (fakeProxyFactory) Free() error { return nil }
|
|
||||||
|
|
||||||
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
|
|
||||||
// that surfaces as a macOS test-timeout in
|
|
||||||
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
|
|
||||||
// while waiting for the wireguard-go device goroutines to finish, and
|
|
||||||
// one of those goroutines (the DNS filter hook) calls back into
|
|
||||||
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
|
|
||||||
// the lock before tun.Close() returns control.
|
|
||||||
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
|
|
||||||
tun := &fakeTunDevice{
|
|
||||||
closeStarted: make(chan struct{}),
|
|
||||||
unblockClose: make(chan struct{}),
|
|
||||||
}
|
|
||||||
w := &WGIface{
|
|
||||||
tun: tun,
|
|
||||||
wgProxyFactory: fakeProxyFactory{},
|
|
||||||
}
|
|
||||||
|
|
||||||
closeDone := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
closeDone <- w.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-tun.closeStarted:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
close(tun.unblockClose)
|
|
||||||
t.Fatal("tun.Close() was never invoked")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate the WireGuard read goroutine calling back into WGIface
|
|
||||||
// via the packet filter's DNS hook. If Close() still held w.mu
|
|
||||||
// during tun.Close(), this would block until the test timeout.
|
|
||||||
getDeviceDone := make(chan struct{})
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
_ = w.GetDevice()
|
|
||||||
close(getDeviceDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-getDeviceDone:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
close(tun.unblockClose)
|
|
||||||
wg.Wait()
|
|
||||||
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
|
|
||||||
}
|
|
||||||
|
|
||||||
close(tun.unblockClose)
|
|
||||||
select {
|
|
||||||
case <-closeDone:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if u.address.Network.Contains(a) {
|
if u.address.Network.Contains(a) {
|
||||||
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
|||||||
u.addrCache.Store(addr.String(), isRouted)
|
u.addrCache.Store(addr.String(), isRouted)
|
||||||
if isRouted {
|
if isRouted {
|
||||||
// Extra log, as the error only shows up with ICE logging enabled
|
// Extra log, as the error only shows up with ICE logging enabled
|
||||||
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
|
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||||
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -94,7 +94,6 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
dnsAddresses []netip.AddrPort,
|
dnsAddresses []netip.AddrPort,
|
||||||
dnsReadyListener dns.ReadyListener,
|
dnsReadyListener dns.ReadyListener,
|
||||||
stateFilePath string,
|
stateFilePath string,
|
||||||
cacheDir string,
|
|
||||||
) error {
|
) error {
|
||||||
// in case of non Android os these variables will be nil
|
// in case of non Android os these variables will be nil
|
||||||
mobileDependency := MobileDependency{
|
mobileDependency := MobileDependency{
|
||||||
@@ -104,7 +103,6 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
TempDir: cacheDir,
|
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, "")
|
return c.run(mobileDependency, nil, "")
|
||||||
}
|
}
|
||||||
@@ -333,10 +331,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
c.statusRecorder.MarkSignalConnected()
|
c.statusRecorder.MarkSignalConnected()
|
||||||
|
|
||||||
relayURLs, token := parseRelayInfo(loginResp)
|
relayURLs, token := parseRelayInfo(loginResp)
|
||||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
|
||||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
|
||||||
relayURLs = override
|
|
||||||
}
|
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
|
|
||||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||||
@@ -344,7 +338,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
engineConfig.TempDir = mobileDependency.TempDir
|
|
||||||
|
|
||||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||||
c.statusRecorder.SetRelayMgr(relayManager)
|
c.statusRecorder.SetRelayMgr(relayManager)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -30,6 +31,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const readmeContent = `Netbird debug bundle
|
const readmeContent = `Netbird debug bundle
|
||||||
@@ -232,7 +234,6 @@ type BundleGenerator struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
syncResponse *mgmProto.SyncResponse
|
syncResponse *mgmProto.SyncResponse
|
||||||
logPath string
|
logPath string
|
||||||
tempDir string
|
|
||||||
cpuProfile []byte
|
cpuProfile []byte
|
||||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
clientMetrics MetricsExporter
|
clientMetrics MetricsExporter
|
||||||
@@ -255,7 +256,6 @@ type GeneratorDependencies struct {
|
|||||||
StatusRecorder *peer.Status
|
StatusRecorder *peer.Status
|
||||||
SyncResponse *mgmProto.SyncResponse
|
SyncResponse *mgmProto.SyncResponse
|
||||||
LogPath string
|
LogPath string
|
||||||
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
|
||||||
CPUProfile []byte
|
CPUProfile []byte
|
||||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
ClientMetrics MetricsExporter
|
ClientMetrics MetricsExporter
|
||||||
@@ -275,7 +275,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
statusRecorder: deps.StatusRecorder,
|
statusRecorder: deps.StatusRecorder,
|
||||||
syncResponse: deps.SyncResponse,
|
syncResponse: deps.SyncResponse,
|
||||||
logPath: deps.LogPath,
|
logPath: deps.LogPath,
|
||||||
tempDir: deps.TempDir,
|
|
||||||
cpuProfile: deps.CPUProfile,
|
cpuProfile: deps.CPUProfile,
|
||||||
refreshStatus: deps.RefreshStatus,
|
refreshStatus: deps.RefreshStatus,
|
||||||
clientMetrics: deps.ClientMetrics,
|
clientMetrics: deps.ClientMetrics,
|
||||||
@@ -288,7 +287,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
|
|
||||||
// Generate creates a debug bundle and returns the location.
|
// Generate creates a debug bundle and returns the location.
|
||||||
func (g *BundleGenerator) Generate() (resp string, err error) {
|
func (g *BundleGenerator) Generate() (resp string, err error) {
|
||||||
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
|
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("create zip file: %w", err)
|
return "", fmt.Errorf("create zip file: %w", err)
|
||||||
}
|
}
|
||||||
@@ -374,8 +373,15 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add wg show output: %v", err)
|
log.Errorf("failed to add wg show output: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.addPlatformLog(); err != nil {
|
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||||
log.Errorf("failed to add logs to debug bundle: %v", err)
|
if err := g.addLogfile(); err != nil {
|
||||||
|
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||||
|
if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
log.Errorf("failed to add systemd logs as fallback: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
log.Errorf("failed to add systemd logs: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.addUpdateLogs(); err != nil {
|
if err := g.addUpdateLogs(); err != nil {
|
||||||
|
|||||||
@@ -1,41 +0,0 @@
|
|||||||
//go:build android
|
|
||||||
|
|
||||||
package debug
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"os/exec"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (g *BundleGenerator) addPlatformLog() error {
|
|
||||||
cmd := exec.Command("/system/bin/logcat", "-d")
|
|
||||||
stdout, err := cmd.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("logcat stdout pipe: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
return fmt.Errorf("start logcat: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var logReader io.Reader = stdout
|
|
||||||
if g.anonymize {
|
|
||||||
var pw *io.PipeWriter
|
|
||||||
logReader, pw = io.Pipe()
|
|
||||||
go anonymizeLog(stdout, pw, g.anonymizer)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
|
|
||||||
return fmt.Errorf("add logcat to zip: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cmd.Wait(); err != nil {
|
|
||||||
return fmt.Errorf("wait logcat: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("added logcat output to debug bundle")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package debug
|
|
||||||
|
|
||||||
import (
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (g *BundleGenerator) addPlatformLog() error {
|
|
||||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
|
||||||
if err := g.addLogfile(); err != nil {
|
|
||||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
|
||||||
if err := g.trySystemdLogFallback(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if err := g.trySystemdLogFallback(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -3,12 +3,10 @@ package debug
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@@ -21,10 +19,8 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Skip("Skipping upload test on docker ci")
|
t.Skip("Skipping upload test on docker ci")
|
||||||
}
|
}
|
||||||
testDir := t.TempDir()
|
testDir := t.TempDir()
|
||||||
addr := reserveLoopbackPort(t)
|
testURL := "http://localhost:8080"
|
||||||
testURL := "http://" + addr
|
|
||||||
t.Setenv("SERVER_URL", testURL)
|
t.Setenv("SERVER_URL", testURL)
|
||||||
t.Setenv("SERVER_ADDRESS", addr)
|
|
||||||
t.Setenv("STORE_DIR", testDir)
|
t.Setenv("STORE_DIR", testDir)
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
@@ -37,7 +33,6 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Errorf("Failed to stop server: %v", err)
|
t.Errorf("Failed to stop server: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
waitForServer(t, addr)
|
|
||||||
|
|
||||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||||
fileContent := []byte("test file content")
|
fileContent := []byte("test file content")
|
||||||
@@ -52,30 +47,3 @@ func TestUpload(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, fileContent, createdFileContent)
|
require.Equal(t, fileContent, createdFileContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
|
||||||
// address, then releases it so the server under test can rebind. The close/
|
|
||||||
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
|
||||||
// it's essentially never contended in practice.
|
|
||||||
func reserveLoopbackPort(t *testing.T) string {
|
|
||||||
t.Helper()
|
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
addr := l.Addr().String()
|
|
||||||
require.NoError(t, l.Close())
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForServer(t *testing.T, addr string) {
|
|
||||||
t.Helper()
|
|
||||||
deadline := time.Now().Add(5 * time.Second)
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
|
||||||
if err == nil {
|
|
||||||
_ = c.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
}
|
|
||||||
t.Fatalf("server did not start listening on %s in time", addr)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultResolvConfPath = "/etc/resolv.conf"
|
defaultResolvConfPath = "/etc/resolv.conf"
|
||||||
nsswitchConfPath = "/etc/nsswitch.conf"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type resolvConf struct {
|
type resolvConf struct {
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
|
||||||
"net"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
c.dispatch(w, r, math.MaxInt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dispatch routes a DNS request through the chain, skipping handlers with
|
|
||||||
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
|
||||||
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
|
|||||||
|
|
||||||
// Try handlers in priority order
|
// Try handlers in priority order
|
||||||
for _, entry := range handlers {
|
for _, entry := range handlers {
|
||||||
if entry.Priority > maxPriority {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !c.isHandlerMatch(qname, entry) {
|
if !c.isHandlerMatch(qname, entry) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
|||||||
cw.response.Len(), meta, time.Since(startTime))
|
cw.response.Len(), meta, time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
|
||||||
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
|
||||||
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
|
||||||
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
|
||||||
// (bounded by the invoked handler's internal timeout).
|
|
||||||
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
|
||||||
if len(r.Question) == 0 {
|
|
||||||
return nil, fmt.Errorf("empty question")
|
|
||||||
}
|
|
||||||
|
|
||||||
base := &internalResponseWriter{}
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
c.dispatch(base, r, maxPriority)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-ctx.Done():
|
|
||||||
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
|
||||||
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
|
||||||
strings.ToLower(r.Question[0].Name), maxPriority)
|
|
||||||
}
|
|
||||||
return base.response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
|
||||||
// priority ≤ maxPriority.
|
|
||||||
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
for _, h := range c.handlers {
|
|
||||||
if h.Pattern == "." && h.Priority <= maxPriority {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||||
switch {
|
switch {
|
||||||
case entry.Pattern == ".":
|
case entry.Pattern == ".":
|
||||||
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
|
||||||
type internalResponseWriter struct {
|
|
||||||
response *dns.Msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
|
||||||
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
|
||||||
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
|
||||||
|
|
||||||
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
|
||||||
// still surface their answer to ResolveInternal.
|
|
||||||
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
if err := msg.Unpack(p); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
w.response = msg
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *internalResponseWriter) Close() error { return nil }
|
|
||||||
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
|
||||||
|
|
||||||
// TsigTimersOnly is part of dns.ResponseWriter.
|
|
||||||
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
|
||||||
// no-op: in-process queries carry no TSIG state.
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hijack is part of dns.ResponseWriter.
|
|
||||||
func (w *internalResponseWriter) Hijack() {
|
|
||||||
// no-op: in-process queries have no underlying connection to hand off.
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,15 +1,11 @@
|
|||||||
package dns_test
|
package dns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
|
||||||
// which handler ResolveInternal dispatches to.
|
|
||||||
type answeringHandler struct {
|
|
||||||
name string
|
|
||||||
ip string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(r)
|
|
||||||
resp.Answer = []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP(h.ip).To4(),
|
|
||||||
}}
|
|
||||||
_ = w.WriteMsg(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *answeringHandler) String() string { return h.name }
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
|
|
||||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
|
||||||
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
|
||||||
|
|
||||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
|
||||||
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, resp)
|
|
||||||
assert.Equal(t, 1, len(resp.Answer))
|
|
||||||
a, ok := resp.Answer[0].(*dns.A)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
|
||||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
|
||||||
assert.Error(t, err, "no handler at or below maxPriority should error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
|
||||||
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
|
||||||
type rawWriteHandler struct {
|
|
||||||
ip string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(r)
|
|
||||||
resp.Answer = []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP(h.ip).To4(),
|
|
||||||
}}
|
|
||||||
packed, err := resp.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, _ = w.Write(packed)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
require.Len(t, resp.Answer, 1)
|
|
||||||
a, ok := resp.Answer[0].(*dns.A)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
|
||||||
type hangingHandler struct {
|
|
||||||
block chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
<-h.block
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(r)
|
|
||||||
_ = w.WriteMsg(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hangingHandler) String() string { return "hangingHandler" }
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
h := &hangingHandler{block: make(chan struct{})}
|
|
||||||
defer close(h.block)
|
|
||||||
|
|
||||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
|
||||||
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
|
||||||
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
|
||||||
|
|
||||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
|
||||||
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
|
||||||
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
|
||||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
|
||||||
|
|
||||||
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
|
||||||
|
|
||||||
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
|
||||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
|
||||||
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
|
||||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
|
||||||
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface string) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
osManager, reason, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
log.Infof("System DNS manager discovered: %s", osManager)
|
||||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOSDNSManagerType() (osManagerType, string, error) {
|
func getOSDNSManagerType() (osManagerType, error) {
|
||||||
resolved := isSystemdResolvedRunning()
|
|
||||||
nss := isLibnssResolveUsed()
|
|
||||||
stub := checkStub()
|
|
||||||
|
|
||||||
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
|
||||||
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
|
||||||
// that go through nss-resolve, and in foreign mode they can loop back
|
|
||||||
// through resolved as an upstream.
|
|
||||||
if resolved && (nss || stub) {
|
|
||||||
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mgr, reason, rejected, err := scanResolvConfHeader()
|
|
||||||
if err != nil {
|
|
||||||
return 0, "", err
|
|
||||||
}
|
|
||||||
if reason != "" {
|
|
||||||
return mgr, reason, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
|
||||||
if len(rejected) > 0 {
|
|
||||||
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
|
||||||
}
|
|
||||||
return fileManager, fallback, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
|
||||||
// matching manager. If reason is empty the caller should pick file mode and
|
|
||||||
// use rejected for diagnostics.
|
|
||||||
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
|
||||||
file, err := os.Open(defaultResolvConfPath)
|
file, err := os.Open(defaultResolvConfPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if cerr := file.Close(); cerr != nil {
|
if err := file.Close(); err != nil {
|
||||||
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rejected []string
|
|
||||||
scanner := bufio.NewScanner(file)
|
scanner := bufio.NewScanner(file)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
text := scanner.Text()
|
text := scanner.Text()
|
||||||
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if text[0] != '#' {
|
if text[0] != '#' {
|
||||||
break
|
return fileManager, nil
|
||||||
}
|
}
|
||||||
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||||
return mgr, reason, nil, nil
|
return netbirdManager, nil
|
||||||
} else if rej != "" {
|
}
|
||||||
rejected = append(rejected, rej)
|
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||||
|
return networkManager, nil
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||||
|
if checkStub() {
|
||||||
|
return systemdManager, nil
|
||||||
|
} else {
|
||||||
|
return fileManager, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "resolvconf") {
|
||||||
|
if isSystemdResolveConfMode() {
|
||||||
|
return systemdManager, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolvConfManager, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||||
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
return 0, fmt.Errorf("scan: %w", err)
|
||||||
}
|
}
|
||||||
return 0, "", rejected, nil
|
|
||||||
|
return fileManager, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchResolvConfHeader inspects a single comment line. Returns either a
|
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||||
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
|
||||||
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
|
||||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
|
||||||
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "NetworkManager") {
|
|
||||||
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
|
||||||
return networkManager, "NetworkManager header + supported version on dbus", ""
|
|
||||||
}
|
|
||||||
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "resolvconf") {
|
|
||||||
if isSystemdResolveConfMode() {
|
|
||||||
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
|
||||||
}
|
|
||||||
return resolvConfManager, "resolvconf header detected", ""
|
|
||||||
}
|
|
||||||
return 0, "", ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
|
||||||
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
|
||||||
// into file mode while resolved is active.
|
|
||||||
func checkStub() bool {
|
func checkStub() bool {
|
||||||
rConf, err := parseDefaultResolvConf()
|
rConf, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
log.Warnf("failed to parse resolv conf: %s", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,36 +139,3 @@ func checkStub() bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
|
||||||
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
|
||||||
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
|
||||||
func isLibnssResolveUsed() bool {
|
|
||||||
bs, err := os.ReadFile(nsswitchConfPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return parseNsswitchResolveAhead(bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseNsswitchResolveAhead(data []byte) bool {
|
|
||||||
for _, line := range strings.Split(string(data), "\n") {
|
|
||||||
if i := strings.IndexByte(line, '#'); i >= 0 {
|
|
||||||
line = line[:i]
|
|
||||||
}
|
|
||||||
fields := strings.Fields(line)
|
|
||||||
if len(fields) < 2 || fields[0] != "hosts:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, module := range fields[1:] {
|
|
||||||
switch module {
|
|
||||||
case "dns":
|
|
||||||
return false
|
|
||||||
case "resolve":
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,76 +0,0 @@
|
|||||||
//go:build (linux && !android) || freebsd
|
|
||||||
|
|
||||||
package dns
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestParseNsswitchResolveAhead(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "resolve before dns with action token",
|
|
||||||
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "dns before resolve",
|
|
||||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "debian default with only dns",
|
|
||||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "neither resolve nor dns",
|
|
||||||
in: "hosts: files myhostname\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no hosts line",
|
|
||||||
in: "passwd: files systemd\ngroup: files systemd\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty",
|
|
||||||
in: "",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "comments and blank lines ignored",
|
|
||||||
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "trailing inline comment",
|
|
||||||
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hosts token must be the first field",
|
|
||||||
in: " hosts: resolve dns\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "other db line mentioning resolve is ignored",
|
|
||||||
in: "networks: resolve\nhosts: dns\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "only resolve, no dns",
|
|
||||||
in: "hosts: files resolve\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
|
||||||
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,83 +2,40 @@ package mgmt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sync/singleflight"
|
|
||||||
|
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const dnsTimeout = 5 * time.Second
|
||||||
dnsTimeout = 5 * time.Second
|
|
||||||
defaultTTL = 300 * time.Second
|
|
||||||
refreshBackoff = 30 * time.Second
|
|
||||||
|
|
||||||
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
// Resolver caches critical NetBird infrastructure domains
|
||||||
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
|
||||||
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
|
||||||
// system resolver.
|
|
||||||
type ChainResolver interface {
|
|
||||||
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
|
||||||
HasRootHandlerAtOrBelow(maxPriority int) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
|
||||||
// records and cachedAt are set at construction and treated as immutable;
|
|
||||||
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
|
||||||
// Resolver.mutex.
|
|
||||||
type cachedRecord struct {
|
|
||||||
records []dns.RR
|
|
||||||
cachedAt time.Time
|
|
||||||
lastFailedRefresh time.Time
|
|
||||||
consecFailures int
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolver caches critical NetBird infrastructure domains.
|
|
||||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
records map[dns.Question]*cachedRecord
|
records map[dns.Question][]dns.RR
|
||||||
mgmtDomain *domain.Domain
|
mgmtDomain *domain.Domain
|
||||||
serverDomains *dnsconfig.ServerDomains
|
serverDomains *dnsconfig.ServerDomains
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
chain ChainResolver
|
type ipsResponse struct {
|
||||||
chainMaxPriority int
|
ips []netip.Addr
|
||||||
refreshGroup singleflight.Group
|
err error
|
||||||
|
|
||||||
// refreshing tracks questions whose refresh is running via the OS
|
|
||||||
// fallback path. A ServeDNS hit for a question in this map indicates
|
|
||||||
// the OS resolver routed the recursive query back to us (loop). Only
|
|
||||||
// the OS path arms this so chain-path refreshes don't produce false
|
|
||||||
// positives. The atomic bool is CAS-flipped once per refresh to
|
|
||||||
// throttle the warning log.
|
|
||||||
refreshing map[dns.Question]*atomic.Bool
|
|
||||||
|
|
||||||
cacheTTL time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
records: make(map[dns.Question]*cachedRecord),
|
records: make(map[dns.Question][]dns.RR),
|
||||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
|
||||||
cacheTTL: resolveCacheTTL(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
|
|||||||
return "MgmtCacheResolver"
|
return "MgmtCacheResolver"
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
// ServeDNS implements dns.Handler interface.
|
||||||
// maxPriority caps which handlers may answer refresh queries (typically
|
|
||||||
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
|
||||||
// mgmt/route/local handlers are skipped).
|
|
||||||
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
m.chain = chain
|
|
||||||
m.chainMaxPriority = maxPriority
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
|
||||||
// immediately and refreshed asynchronously (stale-while-revalidate).
|
|
||||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
m.continueToNext(w, r)
|
m.continueToNext(w, r)
|
||||||
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
cached, found := m.records[question]
|
records, found := m.records[question]
|
||||||
inflight := m.refreshing[question]
|
|
||||||
var shouldRefresh bool
|
|
||||||
if found {
|
|
||||||
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
|
||||||
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
|
||||||
shouldRefresh = stale && !inBackoff
|
|
||||||
}
|
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
|
||||||
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
|
||||||
question.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip scheduling a refresh goroutine if one is already inflight for
|
|
||||||
// this question; singleflight would dedup anyway but skipping avoids
|
|
||||||
// a parked goroutine per stale hit under bursty load.
|
|
||||||
if shouldRefresh && inflight == nil {
|
|
||||||
m.scheduleRefresh(question, cached)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &dns.Msg{}
|
resp := &dns.Msg{}
|
||||||
resp.SetReply(r)
|
resp.SetReply(r)
|
||||||
resp.Authoritative = false
|
resp.Authoritative = false
|
||||||
resp.RecursionAvailable = true
|
resp.RecursionAvailable = true
|
||||||
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
|
||||||
|
resp.Answer = append(resp.Answer, records...)
|
||||||
|
|
||||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||||
|
|
||||||
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
// AddDomain manually adds a domain to cache by resolving it.
|
||||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
|
||||||
// entry for that qtype.
|
|
||||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||||
|
if err != nil {
|
||||||
if errA != nil && errAAAA != nil {
|
return err
|
||||||
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
var aRecords, aaaaRecords []dns.RR
|
||||||
if err := errors.Join(errA, errAAAA); err != nil {
|
for _, ip := range ips {
|
||||||
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
if ip.Is4() {
|
||||||
|
rr := &dns.A{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: dnsName,
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300,
|
||||||
|
},
|
||||||
|
A: ip.AsSlice(),
|
||||||
|
}
|
||||||
|
aRecords = append(aRecords, rr)
|
||||||
|
} else if ip.Is6() {
|
||||||
|
rr := &dns.AAAA{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: dnsName,
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300,
|
||||||
|
},
|
||||||
|
AAAA: ip.AsSlice(),
|
||||||
|
}
|
||||||
|
aaaaRecords = append(aaaaRecords, rr)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
if len(aRecords) > 0 {
|
||||||
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
aQuestion := dns.Question{
|
||||||
|
Name: dnsName,
|
||||||
|
Qtype: dns.TypeA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
m.records[aQuestion] = aRecords
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
if len(aaaaRecords) > 0 {
|
||||||
|
aaaaQuestion := dns.Question{
|
||||||
|
Name: dnsName,
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
m.records[aaaaQuestion] = aaaaRecords
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||||
// untouched on error. Caller holds m.mutex.
|
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||||
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
resultChan := make(chan *ipsResponse, 1)
|
||||||
switch {
|
|
||||||
case len(records) > 0:
|
|
||||||
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
|
||||||
case err == nil:
|
|
||||||
delete(m.records, q)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
go func() {
|
||||||
// unique in-flight key; bursty stale hits share its channel. expected is the
|
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||||
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
resultChan <- &ipsResponse{
|
||||||
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
err: err,
|
||||||
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
ips: ips,
|
||||||
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
|
||||||
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
|
||||||
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
|
||||||
return nil, m.refreshQuestion(question, expected)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// refreshQuestion replaces the cached records on success, or marks the entry
|
|
||||||
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
|
||||||
// a resolver loop by spotting a query for this same question arriving on us.
|
|
||||||
// expected pins the cache entry observed at schedule time; mutations only apply
|
|
||||||
// if m.records[question] still points at it.
|
|
||||||
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
|
||||||
if err != nil {
|
|
||||||
m.markRefreshFailed(question, expected)
|
|
||||||
return fmt.Errorf("parse domain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
records, err := m.lookupRecords(ctx, d, question)
|
|
||||||
if err != nil {
|
|
||||||
fails := m.markRefreshFailed(question, expected)
|
|
||||||
logf := log.Warnf
|
|
||||||
if fails == 0 || fails > 1 {
|
|
||||||
logf = log.Debugf
|
|
||||||
}
|
}
|
||||||
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
}()
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
|
||||||
return err
|
var resp *ipsResponse
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||||
|
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||||
|
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case resp = <-resultChan:
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
if resp.err != nil {
|
||||||
if len(records) == 0 {
|
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||||
m.mutex.Lock()
|
|
||||||
if m.records[question] == expected {
|
|
||||||
delete(m.records, question)
|
|
||||||
m.mutex.Unlock()
|
|
||||||
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
return resp.ips, nil
|
||||||
now := time.Now()
|
|
||||||
m.mutex.Lock()
|
|
||||||
if m.records[question] != expected {
|
|
||||||
m.mutex.Unlock()
|
|
||||||
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Resolver) markRefreshing(question dns.Question) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
m.refreshing[question] = &atomic.Bool{}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Resolver) clearRefreshing(question dns.Question) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
delete(m.refreshing, question)
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
|
||||||
// count so callers can downgrade subsequent failure logs to debug.
|
|
||||||
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
c, ok := m.records[question]
|
|
||||||
if !ok || c != expected {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
c.lastFailedRefresh = time.Now()
|
|
||||||
c.consecFailures++
|
|
||||||
return c.consecFailures
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
|
||||||
// callers tell records, NODATA (nil err, no records), and failure apart.
|
|
||||||
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
|
||||||
m.mutex.RLock()
|
|
||||||
chain := m.chain
|
|
||||||
maxPriority := m.chainMaxPriority
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
|
|
||||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
|
||||||
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
|
||||||
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
|
||||||
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
|
||||||
// not the system resolver, so net.DefaultResolver will not loop back.
|
|
||||||
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
|
||||||
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
|
||||||
// arms the loop detector for the duration of its call so that ServeDNS can
|
|
||||||
// spot the OS resolver routing the recursive query back to us.
|
|
||||||
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
|
||||||
m.mutex.RLock()
|
|
||||||
chain := m.chain
|
|
||||||
maxPriority := m.chainMaxPriority
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
|
|
||||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
|
||||||
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: drop once every supported OS registers a fallback resolver.
|
|
||||||
m.markRefreshing(q)
|
|
||||||
defer m.clearRefreshing(q)
|
|
||||||
|
|
||||||
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
|
||||||
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
|
||||||
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
|
||||||
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
|
||||||
msg := &dns.Msg{}
|
|
||||||
msg.SetQuestion(dnsName, qtype)
|
|
||||||
msg.RecursionDesired = true
|
|
||||||
|
|
||||||
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("chain resolve: %w", err)
|
|
||||||
}
|
|
||||||
if resp == nil {
|
|
||||||
return nil, fmt.Errorf("chain resolve returned nil response")
|
|
||||||
}
|
|
||||||
if resp.Rcode != dns.RcodeSuccess {
|
|
||||||
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
|
||||||
}
|
|
||||||
|
|
||||||
ttl := uint32(m.cacheTTL.Seconds())
|
|
||||||
owners := cnameOwners(dnsName, resp.Answer)
|
|
||||||
var filtered []dns.RR
|
|
||||||
for _, rr := range resp.Answer {
|
|
||||||
h := rr.Header()
|
|
||||||
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
|
||||||
filtered = append(filtered, cp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return filtered, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
|
||||||
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
|
||||||
// returns (nil, nil).
|
|
||||||
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
|
||||||
network := resutil.NetworkForQtype(qtype)
|
|
||||||
if network == "" {
|
|
||||||
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
|
||||||
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
|
||||||
|
|
||||||
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
|
||||||
if result.Rcode == dns.RcodeSuccess {
|
|
||||||
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Err != nil {
|
|
||||||
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
|
||||||
}
|
|
||||||
|
|
||||||
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
|
||||||
// so downstream resolvers don't cache an answer for longer than we will.
|
|
||||||
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
|
||||||
remaining := m.cacheTTL - time.Since(cachedAt)
|
|
||||||
if remaining <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return uint32((remaining + time.Second - 1) / time.Second)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||||
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
aQuestion := dns.Question{
|
||||||
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
Name: dnsName,
|
||||||
delete(m.records, qA)
|
Qtype: dns.TypeA,
|
||||||
delete(m.records, qAAAA)
|
Qclass: dns.ClassINET,
|
||||||
delete(m.refreshing, qA)
|
}
|
||||||
delete(m.refreshing, qAAAA)
|
delete(m.records, aQuestion)
|
||||||
|
|
||||||
|
aaaaQuestion := dns.Question{
|
||||||
|
Name: dnsName,
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
delete(m.records, aaaaQuestion)
|
||||||
|
|
||||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||||
return nil
|
return nil
|
||||||
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
|||||||
|
|
||||||
return domains
|
return domains
|
||||||
}
|
}
|
||||||
|
|
||||||
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
|
||||||
// A/AAAA records return nil.
|
|
||||||
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
|
||||||
switch r := rr.(type) {
|
|
||||||
case *dns.A:
|
|
||||||
cp := *r
|
|
||||||
cp.Hdr.Name = owner
|
|
||||||
cp.Hdr.Ttl = ttl
|
|
||||||
cp.A = slices.Clone(r.A)
|
|
||||||
return &cp
|
|
||||||
case *dns.AAAA:
|
|
||||||
cp := *r
|
|
||||||
cp.Hdr.Name = owner
|
|
||||||
cp.Hdr.Ttl = ttl
|
|
||||||
cp.AAAA = slices.Clone(r.AAAA)
|
|
||||||
return &cp
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
|
||||||
// stamping ttl so the response shares no memory with the cached slice.
|
|
||||||
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
|
||||||
out := make([]dns.RR, 0, len(records))
|
|
||||||
for _, rr := range records {
|
|
||||||
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
|
||||||
out = append(out, cp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
|
||||||
// in answer, iterating until fixed point so out-of-order chains resolve.
|
|
||||||
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
|
||||||
owners := map[string]bool{dnsName: true}
|
|
||||||
for {
|
|
||||||
added := false
|
|
||||||
for _, rr := range answer {
|
|
||||||
cname, ok := rr.(*dns.CNAME)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
|
||||||
if !owners[name] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
target := strings.ToLower(dns.Fqdn(cname.Target))
|
|
||||||
if !owners[target] {
|
|
||||||
owners[target] = true
|
|
||||||
added = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !added {
|
|
||||||
return owners
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
|
||||||
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
|
||||||
func resolveCacheTTL() time.Duration {
|
|
||||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
|
||||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return defaultTTL
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,408 +0,0 @@
|
|||||||
package mgmt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
)
|
|
||||||
|
|
||||||
type fakeChain struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
calls map[string]int
|
|
||||||
answers map[string][]dns.RR
|
|
||||||
err error
|
|
||||||
hasRoot bool
|
|
||||||
onLookup func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFakeChain() *fakeChain {
|
|
||||||
return &fakeChain{
|
|
||||||
calls: map[string]int{},
|
|
||||||
answers: map[string][]dns.RR{},
|
|
||||||
hasRoot: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
return f.hasRoot
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
|
||||||
f.mu.Lock()
|
|
||||||
q := msg.Question[0]
|
|
||||||
key := q.Name + "|" + dns.TypeToString[q.Qtype]
|
|
||||||
f.calls[key]++
|
|
||||||
answers := f.answers[key]
|
|
||||||
err := f.err
|
|
||||||
onLookup := f.onLookup
|
|
||||||
f.mu.Unlock()
|
|
||||||
|
|
||||||
if onLookup != nil {
|
|
||||||
onLookup()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(msg)
|
|
||||||
resp.Answer = answers
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
key := name + "|" + dns.TypeToString[qtype]
|
|
||||||
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
|
|
||||||
switch qtype {
|
|
||||||
case dns.TypeA:
|
|
||||||
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
|
|
||||||
case dns.TypeAAAA:
|
|
||||||
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
return f.calls[name+"|"+dns.TypeToString[qtype]]
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitFor polls the predicate until it returns true or the deadline passes.
|
|
||||||
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
|
|
||||||
t.Helper()
|
|
||||||
deadline := time.Now().Add(d)
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
if fn() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Millisecond)
|
|
||||||
}
|
|
||||||
t.Fatalf("condition not met within %s", d)
|
|
||||||
}
|
|
||||||
|
|
||||||
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
|
|
||||||
t.Helper()
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
msg.SetQuestion(name, dns.TypeA)
|
|
||||||
w := &test.MockResponseWriter{}
|
|
||||||
r.ServeDNS(w, msg)
|
|
||||||
return w.GetLastResponse()
|
|
||||||
}
|
|
||||||
|
|
||||||
func firstA(t *testing.T, resp *dns.Msg) string {
|
|
||||||
t.Helper()
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
|
|
||||||
a, ok := resp.Answer[0].(*dns.A)
|
|
||||||
require.True(t, ok, "expected A record")
|
|
||||||
return a.A.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
|
||||||
// Same cached entry age, different cacheTTL values: the shorter TTL must
|
|
||||||
// trigger a background refresh, the longer one must not. Proves that the
|
|
||||||
// per-Resolver cacheTTL field actually drives the stale decision.
|
|
||||||
cachedAt := time.Now().Add(-100 * time.Millisecond)
|
|
||||||
|
|
||||||
newRec := func() *cachedRecord {
|
|
||||||
return &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: cachedAt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
|
|
||||||
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
r.cacheTTL = 10 * time.Millisecond
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
r.records[q] = newRec()
|
|
||||||
|
|
||||||
resp := queryA(t, r, q.Name)
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
|
||||||
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
return chain.callCount(q.Name, dns.TypeA) >= 1
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
r.cacheTTL = time.Hour
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
r.records[q] = newRec()
|
|
||||||
|
|
||||||
resp := queryA(t, r, q.Name)
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(), // fresh
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
|
||||||
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
|
|
||||||
}
|
|
||||||
|
|
||||||
// First query: serves stale immediately.
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
|
||||||
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
|
|
||||||
})
|
|
||||||
|
|
||||||
// Next query should now return the refreshed IP.
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
|
|
||||||
var inflight atomic.Int32
|
|
||||||
var maxInflight atomic.Int32
|
|
||||||
chain.onLookup = func() {
|
|
||||||
cur := inflight.Add(1)
|
|
||||||
defer inflight.Add(-1)
|
|
||||||
for {
|
|
||||||
prev := maxInflight.Load()
|
|
||||||
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
|
|
||||||
}
|
|
||||||
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 50; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
waitFor(t, 2*time.Second, func() bool {
|
|
||||||
return inflight.Load() == 0
|
|
||||||
})
|
|
||||||
|
|
||||||
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
|
|
||||||
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
|
|
||||||
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.err = errors.New("boom")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
|
||||||
}
|
|
||||||
|
|
||||||
// First stale hit triggers a refresh attempt that fails.
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
|
|
||||||
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
|
|
||||||
})
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
r.mutex.RLock()
|
|
||||||
defer r.mutex.RUnlock()
|
|
||||||
c, ok := r.records[q]
|
|
||||||
return ok && !c.lastFailedRefresh.IsZero()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Subsequent stale hits within backoff window should not schedule more refreshes.
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
}
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.hasRoot = false
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
// With hasRoot=false the chain must not be consulted. Use a short
|
|
||||||
// deadline so the OS fallback returns quickly without waiting on a
|
|
||||||
// real network call in CI.
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
|
|
||||||
|
|
||||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
|
|
||||||
"chain must not be used when no root handler is registered at the bound priority")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
|
||||||
// ServeDNS being invoked for a question while a refresh for that question
|
|
||||||
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
|
||||||
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
|
||||||
r := NewResolver()
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate an inflight refresh.
|
|
||||||
r.markRefreshing(q)
|
|
||||||
defer r.clearRefreshing(q)
|
|
||||||
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
|
|
||||||
|
|
||||||
r.mutex.RLock()
|
|
||||||
inflight := r.refreshing[q]
|
|
||||||
r.mutex.RUnlock()
|
|
||||||
require.NotNil(t, inflight)
|
|
||||||
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
r.markRefreshing(q)
|
|
||||||
defer r.clearRefreshing(q)
|
|
||||||
|
|
||||||
// Multiple ServeDNS calls during the same refresh must not re-set the flag
|
|
||||||
// (CompareAndSwap from false -> true returns true only on the first call).
|
|
||||||
for range 5 {
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mutex.RLock()
|
|
||||||
inflight := r.refreshing[q]
|
|
||||||
r.mutex.RUnlock()
|
|
||||||
assert.True(t, inflight.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
|
|
||||||
r.mutex.RLock()
|
|
||||||
_, ok := r.refreshing[q]
|
|
||||||
r.mutex.RUnlock()
|
|
||||||
assert.False(t, ok, "no refresh inflight means no loop tracking")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
|
|
||||||
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.2", firstA(t, resp))
|
|
||||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
|
|
||||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
|
|
||||||
}
|
|
||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
|
|||||||
assert.False(t, resolver.MatchSubdomains())
|
assert.False(t, resolver.MatchSubdomains())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveCacheTTL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
value string
|
|
||||||
want time.Duration
|
|
||||||
}{
|
|
||||||
{"unset falls back to default", "", defaultTTL},
|
|
||||||
{"valid duration", "45s", 45 * time.Second},
|
|
||||||
{"valid minutes", "2m", 2 * time.Minute},
|
|
||||||
{"malformed falls back to default", "not-a-duration", defaultTTL},
|
|
||||||
{"zero falls back to default", "0s", defaultTTL},
|
|
||||||
{"negative falls back to default", "-5s", defaultTTL},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Setenv(envMgmtCacheTTL, tc.value)
|
|
||||||
got := resolveCacheTTL()
|
|
||||||
assert.Equal(t, tc.want, got, "parsed TTL should match")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
|
||||||
t.Setenv(envMgmtCacheTTL, "7s")
|
|
||||||
r := NewResolver()
|
|
||||||
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ResponseTTL(t *testing.T) {
|
|
||||||
now := time.Now()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
cacheTTL time.Duration
|
|
||||||
cachedAt time.Time
|
|
||||||
wantMin uint32
|
|
||||||
wantMax uint32
|
|
||||||
}{
|
|
||||||
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
|
|
||||||
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
|
|
||||||
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
|
|
||||||
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
r := &Resolver{cacheTTL: tc.cacheTTL}
|
|
||||||
got := r.responseTTL(tc.cachedAt)
|
|
||||||
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
|
|
||||||
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -212,7 +212,6 @@ func newDefaultServer(
|
|||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
|
|
||||||
mgmtCacheResolver := mgmt.NewResolver()
|
mgmtCacheResolver := mgmt.NewResolver()
|
||||||
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
|
||||||
|
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -141,7 +140,6 @@ type EngineConfig struct {
|
|||||||
ProfileConfig *profilemanager.Config
|
ProfileConfig *profilemanager.Config
|
||||||
|
|
||||||
LogPath string
|
LogPath string
|
||||||
TempDir string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// EngineServices holds the external service dependencies required by the Engine.
|
// EngineServices holds the external service dependencies required by the Engine.
|
||||||
@@ -571,7 +569,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
e.connMgr.Start(e.ctx)
|
e.connMgr.Start(e.ctx)
|
||||||
|
|
||||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||||
e.srWatcher.Start(peer.IsForceRelayed())
|
e.srWatcher.Start()
|
||||||
|
|
||||||
e.receiveSignalEvents()
|
e.receiveSignalEvents()
|
||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
@@ -605,8 +603,6 @@ func (e *Engine) createFirewall() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
firewalld.SetParentContext(e.ctx)
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -944,12 +940,7 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
|||||||
return fmt.Errorf("update relay token: %w", err)
|
return fmt.Errorf("update relay token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
urls := update.Urls
|
e.relayManager.UpdateServerURLs(update.Urls)
|
||||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
|
||||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
|
||||||
urls = override
|
|
||||||
}
|
|
||||||
e.relayManager.UpdateServerURLs(urls)
|
|
||||||
|
|
||||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||||
@@ -1104,7 +1095,6 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
|||||||
StatusRecorder: e.statusRecorder,
|
StatusRecorder: e.statusRecorder,
|
||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogPath: e.config.LogPath,
|
LogPath: e.config.LogPath,
|
||||||
TempDir: e.config.TempDir,
|
|
||||||
ClientMetrics: e.clientMetrics,
|
ClientMetrics: e.clientMetrics,
|
||||||
RefreshStatus: func() {
|
RefreshStatus: func() {
|
||||||
e.RunHealthProbes(true)
|
e.RunHealthProbes(true)
|
||||||
|
|||||||
@@ -22,8 +22,4 @@ type MobileDependency struct {
|
|||||||
DnsManager dns.IosDnsManager
|
DnsManager dns.IosDnsManager
|
||||||
FileDescriptor int32
|
FileDescriptor int32
|
||||||
StateFilePath string
|
StateFilePath string
|
||||||
|
|
||||||
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
|
|
||||||
// On Android, this should be set to the app's cache directory.
|
|
||||||
TempDir string
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -185,20 +185,17 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
|
|
||||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||||
|
|
||||||
forceRelay := IsForceRelayed()
|
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||||
if !forceRelay {
|
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
if err != nil {
|
||||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
return err
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
conn.workerICE = workerICE
|
|
||||||
}
|
}
|
||||||
|
conn.workerICE = workerICE
|
||||||
|
|
||||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
||||||
|
|
||||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||||
if !forceRelay {
|
if !isForceRelayed() {
|
||||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -254,9 +251,7 @@ func (conn *Conn) Close(signalToRemote bool) {
|
|||||||
conn.wgWatcherCancel()
|
conn.wgWatcherCancel()
|
||||||
}
|
}
|
||||||
conn.workerRelay.CloseConn()
|
conn.workerRelay.CloseConn()
|
||||||
if conn.workerICE != nil {
|
conn.workerICE.Close()
|
||||||
conn.workerICE.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
err := conn.wgProxyRelay.CloseConn()
|
err := conn.wgProxyRelay.CloseConn()
|
||||||
@@ -299,9 +294,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
|
|||||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||||
conn.dumpState.RemoteCandidate()
|
conn.dumpState.RemoteCandidate()
|
||||||
if conn.workerICE != nil {
|
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||||
@@ -719,35 +712,33 @@ func (conn *Conn) evalStatus() ConnStatus {
|
|||||||
return StatusConnecting
|
return StatusConnecting
|
||||||
}
|
}
|
||||||
|
|
||||||
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
|
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||||
//
|
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||||
// The result is a tri-state:
|
|
||||||
// - ConnStatusConnected: all available transports are up
|
|
||||||
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
|
|
||||||
// - ConnStatusDisconnected: no working transport
|
|
||||||
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if status == guard.ConnStatusDisconnected {
|
if !connected {
|
||||||
conn.logTraceConnState()
|
conn.logTraceConnState()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
iceWorkerCreated := conn.workerICE != nil
|
// For JS platform: only relay connection is supported
|
||||||
|
if runtime.GOOS == "js" {
|
||||||
var iceInProgress bool
|
return conn.statusRelay.Get() == worker.StatusConnected
|
||||||
if iceWorkerCreated {
|
|
||||||
iceInProgress = conn.workerICE.InProgress()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return evalConnStatus(connStatusInputs{
|
// For non-JS platforms: check ICE connection status
|
||||||
forceRelay: IsForceRelayed(),
|
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||||
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
|
return false
|
||||||
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
|
}
|
||||||
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
|
|
||||||
iceWorkerCreated: iceWorkerCreated,
|
// If relay is supported with peer, it must also be connected
|
||||||
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
|
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||||
iceInProgress: iceInProgress,
|
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||||
})
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||||
@@ -935,43 +926,3 @@ func isController(config ConnConfig) bool {
|
|||||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||||
return remoteRosenpassPubKey != nil
|
return remoteRosenpassPubKey != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
|
|
||||||
// "Relay up and needed" — the peer uses relay and the transport is connected.
|
|
||||||
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
|
|
||||||
|
|
||||||
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
|
|
||||||
if in.forceRelay {
|
|
||||||
return boolToConnStatus(relayUsedAndUp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remote peer doesn't support ICE, or we haven't created the worker yet:
|
|
||||||
// relay is the only possible transport.
|
|
||||||
if !in.remoteSupportsICE || !in.iceWorkerCreated {
|
|
||||||
return boolToConnStatus(relayUsedAndUp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ICE counts as "up" when the status is anything other than Disconnected, OR
|
|
||||||
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
|
|
||||||
iceUp := in.iceStatusConnecting || in.iceInProgress
|
|
||||||
|
|
||||||
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
|
|
||||||
relayOK := !in.peerUsesRelay || in.relayConnected
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case iceUp && relayOK:
|
|
||||||
return guard.ConnStatusConnected
|
|
||||||
case relayUsedAndUp:
|
|
||||||
// Relay is up but ICE is down — partially connected.
|
|
||||||
return guard.ConnStatusPartiallyConnected
|
|
||||||
default:
|
|
||||||
return guard.ConnStatusDisconnected
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func boolToConnStatus(connected bool) guard.ConnStatus {
|
|
||||||
if connected {
|
|
||||||
return guard.ConnStatusConnected
|
|
||||||
}
|
|
||||||
return guard.ConnStatusDisconnected
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,20 +13,6 @@ const (
|
|||||||
StatusConnected
|
StatusConnected
|
||||||
)
|
)
|
||||||
|
|
||||||
// connStatusInputs is the primitive-valued snapshot of the state that drives the
|
|
||||||
// tri-state connection classification. Extracted so the decision logic can be unit-tested
|
|
||||||
// without constructing full Worker/Handshaker objects.
|
|
||||||
type connStatusInputs struct {
|
|
||||||
forceRelay bool // NB_FORCE_RELAY or JS/WASM
|
|
||||||
peerUsesRelay bool // remote peer advertises relay support AND local has relay
|
|
||||||
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
|
|
||||||
remoteSupportsICE bool // remote peer sent ICE credentials
|
|
||||||
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
|
|
||||||
iceStatusConnecting bool // statusICE is anything other than Disconnected
|
|
||||||
iceInProgress bool // a negotiation is currently in flight
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// ConnStatus describe the status of a peer's connection
|
// ConnStatus describe the status of a peer's connection
|
||||||
type ConnStatus int32
|
type ConnStatus int32
|
||||||
|
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEvalConnStatus_ForceRelay(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in connStatusInputs
|
|
||||||
want guard.ConnStatus
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "force relay, peer uses relay, relay up",
|
|
||||||
in: connStatusInputs{
|
|
||||||
forceRelay: true,
|
|
||||||
peerUsesRelay: true,
|
|
||||||
relayConnected: true,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "force relay, peer uses relay, relay down",
|
|
||||||
in: connStatusInputs{
|
|
||||||
forceRelay: true,
|
|
||||||
peerUsesRelay: true,
|
|
||||||
relayConnected: false,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "force relay, peer does NOT use relay - disconnected forever",
|
|
||||||
in: connStatusInputs{
|
|
||||||
forceRelay: true,
|
|
||||||
peerUsesRelay: false,
|
|
||||||
relayConnected: true,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
if got := evalConnStatus(tc.in); got != tc.want {
|
|
||||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in connStatusInputs
|
|
||||||
want guard.ConnStatus
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "remote does not support ICE, peer uses relay, relay up",
|
|
||||||
in: connStatusInputs{
|
|
||||||
peerUsesRelay: true,
|
|
||||||
relayConnected: true,
|
|
||||||
remoteSupportsICE: false,
|
|
||||||
iceWorkerCreated: true,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "remote does not support ICE, peer uses relay, relay down",
|
|
||||||
in: connStatusInputs{
|
|
||||||
peerUsesRelay: true,
|
|
||||||
relayConnected: false,
|
|
||||||
remoteSupportsICE: false,
|
|
||||||
iceWorkerCreated: true,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE worker not yet created, relay up",
|
|
||||||
in: connStatusInputs{
|
|
||||||
peerUsesRelay: true,
|
|
||||||
relayConnected: true,
|
|
||||||
remoteSupportsICE: true,
|
|
||||||
iceWorkerCreated: false,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "remote does not support ICE, peer does not use relay",
|
|
||||||
in: connStatusInputs{
|
|
||||||
peerUsesRelay: false,
|
|
||||||
relayConnected: false,
|
|
||||||
remoteSupportsICE: false,
|
|
||||||
iceWorkerCreated: true,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
if got := evalConnStatus(tc.in); got != tc.want {
|
|
||||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
|
|
||||||
base := connStatusInputs{
|
|
||||||
remoteSupportsICE: true,
|
|
||||||
iceWorkerCreated: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
mutator func(*connStatusInputs)
|
|
||||||
want guard.ConnStatus
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "ICE connected, relay connected, peer uses relay",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = true
|
|
||||||
in.relayConnected = true
|
|
||||||
in.iceStatusConnecting = true
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE connected, peer does NOT use relay",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = false
|
|
||||||
in.relayConnected = false
|
|
||||||
in.iceStatusConnecting = true
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE InProgress only, peer does NOT use relay",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = false
|
|
||||||
in.iceStatusConnecting = false
|
|
||||||
in.iceInProgress = true
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE down, relay up, peer uses relay -> partial",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = true
|
|
||||||
in.relayConnected = true
|
|
||||||
in.iceStatusConnecting = false
|
|
||||||
in.iceInProgress = false
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusPartiallyConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE down, peer does NOT use relay -> disconnected",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = false
|
|
||||||
in.relayConnected = false
|
|
||||||
in.iceStatusConnecting = false
|
|
||||||
in.iceInProgress = false
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = true
|
|
||||||
in.relayConnected = false
|
|
||||||
in.iceStatusConnecting = true
|
|
||||||
},
|
|
||||||
// relayOK = false (peer uses relay but it's down), iceUp = true
|
|
||||||
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
|
|
||||||
// falls into default: Disconnected.
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE down, relay up but peer does not use relay -> disconnected",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = false
|
|
||||||
in.relayConnected = true // not actually used since peer doesn't rely on it
|
|
||||||
in.iceStatusConnecting = false
|
|
||||||
in.iceInProgress = false
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
in := base
|
|
||||||
tc.mutator(&in)
|
|
||||||
if got := evalConnStatus(in); got != tc.want {
|
|
||||||
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -7,38 +7,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||||
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func IsForceRelayed() bool {
|
func isForceRelayed() bool {
|
||||||
if runtime.GOOS == "js" {
|
if runtime.GOOS == "js" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
// OverrideRelayURLs returns the relay server URL list set in
|
|
||||||
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
|
|
||||||
// the override is active. When the env var is unset, the boolean is false
|
|
||||||
// and the caller should keep the list received from the management server.
|
|
||||||
// Intended for lab/debug scenarios where a peer must pin to a specific home
|
|
||||||
// relay regardless of what management offers.
|
|
||||||
func OverrideRelayURLs() ([]string, bool) {
|
|
||||||
raw := os.Getenv(EnvKeyNBHomeRelayServers)
|
|
||||||
if raw == "" {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
parts := strings.Split(raw, ",")
|
|
||||||
urls := make([]string, 0, len(parts))
|
|
||||||
for _, p := range parts {
|
|
||||||
p = strings.TrimSpace(p)
|
|
||||||
if p != "" {
|
|
||||||
urls = append(urls, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(urls) == 0 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
return urls, true
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,19 +8,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnStatus represents the connection state as seen by the guard.
|
type isConnectedFunc func() bool
|
||||||
type ConnStatus int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ConnStatusDisconnected means neither ICE nor Relay is connected.
|
|
||||||
ConnStatusDisconnected ConnStatus = iota
|
|
||||||
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
|
|
||||||
ConnStatusPartiallyConnected
|
|
||||||
// ConnStatusConnected means all required connections are established.
|
|
||||||
ConnStatusConnected
|
|
||||||
)
|
|
||||||
|
|
||||||
type connStatusFunc func() ConnStatus
|
|
||||||
|
|
||||||
// Guard is responsible for the reconnection logic.
|
// Guard is responsible for the reconnection logic.
|
||||||
// It will trigger to send an offer to the peer then has connection issues.
|
// It will trigger to send an offer to the peer then has connection issues.
|
||||||
@@ -32,14 +20,14 @@ type connStatusFunc func() ConnStatus
|
|||||||
// - ICE candidate changes
|
// - ICE candidate changes
|
||||||
type Guard struct {
|
type Guard struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
isConnectedOnAllWay connStatusFunc
|
isConnectedOnAllWay isConnectedFunc
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
srWatcher *SRWatcher
|
srWatcher *SRWatcher
|
||||||
relayedConnDisconnected chan struct{}
|
relayedConnDisconnected chan struct{}
|
||||||
iCEConnDisconnected chan struct{}
|
iCEConnDisconnected chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||||
return &Guard{
|
return &Guard{
|
||||||
log: log,
|
log: log,
|
||||||
isConnectedOnAllWay: isConnectedFn,
|
isConnectedOnAllWay: isConnectedFn,
|
||||||
@@ -69,17 +57,8 @@ func (g *Guard) SetICEConnDisconnected() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
|
// reconnectLoopWithRetry periodically check the connection status.
|
||||||
//
|
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||||
// Behavior depends on the connection state reported by isConnectedOnAllWay:
|
|
||||||
// - Connected: no action, the peer is fully reachable.
|
|
||||||
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
|
|
||||||
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
|
|
||||||
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
|
|
||||||
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
|
|
||||||
//
|
|
||||||
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
|
|
||||||
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
|
|
||||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||||
srReconnectedChan := g.srWatcher.NewListener()
|
srReconnectedChan := g.srWatcher.NewListener()
|
||||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||||
@@ -89,47 +68,36 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
|||||||
|
|
||||||
tickerChannel := ticker.C
|
tickerChannel := ticker.C
|
||||||
|
|
||||||
iceState := &iceRetryState{log: g.log}
|
|
||||||
defer iceState.reset()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-tickerChannel:
|
case t := <-tickerChannel:
|
||||||
switch g.isConnectedOnAllWay() {
|
if t.IsZero() {
|
||||||
case ConnStatusConnected:
|
g.log.Infof("retry timed out, stop periodic offer sending")
|
||||||
// all good, nothing to do
|
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
|
||||||
case ConnStatusDisconnected:
|
tickerChannel = make(<-chan time.Time)
|
||||||
callback()
|
continue
|
||||||
case ConnStatusPartiallyConnected:
|
|
||||||
if iceState.shouldRetry() {
|
|
||||||
callback()
|
|
||||||
} else {
|
|
||||||
iceState.enterHourlyMode()
|
|
||||||
ticker.Stop()
|
|
||||||
tickerChannel = iceState.hourlyC()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !g.isConnectedOnAllWay() {
|
||||||
|
callback()
|
||||||
|
}
|
||||||
case <-g.relayedConnDisconnected:
|
case <-g.relayedConnDisconnected:
|
||||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
ticker = g.newReconnectTicker(ctx)
|
ticker = g.prepareExponentTicker(ctx)
|
||||||
tickerChannel = ticker.C
|
tickerChannel = ticker.C
|
||||||
iceState.reset()
|
|
||||||
|
|
||||||
case <-g.iCEConnDisconnected:
|
case <-g.iCEConnDisconnected:
|
||||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
ticker = g.newReconnectTicker(ctx)
|
ticker = g.prepareExponentTicker(ctx)
|
||||||
tickerChannel = ticker.C
|
tickerChannel = ticker.C
|
||||||
iceState.reset()
|
|
||||||
|
|
||||||
case <-srReconnectedChan:
|
case <-srReconnectedChan:
|
||||||
g.log.Debugf("has network changes, reset reconnection ticker")
|
g.log.Debugf("has network changes, reset reconnection ticker")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
ticker = g.newReconnectTicker(ctx)
|
ticker = g.prepareExponentTicker(ctx)
|
||||||
tickerChannel = ticker.C
|
tickerChannel = ticker.C
|
||||||
iceState.reset()
|
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
g.log.Debugf("context is done, stop reconnect loop")
|
g.log.Debugf("context is done, stop reconnect loop")
|
||||||
@@ -152,7 +120,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
|
|||||||
return backoff.NewTicker(bo)
|
return backoff.NewTicker(bo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
|
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
InitialInterval: 800 * time.Millisecond,
|
InitialInterval: 800 * time.Millisecond,
|
||||||
RandomizationFactor: 0.1,
|
RandomizationFactor: 0.1,
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
package guard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
|
|
||||||
maxICERetries = 3
|
|
||||||
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
|
|
||||||
iceRetryInterval = 1 * time.Hour
|
|
||||||
)
|
|
||||||
|
|
||||||
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
|
|
||||||
// After maxICERetries attempts it switches to a periodic hourly retry.
|
|
||||||
type iceRetryState struct {
|
|
||||||
log *log.Entry
|
|
||||||
retries int
|
|
||||||
hourly *time.Ticker
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *iceRetryState) reset() {
|
|
||||||
s.retries = 0
|
|
||||||
if s.hourly != nil {
|
|
||||||
s.hourly.Stop()
|
|
||||||
s.hourly = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// shouldRetry reports whether the caller should send another ICE offer on this tick.
|
|
||||||
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
|
|
||||||
// to the hourly ticker via enterHourlyMode + hourlyC.
|
|
||||||
func (s *iceRetryState) shouldRetry() bool {
|
|
||||||
if s.hourly != nil {
|
|
||||||
s.log.Debugf("hourly ICE retry attempt")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
s.retries++
|
|
||||||
if s.retries <= maxICERetries {
|
|
||||||
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
|
|
||||||
func (s *iceRetryState) enterHourlyMode() {
|
|
||||||
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
|
|
||||||
s.hourly = time.NewTicker(iceRetryInterval)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *iceRetryState) hourlyC() <-chan time.Time {
|
|
||||||
if s.hourly == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return s.hourly.C
|
|
||||||
}
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
package guard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestRetryState() *iceRetryState {
|
|
||||||
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
|
|
||||||
for i := 1; i <= maxICERetries; i++ {
|
|
||||||
if !s.shouldRetry() {
|
|
||||||
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
|
|
||||||
for i := 0; i < maxICERetries; i++ {
|
|
||||||
_ = s.shouldRetry()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.shouldRetry() {
|
|
||||||
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
|
|
||||||
if s.hourlyC() != nil {
|
|
||||||
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
for i := 0; i < maxICERetries+1; i++ {
|
|
||||||
_ = s.shouldRetry()
|
|
||||||
}
|
|
||||||
|
|
||||||
s.enterHourlyMode()
|
|
||||||
defer s.reset()
|
|
||||||
|
|
||||||
if s.hourlyC() == nil {
|
|
||||||
t.Fatalf("hourlyC returned nil after enterHourlyMode")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
s.enterHourlyMode()
|
|
||||||
defer s.reset()
|
|
||||||
|
|
||||||
if !s.shouldRetry() {
|
|
||||||
t.Fatalf("shouldRetry returned false in hourly mode, want true")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Subsequent calls also return true — we keep retrying on each hourly tick.
|
|
||||||
if !s.shouldRetry() {
|
|
||||||
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
for i := 0; i < maxICERetries+1; i++ {
|
|
||||||
_ = s.shouldRetry()
|
|
||||||
}
|
|
||||||
s.enterHourlyMode()
|
|
||||||
|
|
||||||
s.reset()
|
|
||||||
|
|
||||||
if s.hourlyC() != nil {
|
|
||||||
t.Fatalf("hourlyC returned non-nil channel after reset")
|
|
||||||
}
|
|
||||||
if s.retries != 0 {
|
|
||||||
t.Fatalf("retries = %d after reset, want 0", s.retries)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 1; i <= maxICERetries; i++ {
|
|
||||||
if !s.shouldRetry() {
|
|
||||||
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
s.reset()
|
|
||||||
s.reset() // second call must not panic or re-stop a nil ticker
|
|
||||||
|
|
||||||
if s.hourlyC() != nil {
|
|
||||||
t.Fatalf("hourlyC non-nil after double reset")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
|
|||||||
return srw
|
return srw
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SRWatcher) Start(disableICEMonitor bool) {
|
func (w *SRWatcher) Start() {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
@@ -50,10 +50,8 @@ func (w *SRWatcher) Start(disableICEMonitor bool) {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
w.cancelIceMonitor = cancel
|
w.cancelIceMonitor = cancel
|
||||||
|
|
||||||
if !disableICEMonitor {
|
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
|
||||||
}
|
|
||||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -44,10 +43,6 @@ type OfferAnswer struct {
|
|||||||
SessionID *ICESessionID
|
SessionID *ICESessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OfferAnswer) hasICECredentials() bool {
|
|
||||||
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
type Handshaker struct {
|
type Handshaker struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
@@ -64,10 +59,6 @@ type Handshaker struct {
|
|||||||
relayListener *AsyncOfferListener
|
relayListener *AsyncOfferListener
|
||||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||||
|
|
||||||
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
|
|
||||||
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
|
|
||||||
remoteICESupported atomic.Bool
|
|
||||||
|
|
||||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||||
remoteOffersCh chan OfferAnswer
|
remoteOffersCh chan OfferAnswer
|
||||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||||
@@ -75,7 +66,7 @@ type Handshaker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
||||||
h := &Handshaker{
|
return &Handshaker{
|
||||||
log: log,
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
signaler: signaler,
|
signaler: signaler,
|
||||||
@@ -85,13 +76,6 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
|||||||
remoteOffersCh: make(chan OfferAnswer),
|
remoteOffersCh: make(chan OfferAnswer),
|
||||||
remoteAnswerCh: make(chan OfferAnswer),
|
remoteAnswerCh: make(chan OfferAnswer),
|
||||||
}
|
}
|
||||||
// assume remote supports ICE until we learn otherwise from received offers
|
|
||||||
h.remoteICESupported.Store(ice != nil)
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handshaker) RemoteICESupported() bool {
|
|
||||||
return h.remoteICESupported.Load()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||||
@@ -106,20 +90,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
|
|
||||||
// Record signaling received for reconnection attempts
|
// Record signaling received for reconnection attempts
|
||||||
if h.metricsStages != nil {
|
if h.metricsStages != nil {
|
||||||
h.metricsStages.RecordSignalingReceived()
|
h.metricsStages.RecordSignalingReceived()
|
||||||
}
|
}
|
||||||
|
|
||||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
|
||||||
|
|
||||||
if h.relayListener != nil {
|
if h.relayListener != nil {
|
||||||
h.relayListener.Notify(&remoteOfferAnswer)
|
h.relayListener.Notify(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.iceListener != nil && h.RemoteICESupported() {
|
if h.iceListener != nil {
|
||||||
h.iceListener(&remoteOfferAnswer)
|
h.iceListener(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,20 +110,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
|
|
||||||
// Record signaling received for reconnection attempts
|
// Record signaling received for reconnection attempts
|
||||||
if h.metricsStages != nil {
|
if h.metricsStages != nil {
|
||||||
h.metricsStages.RecordSignalingReceived()
|
h.metricsStages.RecordSignalingReceived()
|
||||||
}
|
}
|
||||||
|
|
||||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
|
||||||
|
|
||||||
if h.relayListener != nil {
|
if h.relayListener != nil {
|
||||||
h.relayListener.Notify(&remoteOfferAnswer)
|
h.relayListener.Notify(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.iceListener != nil && h.RemoteICESupported() {
|
if h.iceListener != nil {
|
||||||
h.iceListener(&remoteOfferAnswer)
|
h.iceListener(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -203,18 +183,15 @@ func (h *Handshaker) sendAnswer() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||||
|
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||||
|
sid := h.ice.SessionID()
|
||||||
answer := OfferAnswer{
|
answer := OfferAnswer{
|
||||||
|
IceCredentials: IceCredentials{uFrag, pwd},
|
||||||
WgListenPort: h.config.LocalWgPort,
|
WgListenPort: h.config.LocalWgPort,
|
||||||
Version: version.NetbirdVersion(),
|
Version: version.NetbirdVersion(),
|
||||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||||
}
|
SessionID: &sid,
|
||||||
|
|
||||||
if h.ice != nil && h.RemoteICESupported() {
|
|
||||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
|
||||||
sid := h.ice.SessionID()
|
|
||||||
answer.IceCredentials = IceCredentials{uFrag, pwd}
|
|
||||||
answer.SessionID = &sid
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||||
@@ -223,18 +200,3 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
|||||||
|
|
||||||
return answer
|
return answer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
|
|
||||||
hasICE := offer.hasICECredentials()
|
|
||||||
prev := h.remoteICESupported.Swap(hasICE)
|
|
||||||
if prev != hasICE {
|
|
||||||
if hasICE {
|
|
||||||
h.log.Infof("remote peer started sending ICE credentials")
|
|
||||||
} else {
|
|
||||||
h.log.Infof("remote peer stopped sending ICE credentials")
|
|
||||||
if h.ice != nil {
|
|
||||||
h.ice.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -46,13 +46,9 @@ func (s *Signaler) Ready() bool {
|
|||||||
|
|
||||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||||
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
||||||
var sessionIDBytes []byte
|
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
|
||||||
if offerAnswer.SessionID != nil {
|
if err != nil {
|
||||||
var err error
|
log.Warnf("failed to get session ID bytes: %v", err)
|
||||||
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to get session ID bytes: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
msg, err := signal.MarshalCredential(
|
msg, err := signal.MarshalCredential(
|
||||||
s.wgPrivateKey,
|
s.wgPrivateKey,
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
|
|
||||||
|
|
||||||
package systemops
|
|
||||||
|
|
||||||
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
|
|
||||||
// always fall through to the ref-counter exclusion-route path; these stubs
|
|
||||||
// exist only so systemops_unix.go compiles.
|
|
||||||
func (r *SysOps) setupAdvancedRouting() error { return nil }
|
|
||||||
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
|
|
||||||
func (r *SysOps) flushPlatformExtras() error { return nil }
|
|
||||||
@@ -1,241 +0,0 @@
|
|||||||
//go:build darwin && !ios
|
|
||||||
|
|
||||||
package systemops
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/net/route"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// scopedRouteBudget bounds retries for the scoped default route. Installing or
|
|
||||||
// deleting it matters enough that we're willing to spend longer waiting for the
|
|
||||||
// kernel reply than for per-prefix exclusion routes.
|
|
||||||
const scopedRouteBudget = 5 * time.Second
|
|
||||||
|
|
||||||
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
|
|
||||||
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
|
|
||||||
// resolve gateway'd destinations while the VPN's split default owns the
|
|
||||||
// unscoped table.
|
|
||||||
//
|
|
||||||
// Timing note: this runs during routeManager.Init, which happens before the
|
|
||||||
// VPN interface is created and before any peer routes propagate. The initial
|
|
||||||
// mgmt / signal / relay TCP dials always fire before this runs, so those
|
|
||||||
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
|
|
||||||
// lookup, which at that point correctly picks the physical default. Those
|
|
||||||
// already-established TCP flows keep their originally-selected interface for
|
|
||||||
// their lifetime on Darwin because the kernel caches the egress route
|
|
||||||
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
|
|
||||||
// afterwards does not migrate them since the original en0 default stays in
|
|
||||||
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
|
|
||||||
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
|
|
||||||
func (r *SysOps) setupAdvancedRouting() error {
|
|
||||||
// Drop any previously-cached egress interface before reinstalling. On a
|
|
||||||
// refresh, a family that no longer resolves would otherwise keep the stale
|
|
||||||
// binding, causing new sockets to scope to an interface without a matching
|
|
||||||
// scoped default.
|
|
||||||
nbnet.ClearBoundInterfaces()
|
|
||||||
|
|
||||||
if err := r.flushScopedDefaults(); err != nil {
|
|
||||||
log.Warnf("flush residual scoped defaults: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
installed := 0
|
|
||||||
|
|
||||||
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
|
|
||||||
ok, err := r.installScopedDefaultFor(unspec)
|
|
||||||
if err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
installed++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if installed == 0 && merr != nil {
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
if merr != nil {
|
|
||||||
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// installScopedDefaultFor resolves the physical default nexthop for the given
|
|
||||||
// address family, installs a scoped default via it, and caches the iface for
|
|
||||||
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
|
|
||||||
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
|
||||||
nexthop, err := GetNextHop(unspec)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, vars.ErrRouteNotFound) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
|
|
||||||
}
|
|
||||||
if nexthop.Intf == nil {
|
|
||||||
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
|
||||||
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
af := unix.AF_INET
|
|
||||||
if unspec.Is6() {
|
|
||||||
af = unix.AF_INET6
|
|
||||||
}
|
|
||||||
nbnet.SetBoundInterface(af, nexthop.Intf)
|
|
||||||
via := "point-to-point"
|
|
||||||
if nexthop.IP.IsValid() {
|
|
||||||
via = nexthop.IP.String()
|
|
||||||
}
|
|
||||||
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SysOps) cleanupAdvancedRouting() error {
|
|
||||||
nbnet.ClearBoundInterfaces()
|
|
||||||
return r.flushScopedDefaults()
|
|
||||||
}
|
|
||||||
|
|
||||||
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
|
|
||||||
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
|
|
||||||
// removed on the next boot regardless of whether a profile is brought up.
|
|
||||||
func (r *SysOps) flushPlatformExtras() error {
|
|
||||||
return r.flushScopedDefaults()
|
|
||||||
}
|
|
||||||
|
|
||||||
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
|
|
||||||
// Safe to call at startup to clear residual entries from a prior session.
|
|
||||||
func (r *SysOps) flushScopedDefaults() error {
|
|
||||||
rib, err := retryFetchRIB()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("fetch routing table: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse routing table: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
removed := 0
|
|
||||||
|
|
||||||
for _, msg := range msgs {
|
|
||||||
rtMsg, ok := msg.(*route.RouteMessage)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if rtMsg.Flags&routeProtoFlag == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
info, err := MsgToRoute(rtMsg)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("skip scoped flush: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.deleteScopedRoute(rtMsg); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
|
|
||||||
info.Dst, rtMsg.Index, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
removed++
|
|
||||||
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
|
|
||||||
}
|
|
||||||
|
|
||||||
if removed > 0 {
|
|
||||||
log.Infof("flushed %d residual scoped default route(s)", removed)
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
|
|
||||||
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
|
|
||||||
// Preserve identifying flags from the stored route (including RTF_GATEWAY
|
|
||||||
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
|
|
||||||
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
|
|
||||||
del := &route.RouteMessage{
|
|
||||||
Type: unix.RTM_DELETE,
|
|
||||||
Flags: rtMsg.Flags & keep,
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Seq: r.getSeq(),
|
|
||||||
Index: rtMsg.Index,
|
|
||||||
Addrs: rtMsg.Addrs,
|
|
||||||
}
|
|
||||||
return r.writeRouteMessage(del, scopedRouteBudget)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
|
|
||||||
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
|
|
||||||
|
|
||||||
msg := &route.RouteMessage{
|
|
||||||
Type: action,
|
|
||||||
Flags: flags,
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
ID: uintptr(os.Getpid()),
|
|
||||||
Seq: r.getSeq(),
|
|
||||||
Index: nexthop.Intf.Index,
|
|
||||||
}
|
|
||||||
|
|
||||||
const numAddrs = unix.RTAX_NETMASK + 1
|
|
||||||
addrs := make([]route.Addr, numAddrs)
|
|
||||||
|
|
||||||
dst, err := addrToRouteAddr(unspec)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("build destination: %w", err)
|
|
||||||
}
|
|
||||||
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("build netmask: %w", err)
|
|
||||||
}
|
|
||||||
addrs[unix.RTAX_DST] = dst
|
|
||||||
addrs[unix.RTAX_NETMASK] = mask
|
|
||||||
|
|
||||||
if nexthop.IP.IsValid() {
|
|
||||||
msg.Flags |= unix.RTF_GATEWAY
|
|
||||||
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("build gateway: %w", err)
|
|
||||||
}
|
|
||||||
addrs[unix.RTAX_GATEWAY] = gw
|
|
||||||
} else {
|
|
||||||
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
|
|
||||||
Index: nexthop.Intf.Index,
|
|
||||||
Name: nexthop.Intf.Name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msg.Addrs = addrs
|
|
||||||
|
|
||||||
return r.writeRouteMessage(msg, scopedRouteBudget)
|
|
||||||
}
|
|
||||||
|
|
||||||
func afOf(a netip.Addr) string {
|
|
||||||
if a.Is4() {
|
|
||||||
return "IPv4"
|
|
||||||
}
|
|
||||||
return "IPv6"
|
|
||||||
}
|
|
||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
"github.com/netbirdio/netbird/client/net/hooks"
|
"github.com/netbirdio/netbird/client/net/hooks"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,6 +31,8 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
|||||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||||
|
|
||||||
|
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||||
|
|
||||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
stateManager.RegisterState(&ShutdownState{})
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
@@ -396,16 +397,12 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
||||||
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
|
|
||||||
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
|
|
||||||
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||||
if nbnet.AdvancedRouting() {
|
localRoutes, err := hasSeparateRouting()
|
||||||
return false, netip.Prefix{}
|
|
||||||
}
|
|
||||||
|
|
||||||
localRoutes, err := GetRoutesFromTable()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to get routes: %v", err)
|
if !errors.Is(err, ErrRoutingIsSeparate) {
|
||||||
|
log.Errorf("Failed to get routes: %v", err)
|
||||||
|
}
|
||||||
return false, netip.Prefix{}
|
return false, netip.Prefix{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
return []netip.Prefix{}, nil
|
return []netip.Prefix{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||||
|
return []netip.Prefix{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
||||||
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
||||||
return []DetailedRoute{}, nil
|
return []DetailedRoute{}, nil
|
||||||
|
|||||||
@@ -894,6 +894,13 @@ func getAddressFamily(prefix netip.Prefix) int {
|
|||||||
return netlink.FAMILY_V6
|
return netlink.FAMILY_V6
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||||
|
if !nbnet.AdvancedRouting() {
|
||||||
|
return GetRoutesFromTable()
|
||||||
|
}
|
||||||
|
return nil, ErrRoutingIsSeparate
|
||||||
|
}
|
||||||
|
|
||||||
func isOpErr(err error) bool {
|
func isOpErr(err error) bool {
|
||||||
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
|
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
|
||||||
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {
|
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {
|
||||||
|
|||||||
@@ -48,6 +48,10 @@ func EnableIPForwarding() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||||
|
return GetRoutesFromTable()
|
||||||
|
}
|
||||||
|
|
||||||
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
|
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
|
||||||
func GetIPRules() ([]IPRule, error) {
|
func GetIPRules() ([]IPRule, error) {
|
||||||
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)
|
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)
|
||||||
|
|||||||
@@ -25,9 +25,6 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
|
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
|
||||||
|
|
||||||
// routeBudget bounds retries for per-prefix exclusion route programming.
|
|
||||||
routeBudget = 1 * time.Second
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var routeProtoFlag int
|
var routeProtoFlag int
|
||||||
@@ -44,42 +41,26 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
if advancedRouting {
|
|
||||||
return r.setupAdvancedRouting()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("Using legacy routing setup with ref counters")
|
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
if advancedRouting {
|
|
||||||
return r.cleanupAdvancedRouting()
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
|
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
|
||||||
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
|
|
||||||
// crashed prior session can't leave crud in the table.
|
|
||||||
func (r *SysOps) FlushMarkedRoutes() error {
|
func (r *SysOps) FlushMarkedRoutes() error {
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if err := r.flushPlatformExtras(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
rib, err := retryFetchRIB()
|
rib, err := retryFetchRIB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err)))
|
return fmt.Errorf("fetch routing table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err)))
|
return fmt.Errorf("parse routing table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
flushedCount := 0
|
flushedCount := 0
|
||||||
|
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
@@ -136,12 +117,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
|
|||||||
return fmt.Errorf("invalid prefix: %s", prefix)
|
return fmt.Errorf("invalid prefix: %s", prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := r.buildRouteMessage(action, prefix, nexthop)
|
expBackOff := backoff.NewExponentialBackOff()
|
||||||
if err != nil {
|
expBackOff.InitialInterval = 50 * time.Millisecond
|
||||||
return fmt.Errorf("build route message: %w", err)
|
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||||
}
|
expBackOff.MaxElapsedTime = 1 * time.Second
|
||||||
|
|
||||||
if err := r.writeRouteMessage(msg, routeBudget); err != nil {
|
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
|
||||||
a := "add"
|
a := "add"
|
||||||
if action == unix.RTM_DELETE {
|
if action == unix.RTM_DELETE {
|
||||||
a = "remove"
|
a = "remove"
|
||||||
@@ -151,91 +132,50 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeRouteMessage sends a route message over AF_ROUTE and waits for the
|
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
|
||||||
// kernel's matching reply, retrying transient failures until budget elapses.
|
operation := func() error {
|
||||||
// Callers do not need to manage sockets or seq numbers themselves.
|
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||||
func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
|
|
||||||
expBackOff := backoff.NewExponentialBackOff()
|
|
||||||
expBackOff.InitialInterval = 50 * time.Millisecond
|
|
||||||
expBackOff.MaxInterval = 500 * time.Millisecond
|
|
||||||
expBackOff.MaxElapsedTime = budget
|
|
||||||
|
|
||||||
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
|
|
||||||
}
|
|
||||||
|
|
||||||
func routeMessageRoundtrip(msg *route.RouteMessage) error {
|
|
||||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("open routing socket: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
|
||||||
log.Warnf("close routing socket: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
tv := unix.Timeval{Sec: 1}
|
|
||||||
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
|
|
||||||
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// AF_ROUTE is a broadcast channel: every route socket on the host sees
|
|
||||||
// every RTM_* event. With concurrent route programming the default
|
|
||||||
// per-socket queue overflows and our own reply gets dropped.
|
|
||||||
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
|
|
||||||
log.Debugf("set SO_RCVBUF on route socket: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bytes, err := msg.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = unix.Write(fd, bytes); err != nil {
|
|
||||||
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
|
||||||
return fmt.Errorf("write: %w", err)
|
|
||||||
}
|
|
||||||
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
|
||||||
}
|
|
||||||
return readRouteResponse(fd, msg.Type, msg.Seq)
|
|
||||||
}
|
|
||||||
|
|
||||||
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
|
|
||||||
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
|
|
||||||
// broadcast channel: interface up/down, third-party route changes and neighbor
|
|
||||||
// discovery events can all land between our write and read, so we must filter.
|
|
||||||
func readRouteResponse(fd, wantType, wantSeq int) error {
|
|
||||||
pid := int32(os.Getpid())
|
|
||||||
resp := make([]byte, 2048)
|
|
||||||
deadline := time.Now().Add(time.Second)
|
|
||||||
for {
|
|
||||||
if time.Now().After(deadline) {
|
|
||||||
// Transient: under concurrent pressure the kernel can drop our reply
|
|
||||||
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
|
|
||||||
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
|
|
||||||
}
|
|
||||||
n, err := unix.Read(fd, resp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) {
|
return fmt.Errorf("open routing socket: %w", err)
|
||||||
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline.
|
}
|
||||||
continue
|
defer func() {
|
||||||
|
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
||||||
|
log.Warnf("failed to close routing socket: %v", err)
|
||||||
}
|
}
|
||||||
return backoff.Permanent(fmt.Errorf("read: %w", err))
|
}()
|
||||||
|
|
||||||
|
msg, err := r.buildRouteMessage(action, prefix, nexthop)
|
||||||
|
if err != nil {
|
||||||
|
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
|
||||||
}
|
}
|
||||||
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
|
||||||
continue
|
msgBytes, err := msg.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
|
||||||
}
|
}
|
||||||
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
|
|
||||||
// Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid)
|
if _, err = unix.Write(fd, msgBytes); err != nil {
|
||||||
// uniquely identifies our own reply among broadcast traffic.
|
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
||||||
if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid {
|
return fmt.Errorf("write: %w", err)
|
||||||
continue
|
}
|
||||||
|
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
||||||
}
|
}
|
||||||
if hdr.Errno != 0 {
|
|
||||||
return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno)))
|
respBuf := make([]byte, 2048)
|
||||||
|
n, err := unix.Read(fd, respBuf)
|
||||||
|
if err != nil {
|
||||||
|
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
|
||||||
|
return backoff.Permanent(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return operation
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||||
@@ -243,7 +183,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
|
|||||||
Type: action,
|
Type: action,
|
||||||
Flags: unix.RTF_UP | routeProtoFlag,
|
Flags: unix.RTF_UP | routeProtoFlag,
|
||||||
Version: unix.RTM_VERSION,
|
Version: unix.RTM_VERSION,
|
||||||
ID: uintptr(os.Getpid()),
|
|
||||||
Seq: r.getSeq(),
|
Seq: r.getSeq(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -282,6 +221,19 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
|
|||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) parseRouteResponse(buf []byte) error {
|
||||||
|
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||||
|
if rtMsg.Errno != 0 {
|
||||||
|
return fmt.Errorf("parse: %d", rtMsg.Errno)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
|
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
|
||||||
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
|
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
|
||||||
if addr.Is4() {
|
if addr.Is4() {
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
func (d *Dialer) init() {
|
|
||||||
d.Dialer.Control = applyBoundIfToSocket
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux && !windows && !darwin
|
//go:build !linux && !windows
|
||||||
|
|
||||||
package net
|
package net
|
||||||
|
|
||||||
|
|||||||
24
client/net/env_android.go
Normal file
24
client/net/env_android.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
// Init initializes the network environment for Android
|
||||||
|
func Init() {
|
||||||
|
// No initialization needed on Android
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
|
||||||
|
// Always returns true on Android since we cannot handle routes dynamically.
|
||||||
|
func AdvancedRouting() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVPNInterfaceName is a no-op on Android
|
||||||
|
func SetVPNInterfaceName(name string) {
|
||||||
|
// No-op on Android - not needed for Android VPN service
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVPNInterfaceName returns empty string on Android
|
||||||
|
func GetVPNInterfaceName() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux && !windows && !darwin
|
//go:build !linux && !windows && !android
|
||||||
|
|
||||||
package net
|
package net
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
//go:build ios || android
|
|
||||||
|
|
||||||
package net
|
|
||||||
|
|
||||||
// Init initializes the network environment for mobile platforms.
|
|
||||||
func Init() {
|
|
||||||
// no-op on mobile: routing scope is owned by the VPN extension.
|
|
||||||
}
|
|
||||||
|
|
||||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
|
|
||||||
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
|
|
||||||
// owns the routing scope.
|
|
||||||
func AdvancedRouting() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetVPNInterfaceName is a no-op on mobile.
|
|
||||||
func SetVPNInterfaceName(string) {
|
|
||||||
// no-op on mobile: the VPN extension manages the interface.
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetVPNInterfaceName returns an empty string on mobile.
|
|
||||||
func GetVPNInterfaceName() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build (darwin && !ios) || windows
|
//go:build windows
|
||||||
|
|
||||||
package net
|
package net
|
||||||
|
|
||||||
@@ -24,22 +24,17 @@ func Init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func checkAdvancedRoutingSupport() bool {
|
func checkAdvancedRoutingSupport() bool {
|
||||||
legacyRouting := false
|
var err error
|
||||||
|
var legacyRouting bool
|
||||||
if val := os.Getenv(envUseLegacyRouting); val != "" {
|
if val := os.Getenv(envUseLegacyRouting); val != "" {
|
||||||
parsed, err := strconv.ParseBool(val)
|
legacyRouting, err = strconv.ParseBool(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err)
|
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
|
||||||
} else {
|
|
||||||
legacyRouting = parsed
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if legacyRouting {
|
if legacyRouting || netstack.IsEnabled() {
|
||||||
log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting)
|
log.Info("advanced routing has been requested to be disabled")
|
||||||
return false
|
|
||||||
}
|
|
||||||
if netstack.IsEnabled() {
|
|
||||||
log.Info("advanced routing disabled: netstack mode is enabled")
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
func (l *ListenerConfig) init() {
|
|
||||||
l.ListenConfig.Control = applyBoundIfToSocket
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux && !windows && !darwin
|
//go:build !linux && !windows
|
||||||
|
|
||||||
package net
|
package net
|
||||||
|
|
||||||
|
|||||||
@@ -1,160 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
|
|
||||||
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
|
|
||||||
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
|
|
||||||
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
|
|
||||||
// dispatched by socket domain (AF_INET only) not by inp_vflag.
|
|
||||||
|
|
||||||
// boundIface holds the physical interface chosen at routing setup time. Sockets
|
|
||||||
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
|
|
||||||
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
|
|
||||||
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
|
|
||||||
// following the VPN's split default.
|
|
||||||
var (
|
|
||||||
boundIfaceMu sync.RWMutex
|
|
||||||
boundIface4 *net.Interface
|
|
||||||
boundIface6 *net.Interface
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetBoundInterface records the egress interface for an address family. Called
|
|
||||||
// by the routemanager after a scoped default route has been installed.
|
|
||||||
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
|
|
||||||
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
|
|
||||||
func SetBoundInterface(af int, iface *net.Interface) {
|
|
||||||
if iface == nil {
|
|
||||||
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
boundIfaceMu.Lock()
|
|
||||||
defer boundIfaceMu.Unlock()
|
|
||||||
switch af {
|
|
||||||
case unix.AF_INET:
|
|
||||||
boundIface4 = iface
|
|
||||||
case unix.AF_INET6:
|
|
||||||
boundIface6 = iface
|
|
||||||
default:
|
|
||||||
log.Warnf("SetBoundInterface: unsupported address family %d", af)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
|
|
||||||
// routemanager during cleanup.
|
|
||||||
func ClearBoundInterfaces() {
|
|
||||||
boundIfaceMu.Lock()
|
|
||||||
defer boundIfaceMu.Unlock()
|
|
||||||
boundIface4 = nil
|
|
||||||
boundIface6 = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// boundInterfaceFor returns the cached egress interface for a socket's address
|
|
||||||
// family, falling back to the other family if the preferred slot is empty.
|
|
||||||
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
|
|
||||||
// either setsockopt scopes the socket; preferring same-family still matters
|
|
||||||
// when v4 and v6 defaults egress different NICs.
|
|
||||||
func boundInterfaceFor(network, address string) *net.Interface {
|
|
||||||
if iface := zoneInterface(address); iface != nil {
|
|
||||||
return iface
|
|
||||||
}
|
|
||||||
|
|
||||||
boundIfaceMu.RLock()
|
|
||||||
defer boundIfaceMu.RUnlock()
|
|
||||||
|
|
||||||
primary, secondary := boundIface4, boundIface6
|
|
||||||
if isV6Network(network) {
|
|
||||||
primary, secondary = boundIface6, boundIface4
|
|
||||||
}
|
|
||||||
if primary != nil {
|
|
||||||
return primary
|
|
||||||
}
|
|
||||||
return secondary
|
|
||||||
}
|
|
||||||
|
|
||||||
func isV6Network(network string) bool {
|
|
||||||
return strings.HasSuffix(network, "6")
|
|
||||||
}
|
|
||||||
|
|
||||||
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
|
|
||||||
func zoneInterface(address string) *net.Interface {
|
|
||||||
if address == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
addr, err := netip.ParseAddrPort(address)
|
|
||||||
if err != nil {
|
|
||||||
a, err := netip.ParseAddr(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
addr = netip.AddrPortFrom(a, 0)
|
|
||||||
}
|
|
||||||
zone := addr.Addr().Zone()
|
|
||||||
if zone == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if iface, err := net.InterfaceByName(zone); err == nil {
|
|
||||||
return iface
|
|
||||||
}
|
|
||||||
if idx, err := strconv.Atoi(zone); err == nil {
|
|
||||||
if iface, err := net.InterfaceByIndex(idx); err == nil {
|
|
||||||
return iface
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
|
|
||||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
|
|
||||||
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
|
|
||||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
|
|
||||||
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyBoundIfToSocket binds the socket to the cached physical egress interface
|
|
||||||
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
|
|
||||||
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
|
|
||||||
if !AdvancedRouting() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
iface := boundInterfaceFor(network, address)
|
|
||||||
if iface == nil {
|
|
||||||
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
isV6 := isV6Network(network)
|
|
||||||
var controlErr error
|
|
||||||
if err := c.Control(func(fd uintptr) {
|
|
||||||
if isV6 {
|
|
||||||
controlErr = setIPv6BoundIf(fd, iface)
|
|
||||||
} else {
|
|
||||||
controlErr = setIPv4BoundIf(fd, iface)
|
|
||||||
}
|
|
||||||
if controlErr == nil {
|
|
||||||
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
|
|
||||||
}
|
|
||||||
}); err != nil {
|
|
||||||
return fmt.Errorf("control: %w", err)
|
|
||||||
}
|
|
||||||
return controlErr
|
|
||||||
}
|
|
||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -137,8 +138,10 @@ func restoreResidualState(ctx context.Context, statePath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// clean up any remaining routes independently of the state file
|
// clean up any remaining routes independently of the state file
|
||||||
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
if !nbnet.AdvancedRouting() {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
|||||||
@@ -187,23 +187,24 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
|||||||
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
hostList := strings.Join(deduplicatedPatterns, ",")
|
hostLine := strings.Join(deduplicatedPatterns, " ")
|
||||||
config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath)
|
config := fmt.Sprintf("Host %s\n", hostLine)
|
||||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||||
config += " PasswordAuthentication yes\n"
|
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||||
config += " PubkeyAuthentication yes\n"
|
config += " PasswordAuthentication yes\n"
|
||||||
config += " BatchMode no\n"
|
config += " PubkeyAuthentication yes\n"
|
||||||
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
config += " BatchMode no\n"
|
||||||
config += " StrictHostKeyChecking no\n"
|
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
||||||
|
config += " StrictHostKeyChecking no\n"
|
||||||
|
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
config += " UserKnownHostsFile NUL\n"
|
config += " UserKnownHostsFile NUL\n"
|
||||||
} else {
|
} else {
|
||||||
config += " UserKnownHostsFile /dev/null\n"
|
config += " UserKnownHostsFile /dev/null\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
config += " CheckHostIP no\n"
|
config += " CheckHostIP no\n"
|
||||||
config += " LogLevel ERROR\n\n"
|
config += " LogLevel ERROR\n\n"
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,37 +116,6 @@ func TestManager_PeerLimit(t *testing.T) {
|
|||||||
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_MatchHostFormat(t *testing.T) {
|
|
||||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
|
||||||
|
|
||||||
manager := &Manager{
|
|
||||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
|
||||||
sshConfigFile: "99-netbird.conf",
|
|
||||||
}
|
|
||||||
|
|
||||||
peers := []PeerSSHInfo{
|
|
||||||
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
|
|
||||||
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = manager.SetupSSHClientConfig(peers)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
|
||||||
content, err := os.ReadFile(configPath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
configStr := string(content)
|
|
||||||
|
|
||||||
// Must use "Match host" with comma-separated patterns, not a bare "Host" directive.
|
|
||||||
// A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block
|
|
||||||
// ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped.
|
|
||||||
assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive")
|
|
||||||
assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"",
|
|
||||||
"should use Match host with comma-separated patterns")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_ForcedSSHConfig(t *testing.T) {
|
func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||||
// Set force environment variable
|
// Set force environment variable
|
||||||
t.Setenv(EnvForceSSHConfig, "true")
|
t.Setenv(EnvForceSSHConfig, "true")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package system
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -144,6 +145,59 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
|
|||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func networkAddresses() ([]NetworkAddress, error) {
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var netAddresses []NetworkAddress
|
||||||
|
for _, iface := range interfaces {
|
||||||
|
if iface.Flags&net.FlagUp == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if iface.HardwareAddr.String() == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, address := range addrs {
|
||||||
|
ipNet, ok := address.(*net.IPNet)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipNet.IP.IsLoopback() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
netAddr := NetworkAddress{
|
||||||
|
NetIP: netip.MustParsePrefix(ipNet.String()),
|
||||||
|
Mac: iface.HardwareAddr.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if isDuplicated(netAddresses, netAddr) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
netAddresses = append(netAddresses, netAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return netAddresses, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||||
|
for _, duplicated := range addresses {
|
||||||
|
if duplicated.NetIP == addr.NetIP {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
||||||
log.Debugf("gathering system information with checks: %d", len(checks))
|
log.Debugf("gathering system information with checks: %d", len(checks))
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ package system
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -44,66 +42,6 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
return gio
|
return gio
|
||||||
}
|
}
|
||||||
|
|
||||||
// networkAddresses returns the list of network addresses on iOS.
|
|
||||||
// On iOS, hardware (MAC) addresses are not available due to Apple's privacy
|
|
||||||
// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we
|
|
||||||
// leave Mac empty to match Android's behavior. We also skip the HardwareAddr
|
|
||||||
// check that other platforms use and filter out link-local addresses as they
|
|
||||||
// are not useful for posture checks.
|
|
||||||
func networkAddresses() ([]NetworkAddress, error) {
|
|
||||||
interfaces, err := net.Interfaces()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var netAddresses []NetworkAddress
|
|
||||||
for _, iface := range interfaces {
|
|
||||||
if iface.Flags&net.FlagUp == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
addrs, err := iface.Addrs()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, address := range addrs {
|
|
||||||
netAddr, ok := toNetworkAddress(address)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if isDuplicated(netAddresses, netAddr) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
netAddresses = append(netAddresses, netAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return netAddresses, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func toNetworkAddress(address net.Addr) (NetworkAddress, bool) {
|
|
||||||
ipNet, ok := address.(*net.IPNet)
|
|
||||||
if !ok {
|
|
||||||
return NetworkAddress{}, false
|
|
||||||
}
|
|
||||||
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
|
|
||||||
return NetworkAddress{}, false
|
|
||||||
}
|
|
||||||
prefix, err := netip.ParsePrefix(ipNet.String())
|
|
||||||
if err != nil {
|
|
||||||
return NetworkAddress{}, false
|
|
||||||
}
|
|
||||||
return NetworkAddress{NetIP: prefix, Mac: ""}, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
|
||||||
for _, duplicated := range addresses {
|
|
||||||
if duplicated.NetIP == addr.NetIP {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||||
return []File{}, nil
|
return []File{}, nil
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
//go:build !ios
|
|
||||||
|
|
||||||
package system
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
)
|
|
||||||
|
|
||||||
func networkAddresses() ([]NetworkAddress, error) {
|
|
||||||
interfaces, err := net.Interfaces()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var netAddresses []NetworkAddress
|
|
||||||
for _, iface := range interfaces {
|
|
||||||
if iface.Flags&net.FlagUp == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if iface.HardwareAddr.String() == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
addrs, err := iface.Addrs()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
mac := iface.HardwareAddr.String()
|
|
||||||
for _, address := range addrs {
|
|
||||||
netAddr, ok := toNetworkAddress(address, mac)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if isDuplicated(netAddresses, netAddr) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
netAddresses = append(netAddresses, netAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return netAddresses, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
|
|
||||||
ipNet, ok := address.(*net.IPNet)
|
|
||||||
if !ok {
|
|
||||||
return NetworkAddress{}, false
|
|
||||||
}
|
|
||||||
if ipNet.IP.IsLoopback() {
|
|
||||||
return NetworkAddress{}, false
|
|
||||||
}
|
|
||||||
prefix, err := netip.ParsePrefix(ipNet.String())
|
|
||||||
if err != nil {
|
|
||||||
return NetworkAddress{}, false
|
|
||||||
}
|
|
||||||
return NetworkAddress{NetIP: prefix, Mac: mac}, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
|
||||||
for _, duplicated := range addresses {
|
|
||||||
if duplicated.NetIP == addr.NetIP {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -119,8 +119,6 @@ server:
|
|||||||
|
|
||||||
# Reverse proxy settings (optional)
|
# Reverse proxy settings (optional)
|
||||||
# reverseProxy:
|
# reverseProxy:
|
||||||
# trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"])
|
# trustedHTTPProxies: []
|
||||||
# trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies)
|
# trustedHTTPProxiesCount: 0
|
||||||
# trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"])
|
# trustedPeers: []
|
||||||
# accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely).
|
|
||||||
# accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value.
|
|
||||||
|
|||||||
@@ -457,18 +457,6 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
|||||||
|
|
||||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
|
|
||||||
// which is when Receive has actually returned. Without this, the Receive goroutine
|
|
||||||
// can outlive the test and call t.Logf after teardown, panicking.
|
|
||||||
receiveDone := make(chan struct{})
|
|
||||||
t.Cleanup(func() {
|
|
||||||
select {
|
|
||||||
case <-receiveDone:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Error("Receive goroutine did not exit after Close")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err := client.Close()
|
err := client.Close()
|
||||||
assert.NoError(t, err, "failed to close flow")
|
assert.NoError(t, err, "failed to close flow")
|
||||||
@@ -480,7 +468,6 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
|||||||
receivedAfterReconnect := make(chan struct{})
|
receivedAfterReconnect := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(receiveDone)
|
|
||||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||||
if msg.IsInitiator || len(msg.EventId) == 0 {
|
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -323,5 +323,3 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184
|
|||||||
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
|
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
|
||||||
|
|
||||||
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
|
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
|
||||||
|
|
||||||
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0
|
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -400,6 +400,8 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA
|
|||||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
|
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
|
||||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||||
|
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
|
||||||
|
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||||
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
|
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
|
||||||
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
|
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
|
||||||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||||
@@ -447,8 +449,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
|||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||||
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
|
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
|
||||||
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
|
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
|
||||||
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
|
|
||||||
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
|
||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
|
||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) {
|
|||||||
// are stored with types that Dex can open.
|
// are stored with types that Dex can open.
|
||||||
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
|
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
|
||||||
switch connType {
|
switch connType {
|
||||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
|
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||||
return "oidc", applyOIDCDefaults(connType, config)
|
return "oidc", applyOIDCDefaults(connType, config)
|
||||||
default:
|
default:
|
||||||
return connType, config
|
return connType, config
|
||||||
@@ -218,8 +218,6 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin
|
|||||||
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
|
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
|
||||||
case "okta", "pocketid":
|
case "okta", "pocketid":
|
||||||
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
|
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||||
case "adfs":
|
|
||||||
augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return augmented
|
return augmented
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto
|
|||||||
var err error
|
var err error
|
||||||
|
|
||||||
switch cfg.Type {
|
switch cfg.Type {
|
||||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
|
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||||
dexType = "oidc"
|
dexType = "oidc"
|
||||||
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
|
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
|
||||||
case "google":
|
case "google":
|
||||||
@@ -220,8 +220,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
|||||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||||
case "pocketid":
|
case "pocketid":
|
||||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||||
case "adfs":
|
|
||||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
|
|
||||||
}
|
}
|
||||||
return encodeConnectorConfig(oidcConfig)
|
return encodeConnectorConfig(oidcConfig)
|
||||||
}
|
}
|
||||||
@@ -285,7 +283,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa
|
|||||||
// inferOIDCProviderType infers the specific OIDC provider from connector ID
|
// inferOIDCProviderType infers the specific OIDC provider from connector ID
|
||||||
func inferOIDCProviderType(connectorID string) string {
|
func inferOIDCProviderType(connectorID string) string {
|
||||||
connectorIDLower := strings.ToLower(connectorID)
|
connectorIDLower := strings.ToLower(connectorID)
|
||||||
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} {
|
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
|
||||||
if strings.Contains(connectorIDLower, provider) {
|
if strings.Contains(connectorIDLower, provider) {
|
||||||
return provider
|
return provider
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -472,7 +472,7 @@ start_services_and_show_instructions() {
|
|||||||
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
||||||
echo "Registering CrowdSec bouncer..."
|
echo "Registering CrowdSec bouncer..."
|
||||||
local cs_retries=0
|
local cs_retries=0
|
||||||
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do
|
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do
|
||||||
cs_retries=$((cs_retries + 1))
|
cs_retries=$((cs_retries + 1))
|
||||||
if [[ $cs_retries -ge 30 ]]; then
|
if [[ $cs_retries -ge 30 ]]; then
|
||||||
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
|
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -15,9 +16,11 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"golang.org/x/mod/semver"
|
"golang.org/x/mod/semver"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
@@ -55,6 +58,13 @@ type Controller struct {
|
|||||||
proxyController port_forwarding.Controller
|
proxyController port_forwarding.Controller
|
||||||
|
|
||||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||||
|
|
||||||
|
holder *types.Holder
|
||||||
|
|
||||||
|
expNewNetworkMap bool
|
||||||
|
expNewNetworkMapAIDs map[string]struct{}
|
||||||
|
|
||||||
|
compactedNetworkMap bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type bufferUpdate struct {
|
type bufferUpdate struct {
|
||||||
@@ -71,6 +81,29 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
|
||||||
|
newNetworkMapBuilder = false
|
||||||
|
}
|
||||||
|
|
||||||
|
compactedNetworkMap := true
|
||||||
|
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
|
||||||
|
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
|
||||||
|
if err != nil && len(compactedEnv) > 0 {
|
||||||
|
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
|
||||||
|
}
|
||||||
|
if err == nil && !parsedCompactedNmap {
|
||||||
|
log.WithContext(ctx).Info("disabling compacted mode")
|
||||||
|
compactedNetworkMap = false
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
||||||
|
expIDs := make(map[string]struct{}, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
expIDs[id] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
return &Controller{
|
return &Controller{
|
||||||
repo: newRepository(store),
|
repo: newRepository(store),
|
||||||
metrics: nMetrics,
|
metrics: nMetrics,
|
||||||
@@ -84,6 +117,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
|
|
||||||
proxyController: proxyController,
|
proxyController: proxyController,
|
||||||
EphemeralPeersManager: ephemeralPeersManager,
|
EphemeralPeersManager: ephemeralPeersManager,
|
||||||
|
|
||||||
|
holder: types.NewHolder(),
|
||||||
|
expNewNetworkMap: newNetworkMapBuilder,
|
||||||
|
expNewNetworkMapAIDs: expIDs,
|
||||||
|
|
||||||
|
compactedNetworkMap: compactedNetworkMap,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,9 +153,17 @@ func (c *Controller) CountStreams() int {
|
|||||||
|
|
||||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
var (
|
||||||
if err != nil {
|
account *types.Account
|
||||||
return fmt.Errorf("failed to get account: %v", err)
|
err error
|
||||||
|
)
|
||||||
|
if c.experimentalNetworkMap(accountID) {
|
||||||
|
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||||
|
} else {
|
||||||
|
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get account: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
globalStart := time.Now()
|
globalStart := time.Now()
|
||||||
@@ -150,6 +197,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
|
if c.experimentalNetworkMap(accountID) {
|
||||||
|
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||||
|
}
|
||||||
|
|
||||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||||
@@ -192,7 +243,16 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
|
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
var remotePeerNetworkMap *types.NetworkMap
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case c.experimentalNetworkMap(accountID):
|
||||||
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
|
case c.compactedNetworkMap:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
default:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
}
|
||||||
|
|
||||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||||
|
|
||||||
@@ -258,6 +318,10 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
|||||||
// UpdatePeers updates all peers that belong to an account.
|
// UpdatePeers updates all peers that belong to an account.
|
||||||
// Should be called when changes have to be synced to peers.
|
// Should be called when changes have to be synced to peers.
|
||||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||||
|
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||||
|
return fmt.Errorf("recalculate network map cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -307,7 +371,16 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
var remotePeerNetworkMap *types.NetworkMap
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case c.experimentalNetworkMap(accountId):
|
||||||
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
|
case c.compactedNetworkMap:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
default:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if ok {
|
||||||
@@ -378,9 +451,17 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
return peer, emptyMap, nil, 0, nil
|
return peer, emptyMap, nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
var (
|
||||||
if err != nil {
|
account *types.Account
|
||||||
return nil, nil, nil, 0, err
|
err error
|
||||||
|
)
|
||||||
|
if c.experimentalNetworkMap(accountID) {
|
||||||
|
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||||
|
} else {
|
||||||
|
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
account.InjectProxyPolicies(ctx)
|
account.InjectProxyPolicies(ctx)
|
||||||
@@ -412,10 +493,20 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
var networkMap *types.NetworkMap
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
if c.experimentalNetworkMap(accountID) {
|
||||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
|
} else {
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
if c.compactedNetworkMap {
|
||||||
|
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
} else {
|
||||||
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if ok {
|
||||||
@@ -427,6 +518,108 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
||||||
|
c.enrichAccountFromHolder(account)
|
||||||
|
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) getPeerNetworkMapExp(
|
||||||
|
ctx context.Context,
|
||||||
|
accountId string,
|
||||||
|
peerId string,
|
||||||
|
validatedPeers map[string]struct{},
|
||||||
|
peersCustomZone nbdns.CustomZone,
|
||||||
|
accountZones []*zones.Zone,
|
||||||
|
metrics *telemetry.AccountManagerMetrics,
|
||||||
|
) *types.NetworkMap {
|
||||||
|
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
||||||
|
if account == nil {
|
||||||
|
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||||
|
return &types.NetworkMap{
|
||||||
|
Network: &types.Network{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
||||||
|
c.enrichAccountFromHolder(account)
|
||||||
|
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||||
|
c.enrichAccountFromHolder(account)
|
||||||
|
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
||||||
|
account := c.getAccountFromHolder(accountId)
|
||||||
|
if account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account.UpdatePeerInNetworkMapCache(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
||||||
|
account.RecalculateNetworkMapCache(validatedPeers)
|
||||||
|
c.updateAccountInHolder(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
||||||
|
if c.experimentalNetworkMap(accountId) {
|
||||||
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.recalculateNetworkMapCache(account, validatedPeers)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) experimentalNetworkMap(accountId string) bool {
|
||||||
|
_, ok := c.expNewNetworkMapAIDs[accountId]
|
||||||
|
return c.expNewNetworkMap || ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
|
||||||
|
a := c.holder.GetAccount(account.Id)
|
||||||
|
if a == nil {
|
||||||
|
c.holder.AddAccount(account)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account.NetworkMapCache = a.NetworkMapCache
|
||||||
|
if account.NetworkMapCache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.holder.AddAccount(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
|
||||||
|
return c.holder.GetAccount(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
|
||||||
|
a := c.holder.GetAccount(accountID)
|
||||||
|
if a != nil {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) updateAccountInHolder(account *types.Account) {
|
||||||
|
c.holder.AddAccount(account)
|
||||||
|
}
|
||||||
|
|
||||||
// GetDNSDomain returns the configured dnsDomain
|
// GetDNSDomain returns the configured dnsDomain
|
||||||
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
||||||
if settings == nil {
|
if settings == nil {
|
||||||
@@ -563,7 +756,16 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
|
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get peers by ids: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
c.UpdatePeerInNetworkMapCache(accountID, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||||
}
|
}
|
||||||
@@ -573,6 +775,14 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
|
|||||||
|
|
||||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||||
|
if c.experimentalNetworkMap(accountID) {
|
||||||
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
|
||||||
|
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
|
||||||
|
}
|
||||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -607,6 +817,19 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
|||||||
MessageType: network_map.MessageTypeNetworkMap,
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
})
|
})
|
||||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||||
|
|
||||||
|
if c.experimentalNetworkMap(accountID) {
|
||||||
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
@@ -649,11 +872,21 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.InjectProxyPolicies(ctx)
|
var networkMap *types.NetworkMap
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
if c.experimentalNetworkMap(peer.AccountID) {
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
} else {
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
if c.compactedNetworkMap {
|
||||||
|
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||||
|
} else {
|
||||||
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if ok {
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
||||||
|
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
||||||
|
|
||||||
DnsForwarderPort = nbdns.ForwarderServerPort
|
DnsForwarderPort = nbdns.ForwarderServerPort
|
||||||
OldForwarderPort = nbdns.ForwarderClientPort
|
OldForwarderPort = nbdns.ForwarderClientPort
|
||||||
DnsForwarderPortMinVersion = "v0.59.0"
|
DnsForwarderPortMinVersion = "v0.59.0"
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
|
||||||
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
@@ -82,24 +83,26 @@ type CapabilityProvider interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
store store.Store
|
store store.Store
|
||||||
accountManager account.Manager
|
accountManager account.Manager
|
||||||
permissionsManager permissions.Manager
|
permissionsManager permissions.Manager
|
||||||
proxyController proxy.Controller
|
proxyController proxy.Controller
|
||||||
capabilities CapabilityProvider
|
networkMapController network_map.Controller
|
||||||
clusterDeriver ClusterDeriver
|
capabilities CapabilityProvider
|
||||||
exposeReaper *exposeReaper
|
clusterDeriver ClusterDeriver
|
||||||
|
exposeReaper *exposeReaper
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new service manager.
|
// NewManager creates a new service manager.
|
||||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
|
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver, networkMapController network_map.Controller) *Manager {
|
||||||
mgr := &Manager{
|
mgr := &Manager{
|
||||||
store: store,
|
store: store,
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
permissionsManager: permissionsManager,
|
permissionsManager: permissionsManager,
|
||||||
proxyController: proxyController,
|
proxyController: proxyController,
|
||||||
capabilities: capabilities,
|
networkMapController: networkMapController,
|
||||||
clusterDeriver: clusterDeriver,
|
capabilities: capabilities,
|
||||||
|
clusterDeriver: clusterDeriver,
|
||||||
}
|
}
|
||||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||||
return mgr
|
return mgr
|
||||||
@@ -151,13 +154,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
|||||||
for _, target := range s.Targets {
|
for _, target := range s.Targets {
|
||||||
switch target.TargetType {
|
switch target.TargetType {
|
||||||
case service.TargetTypePeer:
|
case service.TargetTypePeer:
|
||||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
target.Host = m.getPeerTargetHost(ctx, accountID, target)
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
|
|
||||||
target.Host = unknownHostPlaceholder
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
target.Host = peer.IP.String()
|
|
||||||
case service.TargetTypeHost:
|
case service.TargetTypeHost:
|
||||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -184,6 +181,26 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) getPeerTargetHost(ctx context.Context, accountID string, target *service.Target) string {
|
||||||
|
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, target.ServiceID, err)
|
||||||
|
return unknownHostPlaceholder
|
||||||
|
}
|
||||||
|
|
||||||
|
if target.Protocol == "https" {
|
||||||
|
settings, err := m.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get account settings for service %s: %v", target.ServiceID, err)
|
||||||
|
return unknownHostPlaceholder
|
||||||
|
}
|
||||||
|
dnsDomain := m.networkMapController.GetDNSDomain(settings)
|
||||||
|
return peer.FQDN(dnsDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer.IP.String()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ import (
|
|||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
@@ -110,7 +109,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -118,15 +117,6 @@ func (s *BaseServer) APIHandler() http.Handler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
|
||||||
return Create(s, func() *middleware.APIRateLimiter {
|
|
||||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
|
||||||
limiter := middleware.NewAPIRateLimiter(cfg)
|
|
||||||
limiter.SetEnabled(enabled)
|
|
||||||
return limiter
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||||
return Create(s, func() *grpc.Server {
|
return Create(s, func() *grpc.Server {
|
||||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) ServiceManager() service.Manager {
|
func (s *BaseServer) ServiceManager() service.Manager {
|
||||||
return Create(s, func() service.Manager {
|
return Create(s, func() service.Manager {
|
||||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
|
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager(), s.NetworkMapController())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -408,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||||
networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||||
}
|
}
|
||||||
@@ -1171,6 +1171,11 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
|||||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
|
||||||
|
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||||
|
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||||
}
|
}
|
||||||
@@ -1226,6 +1231,11 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
|
||||||
|
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||||
|
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||||
}
|
}
|
||||||
@@ -1264,6 +1274,11 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
|
||||||
|
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||||
|
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||||
}
|
}
|
||||||
@@ -1317,6 +1332,11 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
|
||||||
|
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||||
|
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||||
}
|
}
|
||||||
@@ -1377,6 +1397,11 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
|
||||||
|
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||||
|
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||||
}
|
}
|
||||||
@@ -1608,6 +1633,75 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
|||||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||||
|
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, prefix2, err := route.ParseNetwork("192.168.0.0/24")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
account := &types.Account{
|
||||||
|
Peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
|
||||||
|
},
|
||||||
|
Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
|
||||||
|
Routes: map[route.ID]*route.Route{
|
||||||
|
"route-1": {
|
||||||
|
ID: "route-1",
|
||||||
|
Network: prefix,
|
||||||
|
NetID: "network-1",
|
||||||
|
Description: "network-1",
|
||||||
|
Peer: "peer-1",
|
||||||
|
NetworkType: 0,
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
},
|
||||||
|
"route-2": {
|
||||||
|
ID: "route-2",
|
||||||
|
Network: prefix2,
|
||||||
|
NetID: "network-2",
|
||||||
|
Description: "network-2",
|
||||||
|
Peer: "peer-2",
|
||||||
|
NetworkType: 0,
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
},
|
||||||
|
"route-3": {
|
||||||
|
ID: "route-3",
|
||||||
|
Network: prefix,
|
||||||
|
NetID: "network-1",
|
||||||
|
Description: "network-1",
|
||||||
|
Peer: "peer-2",
|
||||||
|
NetworkType: 0,
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
|
||||||
|
|
||||||
|
assert.Len(t, routes, 2)
|
||||||
|
routeIDs := make(map[route.ID]struct{}, 2)
|
||||||
|
for _, r := range routes {
|
||||||
|
routeIDs[r.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||||
|
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||||
|
|
||||||
|
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
|
||||||
|
|
||||||
|
assert.Len(t, emptyRoutes, 0)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_Copy(t *testing.T) {
|
func TestAccount_Copy(t *testing.T) {
|
||||||
account := &types.Account{
|
account := &types.Account{
|
||||||
Id: "account1",
|
Id: "account1",
|
||||||
@@ -1730,7 +1824,9 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
AccountID: "account1",
|
AccountID: "account1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||||
}
|
}
|
||||||
|
account.InitOnce()
|
||||||
err := hasNilField(account)
|
err := hasNilField(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -2215,29 +2311,6 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir())
|
|
||||||
t.Cleanup(cleanUp)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
|
||||||
|
|
||||||
// Verify the already-expired peer is excluded at the store level
|
|
||||||
peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
for _, peer := range peers {
|
|
||||||
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query")
|
|
||||||
assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only the non-expired peer with expiration enabled should be returned
|
|
||||||
require.Len(t, peers, 1)
|
|
||||||
assert.Equal(t, "notexpired01", peers[0].ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccount_GetInactivePeers(t *testing.T) {
|
func TestAccount_GetInactivePeers(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
@@ -3157,13 +3230,6 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
|||||||
return manager, updateManager, account, peer1, peer2, peer3
|
return manager, updateManager, account, peer1, peer2, peer3
|
||||||
}
|
}
|
||||||
|
|
||||||
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
|
|
||||||
// wrappers wait for an expected update message. Sized for slow CI runners
|
|
||||||
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
|
|
||||||
// seconds. Only runs down on failure; passing tests return immediately
|
|
||||||
// when the channel delivers.
|
|
||||||
const peerUpdateTimeout = 5 * time.Second
|
|
||||||
|
|
||||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
select {
|
select {
|
||||||
@@ -3182,7 +3248,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
|
|||||||
if msg == nil {
|
if msg == nil {
|
||||||
t.Errorf("Received nil update message, expected valid message")
|
t.Errorf("Received nil update message, expected valid message")
|
||||||
}
|
}
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(500 * time.Millisecond):
|
||||||
t.Error("Timed out waiting for update message")
|
t.Error("Timed out waiting for update message")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -458,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -478,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -518,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -620,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -638,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -689,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -730,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -757,7 +757,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -804,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
@@ -63,11 +66,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
apiPrefix = "/api"
|
apiPrefix = "/api"
|
||||||
|
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
|
||||||
|
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
|
||||||
|
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||||
|
|
||||||
// Register bypass paths for unauthenticated endpoints
|
// Register bypass paths for unauthenticated endpoints
|
||||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||||
@@ -88,10 +94,34 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rateLimiter == nil {
|
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||||
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
|
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||||
rateLimiter = middleware.NewAPIRateLimiter(nil)
|
rpm := 6
|
||||||
rateLimiter.SetEnabled(false)
|
if v := os.Getenv(rateLimitingRPMKey); v != "" {
|
||||||
|
value, err := strconv.Atoi(v)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
|
||||||
|
} else {
|
||||||
|
rpm = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
burst := 500
|
||||||
|
if v := os.Getenv(rateLimitingBurstKey); v != "" {
|
||||||
|
value, err := strconv.Atoi(v)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
|
||||||
|
} else {
|
||||||
|
burst = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rateLimitingConfig = &middleware.RateLimiterConfig{
|
||||||
|
RequestsPerMinute: float64(rpm),
|
||||||
|
Burst: burst,
|
||||||
|
CleanupInterval: 6 * time.Hour,
|
||||||
|
LimiterTTL: 24 * time.Hour,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
authMiddleware := middleware.NewAuthMiddleware(
|
authMiddleware := middleware.NewAuthMiddleware(
|
||||||
@@ -99,7 +129,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
accountManager.GetAccountIDFromUserAuth,
|
accountManager.GetAccountIDFromUserAuth,
|
||||||
accountManager.SyncUserJWTGroups,
|
accountManager.SyncUserJWTGroups,
|
||||||
accountManager.GetUserFromUserAuth,
|
accountManager.GetUserFromUserAuth,
|
||||||
rateLimiter,
|
rateLimitingConfig,
|
||||||
appMetrics.GetMeter(),
|
appMetrics.GetMeter(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -417,7 +417,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||||
|
|
||||||
netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||||
@@ -43,9 +42,14 @@ func NewAuthMiddleware(
|
|||||||
ensureAccount EnsureAccountFunc,
|
ensureAccount EnsureAccountFunc,
|
||||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||||
rateLimiter *APIRateLimiter,
|
rateLimiterConfig *RateLimiterConfig,
|
||||||
meter metric.Meter,
|
meter metric.Meter,
|
||||||
) *AuthMiddleware {
|
) *AuthMiddleware {
|
||||||
|
var rateLimiter *APIRateLimiter
|
||||||
|
if rateLimiterConfig != nil {
|
||||||
|
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||||
|
}
|
||||||
|
|
||||||
var patUsageTracker *PATUsageTracker
|
var patUsageTracker *PATUsageTracker
|
||||||
if meter != nil {
|
if meter != nil {
|
||||||
var err error
|
var err error
|
||||||
@@ -83,14 +87,17 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
|
|
||||||
switch authType {
|
switch authType {
|
||||||
case "bearer":
|
case "bearer":
|
||||||
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
|
request, err := m.checkJWTFromRequest(r, authHeader)
|
||||||
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.ServeHTTP(w, r)
|
|
||||||
|
h.ServeHTTP(w, request)
|
||||||
case "token":
|
case "token":
|
||||||
if err := m.checkPATFromRequest(r, authHeader); err != nil {
|
request, err := m.checkPATFromRequest(r, authHeader)
|
||||||
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||||
// Check if it's a status error, otherwise default to Unauthorized
|
// Check if it's a status error, otherwise default to Unauthorized
|
||||||
if _, ok := status.FromError(err); !ok {
|
if _, ok := status.FromError(err); !ok {
|
||||||
@@ -99,7 +106,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.ServeHTTP(w, r)
|
h.ServeHTTP(w, request)
|
||||||
default:
|
default:
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||||
return
|
return
|
||||||
@@ -108,19 +115,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckJWTFromRequest checks if the JWT is valid
|
// CheckJWTFromRequest checks if the JWT is valid
|
||||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
|
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||||
|
|
||||||
// If an error occurs, call the error handler and return an error
|
// If an error occurs, call the error handler and return an error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error extracting token: %w", err)
|
return r, fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return r, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||||
@@ -136,7 +143,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return r, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if userAuth.AccountId != accountId {
|
if userAuth.AccountId != accountId {
|
||||||
@@ -146,7 +153,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
|
|
||||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return r, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||||
@@ -157,41 +164,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||||
return err
|
return r, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// propagates ctx change to upstream middleware
|
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckPATFromRequest checks if the PAT is valid
|
// CheckPATFromRequest checks if the PAT is valid
|
||||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
|
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error extracting token: %w", err)
|
return r, fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.patUsageTracker != nil {
|
if m.patUsageTracker != nil {
|
||||||
m.patUsageTracker.IncrementUsage(token)
|
m.patUsageTracker.IncrementUsage(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
|
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||||
return status.Errorf(status.TooManyRequests, "too many requests")
|
if !m.rateLimiter.Allow(token) {
|
||||||
|
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid Token: %w", err)
|
return r, fmt.Errorf("invalid Token: %w", err)
|
||||||
}
|
}
|
||||||
if time.Now().After(pat.GetExpirationDate()) {
|
if time.Now().After(pat.GetExpirationDate()) {
|
||||||
return fmt.Errorf("token expired")
|
return r, fmt.Errorf("token expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return r, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userAuth := auth.UserAuth{
|
userAuth := auth.UserAuth{
|
||||||
@@ -209,9 +216,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// propagates ctx change to upstream middleware
|
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isTerraformRequest(r *http.Request) bool {
|
func isTerraformRequest(r *http.Request) bool {
|
||||||
|
|||||||
@@ -196,8 +196,6 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||||
}
|
}
|
||||||
|
|
||||||
disabledLimiter := NewAPIRateLimiter(nil)
|
|
||||||
disabledLimiter.SetEnabled(false)
|
|
||||||
authMiddleware := NewAuthMiddleware(
|
authMiddleware := NewAuthMiddleware(
|
||||||
mockAuth,
|
mockAuth,
|
||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||||
@@ -209,7 +207,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
disabledLimiter,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -268,7 +266,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
rateLimitConfig,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -320,7 +318,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
rateLimitConfig,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -363,7 +361,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
rateLimitConfig,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -407,7 +405,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
rateLimitConfig,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -471,7 +469,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
rateLimitConfig,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -530,7 +528,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
rateLimitConfig,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -585,7 +583,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
rateLimitConfig,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -672,8 +670,6 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||||
}
|
}
|
||||||
|
|
||||||
disabledLimiter := NewAPIRateLimiter(nil)
|
|
||||||
disabledLimiter.SetEnabled(false)
|
|
||||||
authMiddleware := NewAuthMiddleware(
|
authMiddleware := NewAuthMiddleware(
|
||||||
mockAuth,
|
mockAuth,
|
||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||||
@@ -685,7 +681,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
disabledLimiter,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,27 +4,14 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
|
|
||||||
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
|
|
||||||
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
|
|
||||||
|
|
||||||
defaultAPIRPM = 6
|
|
||||||
defaultAPIBurst = 500
|
|
||||||
)
|
|
||||||
|
|
||||||
// RateLimiterConfig holds configuration for the API rate limiter
|
// RateLimiterConfig holds configuration for the API rate limiter
|
||||||
type RateLimiterConfig struct {
|
type RateLimiterConfig struct {
|
||||||
// RequestsPerMinute defines the rate at which tokens are replenished
|
// RequestsPerMinute defines the rate at which tokens are replenished
|
||||||
@@ -47,43 +34,6 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
|
|
||||||
rpm := defaultAPIRPM
|
|
||||||
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
|
|
||||||
value, err := strconv.Atoi(v)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
|
|
||||||
} else {
|
|
||||||
rpm = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if rpm <= 0 {
|
|
||||||
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
|
|
||||||
rpm = defaultAPIRPM
|
|
||||||
}
|
|
||||||
|
|
||||||
burst := defaultAPIBurst
|
|
||||||
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
|
|
||||||
value, err := strconv.Atoi(v)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
|
|
||||||
} else {
|
|
||||||
burst = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if burst <= 0 {
|
|
||||||
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
|
|
||||||
burst = defaultAPIBurst
|
|
||||||
}
|
|
||||||
|
|
||||||
return &RateLimiterConfig{
|
|
||||||
RequestsPerMinute: float64(rpm),
|
|
||||||
Burst: burst,
|
|
||||||
CleanupInterval: 6 * time.Hour,
|
|
||||||
LimiterTTL: 24 * time.Hour,
|
|
||||||
}, os.Getenv(RateLimitingEnabledEnv) == "true"
|
|
||||||
}
|
|
||||||
|
|
||||||
// limiterEntry holds a rate limiter and its last access time
|
// limiterEntry holds a rate limiter and its last access time
|
||||||
type limiterEntry struct {
|
type limiterEntry struct {
|
||||||
limiter *rate.Limiter
|
limiter *rate.Limiter
|
||||||
@@ -96,7 +46,6 @@ type APIRateLimiter struct {
|
|||||||
limiters map[string]*limiterEntry
|
limiters map[string]*limiterEntry
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
stopChan chan struct{}
|
stopChan chan struct{}
|
||||||
enabled atomic.Bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
||||||
@@ -110,53 +59,14 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
|
|||||||
limiters: make(map[string]*limiterEntry),
|
limiters: make(map[string]*limiterEntry),
|
||||||
stopChan: make(chan struct{}),
|
stopChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
rl.enabled.Store(true)
|
|
||||||
|
|
||||||
go rl.cleanupLoop()
|
go rl.cleanupLoop()
|
||||||
|
|
||||||
return rl
|
return rl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
|
|
||||||
rl.enabled.Store(enabled)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rl *APIRateLimiter) Enabled() bool {
|
|
||||||
return rl.enabled.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
|
|
||||||
if config == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
|
|
||||||
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
|
|
||||||
newBurst := config.Burst
|
|
||||||
|
|
||||||
rl.mu.Lock()
|
|
||||||
rl.config.RequestsPerMinute = config.RequestsPerMinute
|
|
||||||
rl.config.Burst = newBurst
|
|
||||||
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
|
|
||||||
for _, entry := range rl.limiters {
|
|
||||||
snapshot = append(snapshot, entry.limiter)
|
|
||||||
}
|
|
||||||
rl.mu.Unlock()
|
|
||||||
|
|
||||||
for _, l := range snapshot {
|
|
||||||
l.SetLimit(newRPS)
|
|
||||||
l.SetBurst(newBurst)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allow checks if a request for the given key (token) is allowed
|
// Allow checks if a request for the given key (token) is allowed
|
||||||
func (rl *APIRateLimiter) Allow(key string) bool {
|
func (rl *APIRateLimiter) Allow(key string) bool {
|
||||||
if !rl.enabled.Load() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
limiter := rl.getLimiter(key)
|
limiter := rl.getLimiter(key)
|
||||||
return limiter.Allow()
|
return limiter.Allow()
|
||||||
}
|
}
|
||||||
@@ -164,9 +74,6 @@ func (rl *APIRateLimiter) Allow(key string) bool {
|
|||||||
// Wait blocks until the rate limiter allows another request for the given key
|
// Wait blocks until the rate limiter allows another request for the given key
|
||||||
// Returns an error if the context is canceled
|
// Returns an error if the context is canceled
|
||||||
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
||||||
if !rl.enabled.Load() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
limiter := rl.getLimiter(key)
|
limiter := rl.getLimiter(key)
|
||||||
return limiter.Wait(ctx)
|
return limiter.Wait(ctx)
|
||||||
}
|
}
|
||||||
@@ -246,10 +153,6 @@ func (rl *APIRateLimiter) Reset(key string) {
|
|||||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if !rl.enabled.Load() {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
clientIP := getClientIP(r)
|
clientIP := getClientIP(r)
|
||||||
if !rl.Allow(clientIP) {
|
if !rl.Allow(clientIP) {
|
||||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -158,172 +156,3 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
|
|||||||
// Should be allowed again
|
// Should be allowed again
|
||||||
assert.True(t, rl.Allow("test-key"))
|
assert.True(t, rl.Allow("test-key"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
|
|
||||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
||||||
RequestsPerMinute: 60,
|
|
||||||
Burst: 1,
|
|
||||||
CleanupInterval: time.Minute,
|
|
||||||
LimiterTTL: time.Minute,
|
|
||||||
})
|
|
||||||
defer rl.Stop()
|
|
||||||
|
|
||||||
assert.True(t, rl.Allow("key"))
|
|
||||||
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
|
|
||||||
|
|
||||||
rl.SetEnabled(false)
|
|
||||||
assert.False(t, rl.Enabled())
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
|
|
||||||
}
|
|
||||||
|
|
||||||
rl.SetEnabled(true)
|
|
||||||
assert.True(t, rl.Enabled())
|
|
||||||
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
|
|
||||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
||||||
RequestsPerMinute: 60,
|
|
||||||
Burst: 2,
|
|
||||||
CleanupInterval: time.Minute,
|
|
||||||
LimiterTTL: time.Minute,
|
|
||||||
})
|
|
||||||
defer rl.Stop()
|
|
||||||
|
|
||||||
assert.True(t, rl.Allow("k1"))
|
|
||||||
assert.True(t, rl.Allow("k1"))
|
|
||||||
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
|
|
||||||
|
|
||||||
rl.UpdateConfig(&RateLimiterConfig{
|
|
||||||
RequestsPerMinute: 60,
|
|
||||||
Burst: 10,
|
|
||||||
CleanupInterval: time.Minute,
|
|
||||||
LimiterTTL: time.Minute,
|
|
||||||
})
|
|
||||||
|
|
||||||
// New burst applies to existing keys in place; bucket refills up to new burst over time,
|
|
||||||
// but importantly newly-added keys use the updated config immediately.
|
|
||||||
assert.True(t, rl.Allow("k2"))
|
|
||||||
for i := 0; i < 9; i++ {
|
|
||||||
assert.True(t, rl.Allow("k2"))
|
|
||||||
}
|
|
||||||
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
|
|
||||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
||||||
RequestsPerMinute: 60,
|
|
||||||
Burst: 1,
|
|
||||||
CleanupInterval: time.Minute,
|
|
||||||
LimiterTTL: time.Minute,
|
|
||||||
})
|
|
||||||
defer rl.Stop()
|
|
||||||
|
|
||||||
rl.UpdateConfig(nil) // must not panic or zero the config
|
|
||||||
|
|
||||||
assert.True(t, rl.Allow("k"))
|
|
||||||
assert.False(t, rl.Allow("k"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
|
|
||||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
||||||
RequestsPerMinute: 60,
|
|
||||||
Burst: 1,
|
|
||||||
CleanupInterval: time.Minute,
|
|
||||||
LimiterTTL: time.Minute,
|
|
||||||
})
|
|
||||||
defer rl.Stop()
|
|
||||||
|
|
||||||
assert.True(t, rl.Allow("k"))
|
|
||||||
assert.False(t, rl.Allow("k"))
|
|
||||||
|
|
||||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
|
||||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
|
||||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
|
||||||
|
|
||||||
rl.Reset("k")
|
|
||||||
assert.True(t, rl.Allow("k"))
|
|
||||||
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
|
|
||||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
||||||
RequestsPerMinute: 600,
|
|
||||||
Burst: 10,
|
|
||||||
CleanupInterval: time.Minute,
|
|
||||||
LimiterTTL: time.Minute,
|
|
||||||
})
|
|
||||||
defer rl.Stop()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
stop := make(chan struct{})
|
|
||||||
|
|
||||||
for i := 0; i < 8; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(id int) {
|
|
||||||
defer wg.Done()
|
|
||||||
key := fmt.Sprintf("k%d", id)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stop:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
rl.Allow(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
for i := 0; i < 200; i++ {
|
|
||||||
select {
|
|
||||||
case <-stop:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
rl.UpdateConfig(&RateLimiterConfig{
|
|
||||||
RequestsPerMinute: float64(30 + (i % 90)),
|
|
||||||
Burst: 1 + (i % 20),
|
|
||||||
CleanupInterval: time.Minute,
|
|
||||||
LimiterTTL: time.Minute,
|
|
||||||
})
|
|
||||||
rl.SetEnabled(i%2 == 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
close(stop)
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimiterConfigFromEnv(t *testing.T) {
|
|
||||||
t.Setenv(RateLimitingEnabledEnv, "true")
|
|
||||||
t.Setenv(RateLimitingRPMEnv, "42")
|
|
||||||
t.Setenv(RateLimitingBurstEnv, "7")
|
|
||||||
|
|
||||||
cfg, enabled := RateLimiterConfigFromEnv()
|
|
||||||
assert.True(t, enabled)
|
|
||||||
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
|
|
||||||
assert.Equal(t, 7, cfg.Burst)
|
|
||||||
|
|
||||||
t.Setenv(RateLimitingEnabledEnv, "false")
|
|
||||||
_, enabled = RateLimiterConfigFromEnv()
|
|
||||||
assert.False(t, enabled)
|
|
||||||
|
|
||||||
t.Setenv(RateLimitingEnabledEnv, "")
|
|
||||||
t.Setenv(RateLimitingRPMEnv, "")
|
|
||||||
t.Setenv(RateLimitingBurstEnv, "")
|
|
||||||
cfg, enabled = RateLimiterConfigFromEnv()
|
|
||||||
assert.False(t, enabled)
|
|
||||||
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
|
|
||||||
assert.Equal(t, defaultAPIBurst, cfg.Burst)
|
|
||||||
|
|
||||||
t.Setenv(RateLimitingRPMEnv, "0")
|
|
||||||
t.Setenv(RateLimitingBurstEnv, "-5")
|
|
||||||
cfg, _ = RateLimiterConfigFromEnv()
|
|
||||||
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
|
|
||||||
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create proxy controller: %v", err)
|
t.Fatalf("Failed to create proxy controller: %v", err)
|
||||||
}
|
}
|
||||||
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager)
|
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager, networkMapController)
|
||||||
proxyServiceServer.SetServiceManager(serviceManager)
|
proxyServiceServer.SetServiceManager(serviceManager)
|
||||||
am.SetServiceManager(serviceManager)
|
am.SetServiceManager(serviceManager)
|
||||||
|
|
||||||
@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -244,7 +244,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create proxy controller: %v", err)
|
t.Fatalf("Failed to create proxy controller: %v", err)
|
||||||
}
|
}
|
||||||
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager)
|
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager, networkMapController)
|
||||||
proxyServiceServer.SetServiceManager(serviceManager)
|
proxyServiceServer.SetServiceManager(serviceManager)
|
||||||
am.SetServiceManager(serviceManager)
|
am.SetServiceManager(serviceManager)
|
||||||
|
|
||||||
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -274,7 +274,7 @@ func identityProviderToConnectorConfig(idpConfig *types.IdentityProvider) *dex.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateIdentityProviderID generates a unique ID for an identity provider.
|
// generateIdentityProviderID generates a unique ID for an identity provider.
|
||||||
// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft, adfs),
|
// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft),
|
||||||
// the ID is prefixed with the type name. Generic OIDC providers get no prefix.
|
// the ID is prefixed with the type name. Generic OIDC providers get no prefix.
|
||||||
func generateIdentityProviderID(idpType types.IdentityProviderType) string {
|
func generateIdentityProviderID(idpType types.IdentityProviderType) string {
|
||||||
id := xid.New().String()
|
id := xid.New().String()
|
||||||
@@ -296,8 +296,6 @@ func generateIdentityProviderID(idpType types.IdentityProviderType) string {
|
|||||||
return "authentik-" + id
|
return "authentik-" + id
|
||||||
case types.IdentityProviderTypeKeycloak:
|
case types.IdentityProviderTypeKeycloak:
|
||||||
return "keycloak-" + id
|
return "keycloak-" + id
|
||||||
case types.IdentityProviderTypeADFS:
|
|
||||||
return "adfs-" + id
|
|
||||||
default:
|
default:
|
||||||
// Generic OIDC - no prefix
|
// Generic OIDC - no prefix
|
||||||
return id
|
return id
|
||||||
|
|||||||
@@ -267,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// expired peers come separately.
|
// expired peers come separately.
|
||||||
if len(networkMap.GetOfflinePeers()) != 2 {
|
if len(networkMap.GetOfflinePeers()) != 1 {
|
||||||
t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer")
|
t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4="
|
expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4="
|
||||||
|
|||||||
@@ -1087,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1405,10 +1405,6 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
|
|||||||
|
|
||||||
var peers []*nbpeer.Peer
|
var peers []*nbpeer.Peer
|
||||||
for _, peer := range peersWithExpiry {
|
for _, peer := range peersWithExpiry {
|
||||||
if peer.Status.LoginExpired {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
|
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||||
if expired {
|
if expired {
|
||||||
peers = append(peers, peer)
|
peers = append(peers, peer)
|
||||||
|
|||||||
@@ -179,6 +179,11 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
|||||||
testGetNetworkMapGeneral(t)
|
testGetNetworkMapGeneral(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
|
||||||
|
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||||
|
testGetNetworkMapGeneral(t)
|
||||||
|
}
|
||||||
|
|
||||||
func testGetNetworkMapGeneral(t *testing.T) {
|
func testGetNetworkMapGeneral(t *testing.T) {
|
||||||
manager, _, err := createManager(t)
|
manager, _, err := createManager(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1011,6 +1016,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateAccountPeers_Experimental(t *testing.T) {
|
||||||
|
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||||
|
testUpdateAccountPeers(t)
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateAccountPeers(t *testing.T) {
|
func TestUpdateAccountPeers(t *testing.T) {
|
||||||
testUpdateAccountPeers(t)
|
testUpdateAccountPeers(t)
|
||||||
}
|
}
|
||||||
@@ -1590,6 +1600,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_LoginPeer(t *testing.T) {
|
func Test_LoginPeer(t *testing.T) {
|
||||||
|
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
@@ -1896,7 +1907,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -1918,7 +1929,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -1983,7 +1994,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -2001,7 +2012,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -2047,7 +2058,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -2065,7 +2076,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -2102,7 +2113,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -2120,7 +2131,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(peerUpdateTimeout):
|
case <-time.After(time.Second):
|
||||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user