mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 03:06:38 +00:00
Compare commits
200 Commits
v0.65.2
...
github-iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b8e40f78d | ||
|
|
5da05ecca6 | ||
|
|
801de8c68d | ||
|
|
a822a33240 | ||
|
|
57b23c5b25 | ||
|
|
1165058fad | ||
|
|
703353d354 | ||
|
|
2fb50aef6b | ||
|
|
eb3aa96257 | ||
|
|
064ec1c832 | ||
|
|
75e408f51c | ||
|
|
5a89e6621b | ||
|
|
06dfa9d4a5 | ||
|
|
45d9ee52c0 | ||
|
|
3098f48b25 | ||
|
|
7f023ce801 | ||
|
|
e361126515 | ||
|
|
95213f7157 | ||
|
|
2e0e3a3601 | ||
|
|
8ae8f2098f | ||
|
|
a39787d679 | ||
|
|
53b04e512a | ||
|
|
633dde8d1f | ||
|
|
7e4542adde | ||
|
|
d4c61ed38b | ||
|
|
6b540d145c | ||
|
|
08f624507d | ||
|
|
95bc01e48f | ||
|
|
0d86de47df | ||
|
|
e804a705b7 | ||
|
|
46fc8c9f65 | ||
|
|
d7ad908962 | ||
|
|
c5623307cc | ||
|
|
7f666b8022 | ||
|
|
0a30b9b275 | ||
|
|
4eed459f27 | ||
|
|
13539543af | ||
|
|
7483fec048 | ||
|
|
5259e5df51 | ||
|
|
ebd78e0122 | ||
|
|
cf86b9a528 | ||
|
|
ee588e1536 | ||
|
|
2a8aacc5c9 | ||
|
|
15709bc666 | ||
|
|
789b4113fe | ||
|
|
d2cdc0efec | ||
|
|
ee343d5d77 | ||
|
|
099c493b18 | ||
|
|
c1d1229ae0 | ||
|
|
94a36cb53e | ||
|
|
c7ba931466 | ||
|
|
413d95b740 | ||
|
|
332c624c55 | ||
|
|
dc160aff36 | ||
|
|
96806bf55f | ||
|
|
d33cd4c95b | ||
|
|
e2c2f64be7 | ||
|
|
cb73b94ffb | ||
|
|
1d920d700c | ||
|
|
bb85eee40a | ||
|
|
aba5d6f0d2 | ||
|
|
0588d2dbe1 | ||
|
|
14b3b77bda | ||
|
|
6da34e483c | ||
|
|
0efef671d7 | ||
|
|
435203b13b | ||
|
|
decb5dd3af | ||
|
|
28fbf96b2a | ||
|
|
9d1a37c644 | ||
|
|
5bf2372c4d | ||
|
|
c2c6396a04 | ||
|
|
aaf813fc0c | ||
|
|
d97fe84296 | ||
|
|
81f45dab21 | ||
|
|
d670e7382a | ||
|
|
cd8c686339 | ||
|
|
f5c41e3018 | ||
|
|
2477f99d89 | ||
|
|
940f530ac2 | ||
|
|
4d3e2f8ad3 | ||
|
|
5ae986e1c4 | ||
|
|
e5914e4e8b | ||
|
|
c238f5425f | ||
|
|
3c3097ea74 | ||
|
|
405c3f4003 | ||
|
|
6553ce4cea | ||
|
|
a62d472bc4 | ||
|
|
434ac7f0f5 | ||
|
|
7bbe71c3ac | ||
|
|
04dcaadabf | ||
|
|
c522506849 | ||
|
|
0765352c99 | ||
|
|
13807f1b3d | ||
|
|
c919ea149e | ||
|
|
be6fd119d8 | ||
|
|
7abf730d77 | ||
|
|
ec96c5ecaf | ||
|
|
7e1cce4b9f | ||
|
|
7be8752a00 | ||
|
|
145d82f322 | ||
|
|
a8b9570700 | ||
|
|
6ff6d84646 | ||
|
|
9aaa05e8ea | ||
|
|
0af5a0441f | ||
|
|
0fc63ea0ba | ||
|
|
0b329f7881 | ||
|
|
5b85edb753 | ||
|
|
17cfa5fe1e | ||
|
|
2313494e0e | ||
|
|
fd9d430334 | ||
|
|
91f0d5cefd | ||
|
|
82762280ee | ||
|
|
b550a2face | ||
|
|
ab77508950 | ||
|
|
b9462f5c6b | ||
|
|
5ffaa5cdd6 | ||
|
|
a1858a9cb7 | ||
|
|
212b34f639 | ||
|
|
af8eaa23e2 | ||
|
|
f0eed50678 | ||
|
|
19d94c6158 | ||
|
|
628eb56073 | ||
|
|
a590c38d8b | ||
|
|
4e149c9222 | ||
|
|
59f5b34280 | ||
|
|
dff06d0898 | ||
|
|
80a8816b1d | ||
|
|
387e374e4b | ||
|
|
3e6baea405 | ||
|
|
fe9b844511 | ||
|
|
2e1aa497d2 | ||
|
|
529c0314f8 | ||
|
|
d86875aeac | ||
|
|
f80fe506d5 | ||
|
|
967c6f3cd3 | ||
|
|
e50e124e70 | ||
|
|
c545689448 | ||
|
|
8f389fef19 | ||
|
|
d3d6a327e0 | ||
|
|
b5489d4986 | ||
|
|
7a23c57cf8 | ||
|
|
11f891220e | ||
|
|
5585adce18 | ||
|
|
f884299823 | ||
|
|
15aa6bae1b | ||
|
|
11eb725ac8 | ||
|
|
30c02ab78c | ||
|
|
3acd86e346 | ||
|
|
5c20f13c48 | ||
|
|
e6587b071d | ||
|
|
85451ab4cd | ||
|
|
a7f3ba03eb | ||
|
|
4f0a3a77ad | ||
|
|
44655ca9b5 | ||
|
|
e601278117 | ||
|
|
8e7b016be2 | ||
|
|
9e01ea7aae | ||
|
|
cfc7ec8bb9 | ||
|
|
b3bbc0e5c6 | ||
|
|
d7c8e37ff4 | ||
|
|
05b66e73bc | ||
|
|
01ceedac89 | ||
|
|
403babd433 | ||
|
|
47133031e5 | ||
|
|
82da606886 | ||
|
|
bbe5ae2145 | ||
|
|
0b21498b39 | ||
|
|
0ca59535f1 | ||
|
|
59c77d0658 | ||
|
|
333e045099 | ||
|
|
c2c4d9d336 | ||
|
|
9a6a72e88e | ||
|
|
afe6d9fca4 | ||
|
|
ef82905526 | ||
|
|
d18747e846 | ||
|
|
f341d69314 | ||
|
|
327142837c | ||
|
|
f8c0321aee | ||
|
|
89115ff76a | ||
|
|
63c83aa8d2 | ||
|
|
37f025c966 | ||
|
|
4a54f0d670 | ||
|
|
98890a29e3 | ||
|
|
9d123ec059 | ||
|
|
5d171f181a | ||
|
|
22f878b3b7 | ||
|
|
44ef1a18dd | ||
|
|
2b98dc4e52 | ||
|
|
2a26cb4567 | ||
|
|
5ca1b64328 | ||
|
|
36752a8cbb | ||
|
|
f117fc7509 | ||
|
|
fc6b93ae59 | ||
|
|
564fa4ab04 | ||
|
|
a6db88fbd2 | ||
|
|
4b5294e596 | ||
|
|
a322dce42a | ||
|
|
d1ead2265b | ||
|
|
bbca74476e | ||
|
|
318cf59d66 |
14
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
14
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Community Support
|
||||||
|
url: https://forum.netbird.io/
|
||||||
|
about: Community support forum
|
||||||
|
- name: Cloud Support
|
||||||
|
url: https://docs.netbird.io/help/report-bug-issues
|
||||||
|
about: Contact us for support
|
||||||
|
- name: Client/Connection Troubleshooting
|
||||||
|
url: https://docs.netbird.io/help/troubleshooting-client
|
||||||
|
about: See our client troubleshooting guide for help addressing common issues
|
||||||
|
- name: Self-host Troubleshooting
|
||||||
|
url: https://docs.netbird.io/selfhosted/troubleshooting
|
||||||
|
about: See our self-host troubleshooting guide for help addressing common issues
|
||||||
26
.github/issue-resolution/prompts/issue-resolution-system.txt
vendored
Normal file
26
.github/issue-resolution/prompts/issue-resolution-system.txt
vendored
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
You are a GitHub issue resolution classifier.
|
||||||
|
|
||||||
|
Your job is to decide whether an open GitHub issue is:
|
||||||
|
- AUTO_CLOSE
|
||||||
|
- MANUAL_REVIEW
|
||||||
|
- KEEP_OPEN
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
1. AUTO_CLOSE is only allowed if there is objective, hard evidence:
|
||||||
|
- a merged linked PR that clearly resolves the issue, or
|
||||||
|
- an explicit maintainer/member/owner/collaborator comment saying the issue is fixed, resolved, duplicate, or superseded
|
||||||
|
2. If there is any contradictory later evidence, do NOT AUTO_CLOSE.
|
||||||
|
3. If evidence is promising but not airtight, choose MANUAL_REVIEW.
|
||||||
|
4. If the issue still appears active or unresolved, choose KEEP_OPEN.
|
||||||
|
5. Do not invent evidence.
|
||||||
|
6. Output valid JSON only.
|
||||||
|
|
||||||
|
Maintainer-authoritative roles:
|
||||||
|
- MEMBER
|
||||||
|
- OWNER
|
||||||
|
- COLLABORATOR
|
||||||
|
|
||||||
|
Important:
|
||||||
|
- Later comments outweigh earlier ones.
|
||||||
|
- A non-maintainer saying "fixed for me" is not enough for AUTO_CLOSE.
|
||||||
|
- If uncertain, prefer MANUAL_REVIEW or KEEP_OPEN.
|
||||||
78
.github/issue-resolution/schemas/issue-resolution-output.json
vendored
Normal file
78
.github/issue-resolution/schemas/issue-resolution-output.json
vendored
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"decision",
|
||||||
|
"reason_code",
|
||||||
|
"confidence",
|
||||||
|
"hard_signals",
|
||||||
|
"contradictions",
|
||||||
|
"summary",
|
||||||
|
"close_comment",
|
||||||
|
"manual_review_note"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"decision": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["AUTO_CLOSE", "MANUAL_REVIEW", "KEEP_OPEN"]
|
||||||
|
},
|
||||||
|
"reason_code": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"resolved_by_merged_pr",
|
||||||
|
"maintainer_confirmed_resolved",
|
||||||
|
"duplicate_confirmed",
|
||||||
|
"superseded_confirmed",
|
||||||
|
"likely_fixed_but_unconfirmed",
|
||||||
|
"still_open",
|
||||||
|
"unclear"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"confidence": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 1
|
||||||
|
},
|
||||||
|
"hard_signals": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["type", "url"],
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"merged_pr",
|
||||||
|
"maintainer_comment",
|
||||||
|
"duplicate_reference",
|
||||||
|
"superseded_reference"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"url": { "type": "string" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"contradictions": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["type", "url"],
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"reporter_still_broken",
|
||||||
|
"later_unresolved_comment",
|
||||||
|
"ambiguous_pr_link",
|
||||||
|
"other"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"url": { "type": "string" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"summary": { "type": "string" },
|
||||||
|
"close_comment": { "type": "string" },
|
||||||
|
"manual_review_note": { "type": "string" }
|
||||||
|
}
|
||||||
|
}
|
||||||
152
.github/issue-resolution/scripts/apply-decisions.mjs
vendored
Normal file
152
.github/issue-resolution/scripts/apply-decisions.mjs
vendored
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
import fs from "node:fs/promises";
|
||||||
|
|
||||||
|
const decisions = JSON.parse(await fs.readFile("decisions.json", "utf8"));
|
||||||
|
const dryRun = String(process.env.DRY_RUN).toLowerCase() === "true";
|
||||||
|
|
||||||
|
const headers = {
|
||||||
|
Authorization: `Bearer ${process.env.GH_TOKEN}`,
|
||||||
|
Accept: "application/vnd.github+json",
|
||||||
|
"X-GitHub-Api-Version": "2022-11-28",
|
||||||
|
};
|
||||||
|
|
||||||
|
async function rest(url, method = "GET", body) {
|
||||||
|
const res = await fetch(url, {
|
||||||
|
method,
|
||||||
|
headers,
|
||||||
|
body: body ? JSON.stringify(body) : undefined
|
||||||
|
});
|
||||||
|
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
|
||||||
|
return res.status === 204 ? null : res.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
async function graphql(query, variables) {
|
||||||
|
const res = await fetch("https://api.github.com/graphql", {
|
||||||
|
method: "POST",
|
||||||
|
headers,
|
||||||
|
body: JSON.stringify({ query, variables })
|
||||||
|
});
|
||||||
|
if (!res.ok) throw new Error(`${res.status}: ${await res.text()}`);
|
||||||
|
const json = await res.json();
|
||||||
|
if (json.errors) throw new Error(JSON.stringify(json.errors));
|
||||||
|
return json.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function addLabel(owner, repo, issueNumber, labels) {
|
||||||
|
return rest(
|
||||||
|
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/labels`,
|
||||||
|
"POST",
|
||||||
|
{ labels }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function addComment(owner, repo, issueNumber, body) {
|
||||||
|
return rest(
|
||||||
|
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/comments`,
|
||||||
|
"POST",
|
||||||
|
{ body }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function closeIssue(owner, repo, issueNumber) {
|
||||||
|
return rest(
|
||||||
|
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`,
|
||||||
|
"PATCH",
|
||||||
|
{ state: "closed", state_reason: "completed" }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function getIssueNodeId(owner, repo, issueNumber) {
|
||||||
|
const issue = await rest(`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`);
|
||||||
|
return issue.node_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function addToProject(issueNodeId) {
|
||||||
|
const mutation = `
|
||||||
|
mutation($projectId: ID!, $contentId: ID!) {
|
||||||
|
addProjectV2ItemById(input: {projectId: $projectId, contentId: $contentId}) {
|
||||||
|
item { id }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`;
|
||||||
|
|
||||||
|
const data = await graphql(mutation, {
|
||||||
|
projectId: process.env.PROJECT_ID,
|
||||||
|
contentId: issueNodeId
|
||||||
|
});
|
||||||
|
|
||||||
|
return data.addProjectV2ItemById.item.id;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function setTextField(itemId, fieldId, value) {
|
||||||
|
const mutation = `
|
||||||
|
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
|
||||||
|
updateProjectV2ItemFieldValue(input: {
|
||||||
|
projectId: $projectId,
|
||||||
|
itemId: $itemId,
|
||||||
|
fieldId: $fieldId,
|
||||||
|
value: { text: $value }
|
||||||
|
}) {
|
||||||
|
projectV2Item { id }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`;
|
||||||
|
|
||||||
|
return graphql(mutation, {
|
||||||
|
projectId: process.env.PROJECT_ID,
|
||||||
|
itemId,
|
||||||
|
fieldId,
|
||||||
|
value
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const d of decisions) {
|
||||||
|
const [owner, repo] = d.repository.split("/");
|
||||||
|
|
||||||
|
if (d.final_decision === "AUTO_CLOSE") {
|
||||||
|
if (dryRun) continue;
|
||||||
|
|
||||||
|
await addLabel(owner, repo, d.issue_number, ["auto-closed-resolved"]);
|
||||||
|
await addComment(
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
d.issue_number,
|
||||||
|
d.model.close_comment ||
|
||||||
|
"This appears resolved based on linked evidence, so we’re closing it automatically. Reply if this still reproduces and we’ll reopen."
|
||||||
|
);
|
||||||
|
await closeIssue(owner, repo, d.issue_number);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (d.final_decision === "MANUAL_REVIEW") {
|
||||||
|
await addLabel(owner, repo, d.issue_number, ["resolution-candidate"]);
|
||||||
|
|
||||||
|
const issueNodeId = await getIssueNodeId(owner, repo, d.issue_number);
|
||||||
|
const itemId = await addToProject(issueNodeId);
|
||||||
|
|
||||||
|
if (process.env.PROJECT_CONFIDENCE_FIELD_ID) {
|
||||||
|
await setTextField(itemId, process.env.PROJECT_CONFIDENCE_FIELD_ID, String(d.model.confidence));
|
||||||
|
}
|
||||||
|
if (process.env.PROJECT_REASON_FIELD_ID) {
|
||||||
|
await setTextField(itemId, process.env.PROJECT_REASON_FIELD_ID, d.model.reason_code);
|
||||||
|
}
|
||||||
|
if (process.env.PROJECT_EVIDENCE_FIELD_ID) {
|
||||||
|
await setTextField(itemId, process.env.PROJECT_EVIDENCE_FIELD_ID, d.issue_url);
|
||||||
|
}
|
||||||
|
if (process.env.PROJECT_LINKED_PR_FIELD_ID) {
|
||||||
|
const linked = (d.model.hard_signals || []).map(x => x.url).join(", ");
|
||||||
|
if (linked) {
|
||||||
|
await setTextField(itemId, process.env.PROJECT_LINKED_PR_FIELD_ID, linked);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (process.env.PROJECT_REPO_FIELD_ID) {
|
||||||
|
await setTextField(itemId, process.env.PROJECT_REPO_FIELD_ID, d.repository);
|
||||||
|
}
|
||||||
|
|
||||||
|
await addComment(
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
d.issue_number,
|
||||||
|
d.model.manual_review_note ||
|
||||||
|
"This issue looks like a possible resolution candidate, but not with enough certainty for automatic closure. Added to the review queue."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
125
.github/issue-resolution/scripts/classify-candidates.mjs
vendored
Normal file
125
.github/issue-resolution/scripts/classify-candidates.mjs
vendored
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import fs from "node:fs/promises";
|
||||||
|
|
||||||
|
const candidates = JSON.parse(await fs.readFile("candidates.json", "utf8"));
|
||||||
|
|
||||||
|
function isMaintainerRole(role) {
|
||||||
|
return ["MEMBER", "OWNER", "COLLABORATOR"].includes(role || "");
|
||||||
|
}
|
||||||
|
|
||||||
|
function preScore(candidate) {
|
||||||
|
let score = 0;
|
||||||
|
const hardSignals = [];
|
||||||
|
const contradictions = [];
|
||||||
|
|
||||||
|
for (const t of candidate.timeline) {
|
||||||
|
const sourceIssue = t.source?.issue;
|
||||||
|
|
||||||
|
if (t.event === "cross-referenced" && sourceIssue?.pull_request?.html_url) {
|
||||||
|
hardSignals.push({
|
||||||
|
type: "merged_pr",
|
||||||
|
url: sourceIssue.html_url
|
||||||
|
});
|
||||||
|
score += 40; // provisional until PR merged state is verified
|
||||||
|
}
|
||||||
|
|
||||||
|
if (["referenced", "connected"].includes(t.event)) {
|
||||||
|
score += 10;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const c of candidate.comments) {
|
||||||
|
const body = c.body.toLowerCase();
|
||||||
|
|
||||||
|
if (
|
||||||
|
isMaintainerRole(c.author_association) &&
|
||||||
|
/\b(fixed|resolved|duplicate|superseded|closing)\b/.test(body)
|
||||||
|
) {
|
||||||
|
score += 25;
|
||||||
|
hardSignals.push({
|
||||||
|
type: "maintainer_comment",
|
||||||
|
url: c.html_url
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (/\b(still broken|still happening|not fixed|reproducible)\b/.test(body)) {
|
||||||
|
score -= 50;
|
||||||
|
contradictions.push({
|
||||||
|
type: "later_unresolved_comment",
|
||||||
|
url: c.html_url
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { score, hardSignals, contradictions };
|
||||||
|
}
|
||||||
|
|
||||||
|
async function callGitHubModel(issuePacket) {
|
||||||
|
// Replace this stub with the GitHub Models inference call used by your org.
|
||||||
|
// The workflow already has models: read permission.
|
||||||
|
return {
|
||||||
|
decision: "MANUAL_REVIEW",
|
||||||
|
reason_code: "likely_fixed_but_unconfirmed",
|
||||||
|
confidence: 0.74,
|
||||||
|
hard_signals: [],
|
||||||
|
contradictions: [],
|
||||||
|
summary: "Potential resolution candidate; evidence is not strong enough to close automatically.",
|
||||||
|
close_comment: "This appears resolved, so we’re closing it automatically. Reply if this is still reproducible.",
|
||||||
|
manual_review_note: "Potential resolution candidate. Please review evidence before closing."
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function enforcePolicy(modelOut, pre) {
|
||||||
|
const approvedReasons = new Set([
|
||||||
|
"resolved_by_merged_pr",
|
||||||
|
"maintainer_confirmed_resolved",
|
||||||
|
"duplicate_confirmed",
|
||||||
|
"superseded_confirmed"
|
||||||
|
]);
|
||||||
|
|
||||||
|
const hasHardSignal =
|
||||||
|
(modelOut.hard_signals || []).some(s =>
|
||||||
|
["merged_pr", "maintainer_comment", "duplicate_reference", "superseded_reference"].includes(s.type)
|
||||||
|
) || pre.hardSignals.length > 0;
|
||||||
|
|
||||||
|
const hasContradiction =
|
||||||
|
(modelOut.contradictions || []).length > 0 || pre.contradictions.length > 0;
|
||||||
|
|
||||||
|
if (
|
||||||
|
modelOut.decision === "AUTO_CLOSE" &&
|
||||||
|
modelOut.confidence >= 0.97 &&
|
||||||
|
approvedReasons.has(modelOut.reason_code) &&
|
||||||
|
hasHardSignal &&
|
||||||
|
!hasContradiction
|
||||||
|
) {
|
||||||
|
return "AUTO_CLOSE";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
modelOut.decision === "MANUAL_REVIEW" ||
|
||||||
|
modelOut.confidence >= 0.60 ||
|
||||||
|
pre.score >= 25
|
||||||
|
) {
|
||||||
|
return "MANUAL_REVIEW";
|
||||||
|
}
|
||||||
|
|
||||||
|
return "KEEP_OPEN";
|
||||||
|
}
|
||||||
|
|
||||||
|
const decisions = [];
|
||||||
|
for (const candidate of candidates) {
|
||||||
|
const pre = preScore(candidate);
|
||||||
|
const modelOut = await callGitHubModel(candidate);
|
||||||
|
const finalDecision = enforcePolicy(modelOut, pre);
|
||||||
|
|
||||||
|
decisions.push({
|
||||||
|
repository: candidate.repository,
|
||||||
|
issue_number: candidate.issue.number,
|
||||||
|
issue_url: candidate.issue.html_url,
|
||||||
|
title: candidate.issue.title,
|
||||||
|
pre_score: pre.score,
|
||||||
|
final_decision: finalDecision,
|
||||||
|
model: modelOut
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
await fs.writeFile("decisions.json", JSON.stringify(decisions, null, 2));
|
||||||
@@ -31,7 +31,7 @@ jobs:
|
|||||||
while IFS= read -r dir; do
|
while IFS= read -r dir; do
|
||||||
echo "=== Checking $dir ==="
|
echo "=== Checking $dir ==="
|
||||||
# Search for problematic imports, excluding test files
|
# Search for problematic imports, excluding test files
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
|
||||||
if [ -n "$RESULTS" ]; then
|
if [ -n "$RESULTS" ]; then
|
||||||
echo "❌ Found problematic dependencies:"
|
echo "❌ Found problematic dependencies:"
|
||||||
echo "$RESULTS"
|
echo "$RESULTS"
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
# Check if any importer is NOT in management/signal/relay
|
# Check if any importer is NOT in management/signal/relay
|
||||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||||
|
|
||||||
if [ -n "$BSD_IMPORTER" ]; then
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
|
|||||||
37
.github/workflows/golang-test-linux.yml
vendored
37
.github/workflows/golang-test-linux.yml
vendored
@@ -409,12 +409,19 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@v1
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: docker login for root user
|
||||||
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
|
env:
|
||||||
|
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||||
|
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||||
|
|
||||||
- name: download mysql image
|
- name: download mysql image
|
||||||
if: matrix.store == 'mysql'
|
if: matrix.store == 'mysql'
|
||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
@@ -497,15 +504,18 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@v1
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
- name: download mysql image
|
- name: docker login for root user
|
||||||
if: matrix.store == 'mysql'
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
env:
|
||||||
|
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||||
|
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
@@ -586,15 +596,18 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@v1
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
- name: download mysql image
|
- name: docker login for root user
|
||||||
if: matrix.store == 'mysql'
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
env:
|
||||||
|
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||||
|
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
9
.github/workflows/golang-test-windows.yml
vendored
9
.github/workflows/golang-test-windows.yml
vendored
@@ -63,10 +63,15 @@ jobs:
|
|||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
|
- name: Generate test script
|
||||||
|
run: |
|
||||||
|
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
||||||
|
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
||||||
|
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||||
|
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "${{ github.workspace }}\run-tests.cmd"
|
||||||
- name: test output
|
- name: test output
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
run: Get-Content test-out.txt
|
run: Get-Content test-out.txt
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
|
||||||
skip: go.mod,go.sum,**/proxy/web/**
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
50
.github/workflows/issue-resolution-triage.yml
vendored
Normal file
50
.github/workflows/issue-resolution-triage.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
name: issue-resolution-triage
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
dry_run:
|
||||||
|
description: "If true, do not close issues"
|
||||||
|
required: false
|
||||||
|
default: "true"
|
||||||
|
max_issues:
|
||||||
|
description: "How many issues to process"
|
||||||
|
required: false
|
||||||
|
default: "100"
|
||||||
|
schedule:
|
||||||
|
- cron: "17 2 * * *"
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
issues: write
|
||||||
|
pull-requests: read
|
||||||
|
models: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
triage:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
DRY_RUN: ${{ inputs.dry_run || 'true' }}
|
||||||
|
MAX_ISSUES: ${{ inputs.max_issues || '100' }}
|
||||||
|
REPO: ${{ github.repository }}
|
||||||
|
PROJECT_ID: ${{ vars.ISSUE_REVIEW_PROJECT_ID }}
|
||||||
|
PROJECT_STATUS_FIELD_ID: ${{ vars.PROJECT_STATUS_FIELD_ID }}
|
||||||
|
PROJECT_CONFIDENCE_FIELD_ID: ${{ vars.PROJECT_CONFIDENCE_FIELD_ID }}
|
||||||
|
PROJECT_REASON_FIELD_ID: ${{ vars.PROJECT_REASON_FIELD_ID }}
|
||||||
|
PROJECT_EVIDENCE_FIELD_ID: ${{ vars.PROJECT_EVIDENCE_FIELD_ID }}
|
||||||
|
PROJECT_LINKED_PR_FIELD_ID: ${{ vars.PROJECT_LINKED_PR_FIELD_ID }}
|
||||||
|
PROJECT_REPO_FIELD_ID: ${{ vars.PROJECT_REPO_FIELD_ID }}
|
||||||
|
PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID: ${{ vars.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "20"
|
||||||
|
|
||||||
|
- run: npm ci
|
||||||
|
- run: node scripts/fetch-candidates.mjs
|
||||||
|
- run: node scripts/classify-candidates.mjs
|
||||||
|
- run: node scripts/apply-decisions.mjs
|
||||||
51
.github/workflows/pr-title-check.yml
vendored
Normal file
51
.github/workflows/pr-title-check.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
name: PR Title Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, edited, synchronize, reopened]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-title:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Validate PR title prefix
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const title = context.payload.pull_request.title;
|
||||||
|
const allowedTags = [
|
||||||
|
'management',
|
||||||
|
'client',
|
||||||
|
'signal',
|
||||||
|
'proxy',
|
||||||
|
'relay',
|
||||||
|
'misc',
|
||||||
|
'infrastructure',
|
||||||
|
'self-hosted',
|
||||||
|
'doc',
|
||||||
|
];
|
||||||
|
|
||||||
|
const pattern = /^\[([^\]]+)\]\s+.+/;
|
||||||
|
const match = title.match(pattern);
|
||||||
|
|
||||||
|
if (!match) {
|
||||||
|
core.setFailed(
|
||||||
|
`PR title must start with a tag in brackets.\n` +
|
||||||
|
`Example: [client] fix something\n` +
|
||||||
|
`Allowed tags: ${allowedTags.join(', ')}`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
|
||||||
|
|
||||||
|
const invalid = tags.filter(t => !allowedTags.includes(t));
|
||||||
|
if (invalid.length > 0) {
|
||||||
|
core.setFailed(
|
||||||
|
`Invalid tag(s): ${invalid.join(', ')}\n` +
|
||||||
|
`Allowed tags: ${allowedTags.join(', ')}`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`Valid PR title tags: [${tags.join(', ')}]`);
|
||||||
62
.github/workflows/proto-version-check.yml
vendored
Normal file
62
.github/workflows/proto-version-check.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
name: Proto Version Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- "**/*.pb.go"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-proto-versions:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Check for proto tool version changes
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
pull_number: context.issue.number,
|
||||||
|
per_page: 100,
|
||||||
|
});
|
||||||
|
|
||||||
|
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
|
||||||
|
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
|
||||||
|
if (missingPatch.length > 0) {
|
||||||
|
core.setFailed(
|
||||||
|
`Cannot inspect patch data for:\n` +
|
||||||
|
missingPatch.map(f => `- ${f}`).join('\n') +
|
||||||
|
`\nThis can happen with very large PRs. Verify proto versions manually.`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||||
|
const violations = [];
|
||||||
|
|
||||||
|
for (const file of pbFiles) {
|
||||||
|
const changed = file.patch
|
||||||
|
.split('\n')
|
||||||
|
.filter(line => versionPattern.test(line));
|
||||||
|
if (changed.length > 0) {
|
||||||
|
violations.push({
|
||||||
|
file: file.filename,
|
||||||
|
lines: changed,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (violations.length > 0) {
|
||||||
|
const details = violations.map(v =>
|
||||||
|
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
|
||||||
|
).join('\n\n');
|
||||||
|
|
||||||
|
core.setFailed(
|
||||||
|
`Proto version strings changed in generated files.\n` +
|
||||||
|
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
|
||||||
|
`Regenerate with the matching tool versions.\n\n` +
|
||||||
|
details
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log('No proto version string changes detected');
|
||||||
90
.github/workflows/release.yml
vendored
90
.github/workflows/release.yml
vendored
@@ -9,8 +9,8 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.1"
|
SIGN_PIPE_VER: "v0.1.2"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.14.3"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|
||||||
@@ -169,6 +169,14 @@ jobs:
|
|||||||
- name: Install OS build dependencies
|
- name: Install OS build dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||||
|
|
||||||
|
- name: Decode GPG signing key
|
||||||
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||||
|
env:
|
||||||
|
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||||
|
run: |
|
||||||
|
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
|
||||||
|
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Install goversioninfo
|
- name: Install goversioninfo
|
||||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||||
- name: Generate windows syso amd64
|
- name: Generate windows syso amd64
|
||||||
@@ -186,18 +194,54 @@ jobs:
|
|||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
- name: Tag and push PR images (amd64 only)
|
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||||
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
|
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||||
|
- name: Verify RPM signatures
|
||||||
run: |
|
run: |
|
||||||
PR_TAG="pr-${{ github.event.pull_request.number }}"
|
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||||
|
dnf install -y -q rpm-sign curl >/dev/null 2>&1
|
||||||
|
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
|
||||||
|
rpm --import /tmp/rpm-pub.key
|
||||||
|
echo "=== Verifying RPM signatures ==="
|
||||||
|
for rpm_file in /dist/*amd64*.rpm; do
|
||||||
|
[ -f "$rpm_file" ] || continue
|
||||||
|
echo "--- $(basename $rpm_file) ---"
|
||||||
|
rpm -K "$rpm_file"
|
||||||
|
done
|
||||||
|
'
|
||||||
|
- name: Clean up GPG key
|
||||||
|
if: always()
|
||||||
|
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||||
|
- name: Tag and push images (amd64 only)
|
||||||
|
if: |
|
||||||
|
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
|
||||||
|
(github.event_name == 'push' && github.ref == 'refs/heads/main')
|
||||||
|
run: |
|
||||||
|
resolve_tags() {
|
||||||
|
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||||
|
echo "pr-${{ github.event.pull_request.number }}"
|
||||||
|
else
|
||||||
|
echo "main sha-$(git rev-parse --short HEAD)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
tag_and_push() {
|
||||||
|
local src="$1" img_name tag dst
|
||||||
|
img_name="${src%%:*}"
|
||||||
|
for tag in $(resolve_tags); do
|
||||||
|
dst="${img_name}:${tag}"
|
||||||
|
echo "Tagging ${src} -> ${dst}"
|
||||||
|
docker tag "$src" "$dst"
|
||||||
|
docker push "$dst"
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
export -f tag_and_push resolve_tags
|
||||||
|
|
||||||
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||||
grep '^ghcr.io/' | while read -r SRC; do
|
grep '^ghcr.io/' | while read -r SRC; do
|
||||||
IMG_NAME="${SRC%%:*}"
|
tag_and_push "$SRC"
|
||||||
DST="${IMG_NAME}:${PR_TAG}"
|
|
||||||
echo "Tagging ${SRC} -> ${DST}"
|
|
||||||
docker tag "$SRC" "$DST"
|
|
||||||
docker push "$DST"
|
|
||||||
done
|
done
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
@@ -265,6 +309,14 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||||
|
|
||||||
|
- name: Decode GPG signing key
|
||||||
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||||
|
env:
|
||||||
|
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||||
|
run: |
|
||||||
|
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
|
||||||
|
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Install LLVM-MinGW for ARM64 cross-compilation
|
- name: Install LLVM-MinGW for ARM64 cross-compilation
|
||||||
run: |
|
run: |
|
||||||
cd /tmp
|
cd /tmp
|
||||||
@@ -289,6 +341,24 @@ jobs:
|
|||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
|
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||||
|
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||||
|
- name: Verify RPM signatures
|
||||||
|
run: |
|
||||||
|
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||||
|
dnf install -y -q rpm-sign curl >/dev/null 2>&1
|
||||||
|
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
|
||||||
|
rpm --import /tmp/rpm-pub.key
|
||||||
|
echo "=== Verifying RPM signatures ==="
|
||||||
|
for rpm_file in /dist/*.rpm; do
|
||||||
|
[ -f "$rpm_file" ] || continue
|
||||||
|
echo "--- $(basename $rpm_file) ---"
|
||||||
|
rpm -K "$rpm_file"
|
||||||
|
done
|
||||||
|
'
|
||||||
|
- name: Clean up GPG key
|
||||||
|
if: always()
|
||||||
|
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -61,8 +61,8 @@ jobs:
|
|||||||
|
|
||||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||||
|
|
||||||
if [ ${SIZE} -gt 57671680 ]; then
|
if [ ${SIZE} -gt 58720256 ]; then
|
||||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -154,6 +154,26 @@ builds:
|
|||||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-idp-migrate
|
||||||
|
dir: tools/idp-migrate
|
||||||
|
env:
|
||||||
|
- CGO_ENABLED=1
|
||||||
|
- >-
|
||||||
|
{{- if eq .Runtime.Goos "linux" }}
|
||||||
|
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||||
|
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
binary: netbird-idp-migrate
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- arm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
universal_binaries:
|
universal_binaries:
|
||||||
- id: netbird
|
- id: netbird
|
||||||
|
|
||||||
@@ -166,18 +186,22 @@ archives:
|
|||||||
- netbird-wasm
|
- netbird-wasm
|
||||||
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
||||||
format: binary
|
format: binary
|
||||||
|
- id: netbird-idp-migrate
|
||||||
|
builds:
|
||||||
|
- netbird-idp-migrate
|
||||||
|
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||||
|
|
||||||
nfpms:
|
nfpms:
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
description: Netbird client.
|
description: Netbird client.
|
||||||
homepage: https://netbird.io/
|
homepage: https://netbird.io/
|
||||||
id: netbird-deb
|
license: BSD-3-Clause
|
||||||
|
id: netbird_deb
|
||||||
bindir: /usr/bin
|
bindir: /usr/bin
|
||||||
builds:
|
builds:
|
||||||
- netbird
|
- netbird
|
||||||
formats:
|
formats:
|
||||||
- deb
|
- deb
|
||||||
|
|
||||||
scripts:
|
scripts:
|
||||||
postinstall: "release_files/post_install.sh"
|
postinstall: "release_files/post_install.sh"
|
||||||
preremove: "release_files/pre_remove.sh"
|
preremove: "release_files/pre_remove.sh"
|
||||||
@@ -185,16 +209,19 @@ nfpms:
|
|||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
description: Netbird client.
|
description: Netbird client.
|
||||||
homepage: https://netbird.io/
|
homepage: https://netbird.io/
|
||||||
id: netbird-rpm
|
license: BSD-3-Clause
|
||||||
|
id: netbird_rpm
|
||||||
bindir: /usr/bin
|
bindir: /usr/bin
|
||||||
builds:
|
builds:
|
||||||
- netbird
|
- netbird
|
||||||
formats:
|
formats:
|
||||||
- rpm
|
- rpm
|
||||||
|
|
||||||
scripts:
|
scripts:
|
||||||
postinstall: "release_files/post_install.sh"
|
postinstall: "release_files/post_install.sh"
|
||||||
preremove: "release_files/pre_remove.sh"
|
preremove: "release_files/pre_remove.sh"
|
||||||
|
rpm:
|
||||||
|
signature:
|
||||||
|
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||||
dockers:
|
dockers:
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-amd64
|
- netbirdio/netbird:{{ .Version }}-amd64
|
||||||
@@ -876,7 +903,7 @@ brews:
|
|||||||
uploads:
|
uploads:
|
||||||
- name: debian
|
- name: debian
|
||||||
ids:
|
ids:
|
||||||
- netbird-deb
|
- netbird_deb
|
||||||
mode: archive
|
mode: archive
|
||||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
@@ -884,7 +911,7 @@ uploads:
|
|||||||
|
|
||||||
- name: yum
|
- name: yum
|
||||||
ids:
|
ids:
|
||||||
- netbird-rpm
|
- netbird_rpm
|
||||||
mode: archive
|
mode: archive
|
||||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ nfpms:
|
|||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
description: Netbird client UI.
|
description: Netbird client UI.
|
||||||
homepage: https://netbird.io/
|
homepage: https://netbird.io/
|
||||||
id: netbird-ui-deb
|
id: netbird_ui_deb
|
||||||
package_name: netbird-ui
|
package_name: netbird-ui
|
||||||
builds:
|
builds:
|
||||||
- netbird-ui
|
- netbird-ui
|
||||||
@@ -80,7 +80,7 @@ nfpms:
|
|||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
description: Netbird client UI.
|
description: Netbird client UI.
|
||||||
homepage: https://netbird.io/
|
homepage: https://netbird.io/
|
||||||
id: netbird-ui-rpm
|
id: netbird_ui_rpm
|
||||||
package_name: netbird-ui
|
package_name: netbird-ui
|
||||||
builds:
|
builds:
|
||||||
- netbird-ui
|
- netbird-ui
|
||||||
@@ -95,11 +95,14 @@ nfpms:
|
|||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
|
rpm:
|
||||||
|
signature:
|
||||||
|
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||||
|
|
||||||
uploads:
|
uploads:
|
||||||
- name: debian
|
- name: debian
|
||||||
ids:
|
ids:
|
||||||
- netbird-ui-deb
|
- netbird_ui_deb
|
||||||
mode: archive
|
mode: archive
|
||||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
@@ -107,7 +110,7 @@ uploads:
|
|||||||
|
|
||||||
- name: yum
|
- name: yum
|
||||||
ids:
|
ids:
|
||||||
- netbird-ui-rpm
|
- netbird_ui_rpm
|
||||||
mode: archive
|
mode: archive
|
||||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
## Contributor License Agreement
|
## Contributor License Agreement
|
||||||
|
|
||||||
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
||||||
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
|
submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany,
|
||||||
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
||||||
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
||||||
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
||||||
|
|||||||
2
Makefile
2
Makefile
@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
|||||||
$(GOLANGCI_LINT):
|
$(GOLANGCI_LINT):
|
||||||
@echo "Installing golangci-lint..."
|
@echo "Installing golangci-lint..."
|
||||||
@mkdir -p ./bin
|
@mkdir -p ./bin
|
||||||
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
|
||||||
|
|
||||||
# Lint only changed files (fast, for pre-push)
|
# Lint only changed files (fast, for pre-push)
|
||||||
lint: $(GOLANGCI_LINT)
|
lint: $(GOLANGCI_LINT)
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ See a complete [architecture overview](https://docs.netbird.io/about-netbird/how
|
|||||||
### Community projects
|
### Community projects
|
||||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||||
|
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
|
||||||
|
|
||||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||||
|
|
||||||
FROM alpine:3.23.2
|
FROM alpine:3.23.3
|
||||||
# iproute2: busybox doesn't display ip rules properly
|
# iproute2: busybox doesn't display ip rules properly
|
||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
bash \
|
bash \
|
||||||
@@ -17,8 +17,7 @@ ENV \
|
|||||||
NETBIRD_BIN="/usr/local/bin/netbird" \
|
NETBIRD_BIN="/usr/local/bin/netbird" \
|
||||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
|
|
||||||
|
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,7 @@ ENV \
|
|||||||
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
||||||
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
||||||
NB_DISABLE_DNS="true" \
|
NB_DISABLE_DNS="true" \
|
||||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
|
|
||||||
|
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
@@ -15,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -26,6 +28,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
types "github.com/netbirdio/netbird/upload-server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -68,7 +71,30 @@ type Client struct {
|
|||||||
uiVersion string
|
uiVersion string
|
||||||
networkChangeListener listener.NetworkChangeListener
|
networkChangeListener listener.NetworkChangeListener
|
||||||
|
|
||||||
|
stateMu sync.RWMutex
|
||||||
connectClient *internal.ConnectClient
|
connectClient *internal.ConnectClient
|
||||||
|
config *profilemanager.Config
|
||||||
|
cacheDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
|
||||||
|
c.stateMu.Lock()
|
||||||
|
defer c.stateMu.Unlock()
|
||||||
|
c.config = cfg
|
||||||
|
c.cacheDir = cacheDir
|
||||||
|
c.connectClient = cc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
|
||||||
|
c.stateMu.RLock()
|
||||||
|
defer c.stateMu.RUnlock()
|
||||||
|
return c.config, c.cacheDir, c.connectClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getConnectClient() *internal.ConnectClient {
|
||||||
|
c.stateMu.RLock()
|
||||||
|
defer c.stateMu.RUnlock()
|
||||||
|
return c.connectClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
@@ -93,6 +119,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
|||||||
|
|
||||||
cfgFile := platformFiles.ConfigurationFilePath()
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
stateFile := platformFiles.StateFilePath()
|
stateFile := platformFiles.StateFilePath()
|
||||||
|
cacheDir := platformFiles.CacheDir()
|
||||||
|
|
||||||
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
@@ -124,8 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
c.setState(cfg, cacheDir, connectClient)
|
||||||
|
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -135,6 +163,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
|||||||
|
|
||||||
cfgFile := platformFiles.ConfigurationFilePath()
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
stateFile := platformFiles.StateFilePath()
|
stateFile := platformFiles.StateFilePath()
|
||||||
|
cacheDir := platformFiles.CacheDir()
|
||||||
|
|
||||||
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
@@ -157,8 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
c.setState(cfg, cacheDir, connectClient)
|
||||||
|
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -173,11 +203,12 @@ func (c *Client) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) RenewTun(fd int) error {
|
func (c *Client) RenewTun(fd int) error {
|
||||||
if c.connectClient == nil {
|
cc := c.getConnectClient()
|
||||||
|
if cc == nil {
|
||||||
return fmt.Errorf("engine not running")
|
return fmt.Errorf("engine not running")
|
||||||
}
|
}
|
||||||
|
|
||||||
e := c.connectClient.Engine()
|
e := cc.Engine()
|
||||||
if e == nil {
|
if e == nil {
|
||||||
return fmt.Errorf("engine not initialized")
|
return fmt.Errorf("engine not initialized")
|
||||||
}
|
}
|
||||||
@@ -185,6 +216,73 @@ func (c *Client) RenewTun(fd int) error {
|
|||||||
return e.RenewTun(fd)
|
return e.RenewTun(fd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
|
||||||
|
// It works both with and without a running engine.
|
||||||
|
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
|
||||||
|
cfg, cacheDir, cc := c.stateSnapshot()
|
||||||
|
|
||||||
|
// If the engine hasn't been started, load config from disk
|
||||||
|
if cfg == nil {
|
||||||
|
var err error
|
||||||
|
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
|
ConfigPath: platformFiles.ConfigurationFilePath(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
cacheDir = platformFiles.CacheDir()
|
||||||
|
}
|
||||||
|
|
||||||
|
deps := debug.GeneratorDependencies{
|
||||||
|
InternalConfig: cfg,
|
||||||
|
StatusRecorder: c.recorder,
|
||||||
|
TempDir: cacheDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
if cc != nil {
|
||||||
|
resp, err := cc.GetLatestSyncResponse()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("get latest sync response: %v", err)
|
||||||
|
}
|
||||||
|
deps.SyncResponse = resp
|
||||||
|
|
||||||
|
if e := cc.Engine(); e != nil {
|
||||||
|
if cm := e.GetClientMetrics(); cm != nil {
|
||||||
|
deps.ClientMetrics = cm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bundleGenerator := debug.NewBundleGenerator(
|
||||||
|
deps,
|
||||||
|
debug.BundleConfig{
|
||||||
|
Anonymize: anonymize,
|
||||||
|
IncludeSystemInfo: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
path, err := bundleGenerator.Generate()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("generate debug bundle: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
log.Errorf("failed to remove debug bundle file: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("upload debug bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("debug bundle uploaded with key %s", key)
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetTraceLogLevel configure the logger to trace level
|
// SetTraceLogLevel configure the logger to trace level
|
||||||
func (c *Client) SetTraceLogLevel() {
|
func (c *Client) SetTraceLogLevel() {
|
||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
@@ -205,7 +303,7 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
pi := PeerInfo{
|
pi := PeerInfo{
|
||||||
p.IP,
|
p.IP,
|
||||||
p.FQDN,
|
p.FQDN,
|
||||||
p.ConnStatus.String(),
|
int(p.ConnStatus),
|
||||||
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||||
}
|
}
|
||||||
peerInfos[n] = pi
|
peerInfos[n] = pi
|
||||||
@@ -214,12 +312,13 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Networks() *NetworkArray {
|
func (c *Client) Networks() *NetworkArray {
|
||||||
if c.connectClient == nil {
|
cc := c.getConnectClient()
|
||||||
|
if cc == nil {
|
||||||
log.Error("not connected")
|
log.Error("not connected")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
engine := c.connectClient.Engine()
|
engine := cc.Engine()
|
||||||
if engine == nil {
|
if engine == nil {
|
||||||
log.Error("could not get engine")
|
log.Error("could not get engine")
|
||||||
return nil
|
return nil
|
||||||
@@ -300,7 +399,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||||
client := c.connectClient
|
client := c.getConnectClient()
|
||||||
if client == nil {
|
if client == nil {
|
||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,20 @@
|
|||||||
|
|
||||||
package android
|
package android
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
|
||||||
|
// Connection status constants exported via gomobile.
|
||||||
|
const (
|
||||||
|
ConnStatusIdle = int(peer.StatusIdle)
|
||||||
|
ConnStatusConnecting = int(peer.StatusConnecting)
|
||||||
|
ConnStatusConnected = int(peer.StatusConnected)
|
||||||
|
)
|
||||||
|
|
||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
type PeerInfo struct {
|
type PeerInfo struct {
|
||||||
IP string
|
IP string
|
||||||
FQDN string
|
FQDN string
|
||||||
ConnStatus string // Todo replace to enum
|
ConnStatus int
|
||||||
Routes PeerRoutes
|
Routes PeerRoutes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,4 +7,5 @@ package android
|
|||||||
type PlatformFiles interface {
|
type PlatformFiles interface {
|
||||||
ConfigurationFilePath() string
|
ConfigurationFilePath() string
|
||||||
StateFilePath() string
|
StateFilePath() string
|
||||||
|
CacheDir() string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -181,11 +181,12 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
if stateWasDown {
|
if stateWasDown {
|
||||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||||
}
|
} else {
|
||||||
cmd.Println("netbird up")
|
cmd.Println("netbird up")
|
||||||
time.Sleep(time.Second * 10)
|
time.Sleep(time.Second * 10)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||||
if !initialLevelTrace {
|
if !initialLevelTrace {
|
||||||
@@ -198,10 +199,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
cmd.Println("Log level set to trace.")
|
cmd.Println("Log level set to trace.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
needsRestoreUp := false
|
||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
||||||
}
|
} else {
|
||||||
|
needsRestoreUp = !stateWasDown
|
||||||
cmd.Println("netbird down")
|
cmd.Println("netbird down")
|
||||||
|
}
|
||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
@@ -209,13 +213,15 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
|
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||||
}
|
} else {
|
||||||
|
needsRestoreUp = false
|
||||||
cmd.Println("netbird up")
|
cmd.Println("netbird up")
|
||||||
|
}
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
@@ -261,19 +267,29 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if needsRestoreUp {
|
||||||
|
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||||
|
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
|
||||||
|
} else {
|
||||||
|
cmd.Println("netbird up (restored)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if stateWasDown {
|
if stateWasDown {
|
||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
||||||
}
|
} else {
|
||||||
cmd.Println("netbird down")
|
cmd.Println("netbird down")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !initialLevelTrace {
|
if !initialLevelTrace {
|
||||||
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
|
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
|
||||||
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message())
|
||||||
}
|
} else {
|
||||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||||
|
|
||||||
|
|||||||
287
client/cmd/expose.go
Normal file
287
client/cmd/expose.go
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/expose"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
|
||||||
|
|
||||||
|
var (
|
||||||
|
exposePin string
|
||||||
|
exposePassword string
|
||||||
|
exposeUserGroups []string
|
||||||
|
exposeDomain string
|
||||||
|
exposeNamePrefix string
|
||||||
|
exposeProtocol string
|
||||||
|
exposeExternalPort uint16
|
||||||
|
)
|
||||||
|
|
||||||
|
var exposeCmd = &cobra.Command{
|
||||||
|
Use: "expose <port>",
|
||||||
|
Short: "Expose a local port via the NetBird reverse proxy",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
Example: ` netbird expose --with-password safe-pass 8080
|
||||||
|
netbird expose --protocol tcp 5432
|
||||||
|
netbird expose --protocol tcp --with-external-port 5433 5432
|
||||||
|
netbird expose --protocol tls --with-custom-domain tls.example.com 4443`,
|
||||||
|
RunE: exposeFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
|
||||||
|
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)")
|
||||||
|
exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags.
|
||||||
|
func isClusterProtocol(protocol string) bool {
|
||||||
|
switch strings.ToLower(protocol) {
|
||||||
|
case "tcp", "udp", "tls":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP)
|
||||||
|
// where domain display doesn't apply. TLS uses SNI so it has a domain.
|
||||||
|
func isPortBasedProtocol(protocol string) bool {
|
||||||
|
switch strings.ToLower(protocol) {
|
||||||
|
case "tcp", "udp":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPort returns the port portion of a URL like "tcp://host:12345", or
|
||||||
|
// falls back to the given default formatted as a string.
|
||||||
|
func extractPort(serviceURL string, fallback uint16) string {
|
||||||
|
u := serviceURL
|
||||||
|
if idx := strings.Index(u, "://"); idx != -1 {
|
||||||
|
u = u[idx+3:]
|
||||||
|
}
|
||||||
|
if i := strings.LastIndex(u, ":"); i != -1 {
|
||||||
|
if p := u[i+1:]; p != "" {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strconv.FormatUint(uint64(fallback), 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveExternalPort returns the effective external port, defaulting to the target port.
|
||||||
|
func resolveExternalPort(targetPort uint64) uint16 {
|
||||||
|
if exposeExternalPort != 0 {
|
||||||
|
return exposeExternalPort
|
||||||
|
}
|
||||||
|
return uint16(targetPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
|
||||||
|
port, err := strconv.ParseUint(portStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid port number: %s", portStr)
|
||||||
|
}
|
||||||
|
if port == 0 || port > 65535 {
|
||||||
|
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isProtocolValid(exposeProtocol) {
|
||||||
|
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isClusterProtocol(exposeProtocol) {
|
||||||
|
if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 {
|
||||||
|
return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol)
|
||||||
|
}
|
||||||
|
} else if cmd.Flags().Changed("with-external-port") {
|
||||||
|
return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
|
||||||
|
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flags().Changed("with-password") && exposePassword == "" {
|
||||||
|
return 0, fmt.Errorf("password cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
|
||||||
|
return 0, fmt.Errorf("user groups cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProtocolValid(exposeProtocol string) bool {
|
||||||
|
switch strings.ToLower(exposeProtocol) {
|
||||||
|
case "http", "https", "tcp", "udp", "tls":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func exposeFn(cmd *cobra.Command, args []string) error {
|
||||||
|
SetFlagsFromEnvVars(rootCmd)
|
||||||
|
|
||||||
|
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||||
|
log.Errorf("failed initializing log %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Root().SilenceUsage = false
|
||||||
|
|
||||||
|
port, err := validateExposeFlags(cmd, args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Root().SilenceUsage = true
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
<-sigCh
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to daemon: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Debugf("failed to close daemon connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
|
||||||
|
protocol, err := toExposeProtocol(exposeProtocol)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &proto.ExposeServiceRequest{
|
||||||
|
Port: uint32(port),
|
||||||
|
Protocol: protocol,
|
||||||
|
Pin: exposePin,
|
||||||
|
Password: exposePassword,
|
||||||
|
UserGroups: exposeUserGroups,
|
||||||
|
Domain: exposeDomain,
|
||||||
|
NamePrefix: exposeNamePrefix,
|
||||||
|
}
|
||||||
|
if isClusterProtocol(exposeProtocol) {
|
||||||
|
req.ListenPort = uint32(resolveExternalPort(port))
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := client.ExposeService(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return waitForExposeEvents(cmd, ctx, stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||||
|
p, err := expose.ParseProtocolType(exposeProtocol)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid protocol: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p {
|
||||||
|
case expose.ProtocolHTTP:
|
||||||
|
return proto.ExposeProtocol_EXPOSE_HTTP, nil
|
||||||
|
case expose.ProtocolHTTPS:
|
||||||
|
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
|
||||||
|
case expose.ProtocolTCP:
|
||||||
|
return proto.ExposeProtocol_EXPOSE_TCP, nil
|
||||||
|
case expose.ProtocolUDP:
|
||||||
|
return proto.ExposeProtocol_EXPOSE_UDP, nil
|
||||||
|
case expose.ProtocolTLS:
|
||||||
|
return proto.ExposeProtocol_EXPOSE_TLS, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unhandled protocol type: %d", p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||||
|
event, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected expose event: %T", event.Event)
|
||||||
|
}
|
||||||
|
printExposeReady(cmd, ready.Ready, port)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) {
|
||||||
|
cmd.Println("Service exposed successfully!")
|
||||||
|
cmd.Printf(" Name: %s\n", r.ServiceName)
|
||||||
|
if r.ServiceUrl != "" {
|
||||||
|
cmd.Printf(" URL: %s\n", r.ServiceUrl)
|
||||||
|
}
|
||||||
|
if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) {
|
||||||
|
cmd.Printf(" Domain: %s\n", r.Domain)
|
||||||
|
}
|
||||||
|
cmd.Printf(" Protocol: %s\n", exposeProtocol)
|
||||||
|
cmd.Printf(" Internal: %d\n", port)
|
||||||
|
if isClusterProtocol(exposeProtocol) {
|
||||||
|
cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port)))
|
||||||
|
}
|
||||||
|
if r.PortAutoAssigned && exposeExternalPort != 0 {
|
||||||
|
cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort)
|
||||||
|
}
|
||||||
|
cmd.Println()
|
||||||
|
cmd.Println("Press Ctrl+C to stop exposing.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
|
||||||
|
for {
|
||||||
|
_, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
cmd.Println("\nService stopped.")
|
||||||
|
//nolint:nilerr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return fmt.Errorf("connection to daemon closed unexpectedly")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("stream error: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -74,12 +75,22 @@ var (
|
|||||||
mtu uint16
|
mtu uint16
|
||||||
profilesDisabled bool
|
profilesDisabled bool
|
||||||
updateSettingsDisabled bool
|
updateSettingsDisabled bool
|
||||||
|
networksDisabled bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
Short: "",
|
Short: "",
|
||||||
Long: "",
|
Long: "",
|
||||||
SilenceUsage: true,
|
SilenceUsage: true,
|
||||||
|
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
SetFlagsFromEnvVars(cmd.Root())
|
||||||
|
|
||||||
|
// Don't resolve for service commands — they create the socket, not connect to it.
|
||||||
|
if !isServiceCmd(cmd) {
|
||||||
|
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -144,6 +155,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(forwardingRulesCmd)
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
rootCmd.AddCommand(profileCmd)
|
rootCmd.AddCommand(profileCmd)
|
||||||
|
rootCmd.AddCommand(exposeCmd)
|
||||||
|
|
||||||
networksCMD.AddCommand(routesListCmd)
|
networksCMD.AddCommand(routesListCmd)
|
||||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||||
@@ -385,7 +397,6 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||||
SetFlagsFromEnvVars(rootCmd)
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||||
@@ -398,3 +409,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
|||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isServiceCmd returns true if cmd is the "service" command or a child of it.
|
||||||
|
func isServiceCmd(cmd *cobra.Command) bool {
|
||||||
|
for c := cmd; c != nil; c = c.Parent() {
|
||||||
|
if c.Name() == "service" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -41,13 +41,16 @@ func init() {
|
|||||||
defaultServiceName = "Netbird"
|
defaultServiceName = "Netbird"
|
||||||
}
|
}
|
||||||
|
|
||||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
|
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||||
|
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
||||||
|
|
||||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||||
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
||||||
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
||||||
|
`New keys are merged with previously saved env vars; existing keys are overwritten. ` +
|
||||||
|
`Use --service-env "" to clear all saved env vars. ` +
|
||||||
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
|
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
|
||||||
|
|
||||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
|
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
|
||||||
if err := serverInstance.Start(); err != nil {
|
if err := serverInstance.Start(); err != nil {
|
||||||
log.Fatalf("failed to start daemon: %v", err)
|
log.Fatalf("failed to start daemon: %v", err)
|
||||||
}
|
}
|
||||||
@@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error {
|
|||||||
|
|
||||||
// Common setup for service control commands
|
// Common setup for service control commands
|
||||||
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
||||||
SetFlagsFromEnvVars(rootCmd)
|
// rootCmd env vars are already applied by PersistentPreRunE.
|
||||||
SetFlagsFromEnvVars(serviceCmd)
|
SetFlagsFromEnvVars(serviceCmd)
|
||||||
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|||||||
@@ -59,6 +59,10 @@ func buildServiceArguments() []string {
|
|||||||
args = append(args, "--disable-update-settings")
|
args = append(args, "--disable-update-settings")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if networksDisabled {
|
||||||
|
args = append(args, "--disable-networks")
|
||||||
|
}
|
||||||
|
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,6 +123,10 @@ var installCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
svcConfig, err := createServiceConfigForInstall()
|
svcConfig, err := createServiceConfigForInstall()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -136,6 +144,10 @@ var installCmd = &cobra.Command{
|
|||||||
return fmt.Errorf("install service: %w", err)
|
return fmt.Errorf("install service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
cmd.Println("NetBird service has been installed")
|
cmd.Println("NetBird service has been installed")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -187,6 +199,10 @@ This command will temporarily stop the service, update its configuration, and re
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
wasRunning, err := isServiceRunning()
|
wasRunning, err := isServiceRunning()
|
||||||
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
|
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
|
||||||
return fmt.Errorf("check service status: %w", err)
|
return fmt.Errorf("check service status: %w", err)
|
||||||
@@ -222,6 +238,10 @@ This command will temporarily stop the service, update its configuration, and re
|
|||||||
return fmt.Errorf("install service with new config: %w", err)
|
return fmt.Errorf("install service with new config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
if wasRunning {
|
if wasRunning {
|
||||||
cmd.Println("Starting NetBird service...")
|
cmd.Println("Starting NetBird service...")
|
||||||
if err := s.Start(); err != nil {
|
if err := s.Start(); err != nil {
|
||||||
|
|||||||
218
client/cmd/service_params.go
Normal file
218
client/cmd/service_params.go
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const serviceParamsFile = "service.json"
|
||||||
|
|
||||||
|
// serviceParams holds install-time service parameters that persist across
|
||||||
|
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
|
||||||
|
type serviceParams struct {
|
||||||
|
LogLevel string `json:"log_level"`
|
||||||
|
DaemonAddr string `json:"daemon_addr"`
|
||||||
|
ManagementURL string `json:"management_url,omitempty"`
|
||||||
|
ConfigPath string `json:"config_path,omitempty"`
|
||||||
|
LogFiles []string `json:"log_files,omitempty"`
|
||||||
|
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||||
|
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||||
|
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||||
|
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// serviceParamsPath returns the path to the service params file.
|
||||||
|
func serviceParamsPath() string {
|
||||||
|
return filepath.Join(configs.StateDir, serviceParamsFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadServiceParams reads saved service parameters from disk.
|
||||||
|
// Returns nil with no error if the file does not exist.
|
||||||
|
func loadServiceParams() (*serviceParams, error) {
|
||||||
|
path := serviceParamsPath()
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil //nolint:nilnil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("read service params %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var params serviceParams
|
||||||
|
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse service params %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ¶ms, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveServiceParams writes current service parameters to disk atomically
|
||||||
|
// with restricted permissions.
|
||||||
|
func saveServiceParams(params *serviceParams) error {
|
||||||
|
path := serviceParamsPath()
|
||||||
|
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
|
||||||
|
return fmt.Errorf("save service params: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// currentServiceParams captures the current state of all package-level
|
||||||
|
// variables into a serviceParams struct.
|
||||||
|
func currentServiceParams() *serviceParams {
|
||||||
|
params := &serviceParams{
|
||||||
|
LogLevel: logLevel,
|
||||||
|
DaemonAddr: daemonAddr,
|
||||||
|
ManagementURL: managementURL,
|
||||||
|
ConfigPath: configPath,
|
||||||
|
LogFiles: logFiles,
|
||||||
|
DisableProfiles: profilesDisabled,
|
||||||
|
DisableUpdateSettings: updateSettingsDisabled,
|
||||||
|
DisableNetworks: networksDisabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(serviceEnvVars) > 0 {
|
||||||
|
parsed, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
if err == nil {
|
||||||
|
params.ServiceEnvVars = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAndApplyServiceParams loads saved params from disk and applies them
|
||||||
|
// to any flags that were not explicitly set.
|
||||||
|
func loadAndApplyServiceParams(cmd *cobra.Command) error {
|
||||||
|
params, err := loadServiceParams()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
applyServiceParams(cmd, params)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyServiceParams merges saved parameters into package-level variables
|
||||||
|
// for any flag that was not explicitly set by the user (via CLI or env var).
|
||||||
|
// Flags that were Changed() are left untouched.
|
||||||
|
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||||
|
if params == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For fields with non-empty defaults (log-level, daemon-addr), keep the
|
||||||
|
// != "" guard so that an older service.json missing the field doesn't
|
||||||
|
// clobber the default with an empty string.
|
||||||
|
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
|
||||||
|
logLevel = params.LogLevel
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
|
||||||
|
daemonAddr = params.DaemonAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// For optional fields where empty means "use default", always apply so
|
||||||
|
// that an explicit clear (--management-url "") persists across reinstalls.
|
||||||
|
if !rootCmd.PersistentFlags().Changed("management-url") {
|
||||||
|
managementURL = params.ManagementURL
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rootCmd.PersistentFlags().Changed("config") {
|
||||||
|
configPath = params.ConfigPath
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rootCmd.PersistentFlags().Changed("log-file") {
|
||||||
|
logFiles = params.LogFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
|
||||||
|
profilesDisabled = params.DisableProfiles
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
|
||||||
|
updateSettingsDisabled = params.DisableUpdateSettings
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
|
||||||
|
networksDisabled = params.DisableNetworks
|
||||||
|
}
|
||||||
|
|
||||||
|
applyServiceEnvParams(cmd, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyServiceEnvParams merges saved service environment variables.
|
||||||
|
// If --service-env was explicitly set with values, explicit values win on key
|
||||||
|
// conflict but saved keys not in the explicit set are carried over.
|
||||||
|
// If --service-env was explicitly set to empty, all saved env vars are cleared.
|
||||||
|
// If --service-env was not set, saved env vars are used entirely.
|
||||||
|
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
|
||||||
|
if !cmd.Flags().Changed("service-env") {
|
||||||
|
if len(params.ServiceEnvVars) > 0 {
|
||||||
|
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
||||||
|
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flag was explicitly set: parse what the user provided.
|
||||||
|
explicit, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the user passed an empty value (e.g. --service-env ""), clear all
|
||||||
|
// saved env vars rather than merging.
|
||||||
|
if len(explicit) == 0 {
|
||||||
|
serviceEnvVars = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(params.ServiceEnvVars) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge saved values underneath explicit ones.
|
||||||
|
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
|
||||||
|
maps.Copy(merged, params.ServiceEnvVars)
|
||||||
|
maps.Copy(merged, explicit) // explicit wins on conflict
|
||||||
|
serviceEnvVars = envMapToSlice(merged)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resetParamsCmd = &cobra.Command{
|
||||||
|
Use: "reset-params",
|
||||||
|
Short: "Remove saved service install parameters",
|
||||||
|
Long: "Removes the saved service.json file so the next install uses default parameters.",
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
path := serviceParamsPath()
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
cmd.Println("No saved service parameters found")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("remove service params: %w", err)
|
||||||
|
}
|
||||||
|
cmd.Printf("Removed saved service parameters (%s)\n", path)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
|
||||||
|
func envMapToSlice(m map[string]string) []string {
|
||||||
|
s := make([]string, 0, len(m))
|
||||||
|
for k, v := range m {
|
||||||
|
s = append(s, k+"="+v)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
559
client/cmd/service_params_test.go
Normal file
559
client/cmd/service_params_test.go
Normal file
@@ -0,0 +1,559 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"go/ast"
|
||||||
|
"go/parser"
|
||||||
|
"go/token"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/spf13/pflag"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServiceParamsPath(t *testing.T) {
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
|
||||||
|
configs.StateDir = "/var/lib/netbird"
|
||||||
|
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
|
||||||
|
|
||||||
|
configs.StateDir = "/custom/state"
|
||||||
|
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveAndLoadServiceParams(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
configs.StateDir = tmpDir
|
||||||
|
|
||||||
|
params := &serviceParams{
|
||||||
|
LogLevel: "debug",
|
||||||
|
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||||
|
ManagementURL: "https://my.server.com",
|
||||||
|
ConfigPath: "/etc/netbird/config.json",
|
||||||
|
LogFiles: []string{"/var/log/netbird/client.log", "console"},
|
||||||
|
DisableProfiles: true,
|
||||||
|
DisableUpdateSettings: false,
|
||||||
|
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := saveServiceParams(params)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the file exists and is valid JSON.
|
||||||
|
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, json.Valid(data))
|
||||||
|
|
||||||
|
loaded, err := loadServiceParams()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, loaded)
|
||||||
|
|
||||||
|
assert.Equal(t, params.LogLevel, loaded.LogLevel)
|
||||||
|
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
|
||||||
|
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
|
||||||
|
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
|
||||||
|
assert.Equal(t, params.LogFiles, loaded.LogFiles)
|
||||||
|
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
|
||||||
|
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
|
||||||
|
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadServiceParams_FileNotExists(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
configs.StateDir = tmpDir
|
||||||
|
|
||||||
|
params, err := loadServiceParams()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
configs.StateDir = tmpDir
|
||||||
|
|
||||||
|
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
params, err := loadServiceParams()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCurrentServiceParams(t *testing.T) {
|
||||||
|
origLogLevel := logLevel
|
||||||
|
origDaemonAddr := daemonAddr
|
||||||
|
origManagementURL := managementURL
|
||||||
|
origConfigPath := configPath
|
||||||
|
origLogFiles := logFiles
|
||||||
|
origProfilesDisabled := profilesDisabled
|
||||||
|
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() {
|
||||||
|
logLevel = origLogLevel
|
||||||
|
daemonAddr = origDaemonAddr
|
||||||
|
managementURL = origManagementURL
|
||||||
|
configPath = origConfigPath
|
||||||
|
logFiles = origLogFiles
|
||||||
|
profilesDisabled = origProfilesDisabled
|
||||||
|
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||||
|
serviceEnvVars = origServiceEnvVars
|
||||||
|
})
|
||||||
|
|
||||||
|
logLevel = "trace"
|
||||||
|
daemonAddr = "tcp://127.0.0.1:9999"
|
||||||
|
managementURL = "https://mgmt.example.com"
|
||||||
|
configPath = "/tmp/test-config.json"
|
||||||
|
logFiles = []string{"/tmp/test.log"}
|
||||||
|
profilesDisabled = true
|
||||||
|
updateSettingsDisabled = true
|
||||||
|
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
|
||||||
|
|
||||||
|
params := currentServiceParams()
|
||||||
|
|
||||||
|
assert.Equal(t, "trace", params.LogLevel)
|
||||||
|
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
|
||||||
|
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
|
||||||
|
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
|
||||||
|
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
|
||||||
|
assert.True(t, params.DisableProfiles)
|
||||||
|
assert.True(t, params.DisableUpdateSettings)
|
||||||
|
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
|
||||||
|
origLogLevel := logLevel
|
||||||
|
origDaemonAddr := daemonAddr
|
||||||
|
origManagementURL := managementURL
|
||||||
|
origConfigPath := configPath
|
||||||
|
origLogFiles := logFiles
|
||||||
|
origProfilesDisabled := profilesDisabled
|
||||||
|
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() {
|
||||||
|
logLevel = origLogLevel
|
||||||
|
daemonAddr = origDaemonAddr
|
||||||
|
managementURL = origManagementURL
|
||||||
|
configPath = origConfigPath
|
||||||
|
logFiles = origLogFiles
|
||||||
|
profilesDisabled = origProfilesDisabled
|
||||||
|
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||||
|
serviceEnvVars = origServiceEnvVars
|
||||||
|
})
|
||||||
|
|
||||||
|
// Reset all flags to defaults.
|
||||||
|
logLevel = "info"
|
||||||
|
daemonAddr = "unix:///var/run/netbird.sock"
|
||||||
|
managementURL = ""
|
||||||
|
configPath = "/etc/netbird/config.json"
|
||||||
|
logFiles = []string{"/var/log/netbird/client.log"}
|
||||||
|
profilesDisabled = false
|
||||||
|
updateSettingsDisabled = false
|
||||||
|
serviceEnvVars = nil
|
||||||
|
|
||||||
|
// Reset Changed state on all relevant flags.
|
||||||
|
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate user explicitly setting --log-level via CLI.
|
||||||
|
logLevel = "warn"
|
||||||
|
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
|
||||||
|
|
||||||
|
saved := &serviceParams{
|
||||||
|
LogLevel: "debug",
|
||||||
|
DaemonAddr: "tcp://127.0.0.1:5555",
|
||||||
|
ManagementURL: "https://saved.example.com",
|
||||||
|
ConfigPath: "/saved/config.json",
|
||||||
|
LogFiles: []string{"/saved/client.log"},
|
||||||
|
DisableProfiles: true,
|
||||||
|
DisableUpdateSettings: true,
|
||||||
|
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
applyServiceParams(cmd, saved)
|
||||||
|
|
||||||
|
// log-level was Changed, so it should keep "warn", not use saved "debug".
|
||||||
|
assert.Equal(t, "warn", logLevel)
|
||||||
|
|
||||||
|
// All other fields were not Changed, so they should use saved values.
|
||||||
|
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
|
||||||
|
assert.Equal(t, "https://saved.example.com", managementURL)
|
||||||
|
assert.Equal(t, "/saved/config.json", configPath)
|
||||||
|
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
|
||||||
|
assert.True(t, profilesDisabled)
|
||||||
|
assert.True(t, updateSettingsDisabled)
|
||||||
|
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
|
||||||
|
origProfilesDisabled := profilesDisabled
|
||||||
|
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||||
|
t.Cleanup(func() {
|
||||||
|
profilesDisabled = origProfilesDisabled
|
||||||
|
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate current state where booleans are true (e.g. set by previous install).
|
||||||
|
profilesDisabled = true
|
||||||
|
updateSettingsDisabled = true
|
||||||
|
|
||||||
|
// Reset Changed state so flags appear unset.
|
||||||
|
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saved params have both as false.
|
||||||
|
saved := &serviceParams{
|
||||||
|
DisableProfiles: false,
|
||||||
|
DisableUpdateSettings: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
applyServiceParams(cmd, saved)
|
||||||
|
|
||||||
|
assert.False(t, profilesDisabled, "saved false should override current true")
|
||||||
|
assert.False(t, updateSettingsDisabled, "saved false should override current true")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
|
||||||
|
origManagementURL := managementURL
|
||||||
|
t.Cleanup(func() { managementURL = origManagementURL })
|
||||||
|
|
||||||
|
managementURL = "https://leftover.example.com"
|
||||||
|
|
||||||
|
// Simulate saved params where management URL was explicitly cleared.
|
||||||
|
saved := &serviceParams{
|
||||||
|
LogLevel: "info",
|
||||||
|
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||||
|
// ManagementURL intentionally empty: was cleared with --management-url "".
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
applyServiceParams(cmd, saved)
|
||||||
|
|
||||||
|
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_NilParams(t *testing.T) {
|
||||||
|
origLogLevel := logLevel
|
||||||
|
t.Cleanup(func() { logLevel = origLogLevel })
|
||||||
|
|
||||||
|
logLevel = "info"
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
|
||||||
|
// Should be a no-op.
|
||||||
|
applyServiceParams(cmd, nil)
|
||||||
|
assert.Equal(t, "info", logLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||||
|
|
||||||
|
// Set up a command with --service-env marked as Changed.
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
|
||||||
|
|
||||||
|
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
|
||||||
|
|
||||||
|
saved := &serviceParams{
|
||||||
|
ServiceEnvVars: map[string]string{
|
||||||
|
"SAVED": "val",
|
||||||
|
"OVERLAP": "saved",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyServiceEnvParams(cmd, saved)
|
||||||
|
|
||||||
|
// Parse result for easier assertion.
|
||||||
|
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "yes", result["EXPLICIT"])
|
||||||
|
assert.Equal(t, "val", result["SAVED"])
|
||||||
|
// Explicit wins on conflict.
|
||||||
|
assert.Equal(t, "explicit", result["OVERLAP"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||||
|
|
||||||
|
serviceEnvVars = nil
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
|
||||||
|
saved := &serviceParams{
|
||||||
|
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyServiceEnvParams(cmd, saved)
|
||||||
|
|
||||||
|
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) {
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||||
|
|
||||||
|
// Simulate --service-env "" which produces [""] in the slice.
|
||||||
|
serviceEnvVars = []string{""}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
require.NoError(t, cmd.Flags().Set("service-env", ""))
|
||||||
|
|
||||||
|
saved := &serviceParams{
|
||||||
|
ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyServiceEnvParams(cmd, saved)
|
||||||
|
|
||||||
|
assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) {
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||||
|
|
||||||
|
// Simulate --service-env "" which produces [""] in the slice.
|
||||||
|
serviceEnvVars = []string{""}
|
||||||
|
|
||||||
|
params := currentServiceParams()
|
||||||
|
|
||||||
|
// After parsing, the empty string is skipped, resulting in an empty map.
|
||||||
|
// The map should still be set (not nil) so it overwrites saved values.
|
||||||
|
assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil")
|
||||||
|
assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
|
||||||
|
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
|
||||||
|
// added to serviceParams but not wired into these functions, this test fails.
|
||||||
|
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
|
||||||
|
fset := token.NewFileSet()
|
||||||
|
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Collect all JSON field names from the serviceParams struct.
|
||||||
|
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||||
|
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
|
||||||
|
|
||||||
|
// Collect field names referenced in currentServiceParams and applyServiceParams.
|
||||||
|
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
|
||||||
|
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
|
||||||
|
// applyServiceEnvParams handles ServiceEnvVars indirectly.
|
||||||
|
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
|
||||||
|
for k, v := range applyEnvFields {
|
||||||
|
applyFields[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range structFields {
|
||||||
|
assert.Contains(t, currentFields, field,
|
||||||
|
"serviceParams field %q is not captured in currentServiceParams()", field)
|
||||||
|
assert.Contains(t, applyFields, field,
|
||||||
|
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
|
||||||
|
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
|
||||||
|
// it flows through newSVCConfig() EnvVars, not CLI args.
|
||||||
|
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
|
||||||
|
fset := token.NewFileSet()
|
||||||
|
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||||
|
require.NotEmpty(t, structFields)
|
||||||
|
|
||||||
|
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
|
||||||
|
fieldsNotInArgs := map[string]bool{
|
||||||
|
"ServiceEnvVars": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
|
||||||
|
|
||||||
|
// Forward: every struct field must appear in buildServiceArguments.
|
||||||
|
for _, field := range structFields {
|
||||||
|
if fieldsNotInArgs[field] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
globalVar := fieldToGlobalVar(field)
|
||||||
|
assert.Contains(t, buildFields, globalVar,
|
||||||
|
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse: every service-related global used in buildServiceArguments must
|
||||||
|
// have a corresponding serviceParams field. This catches a developer adding
|
||||||
|
// a new flag to buildServiceArguments without adding it to the struct.
|
||||||
|
globalToField := make(map[string]string, len(structFields))
|
||||||
|
for _, field := range structFields {
|
||||||
|
globalToField[fieldToGlobalVar(field)] = field
|
||||||
|
}
|
||||||
|
// Identifiers in buildServiceArguments that are not service params
|
||||||
|
// (builtins, boilerplate, loop variables).
|
||||||
|
nonParamGlobals := map[string]bool{
|
||||||
|
"args": true, "append": true, "string": true, "_": true,
|
||||||
|
"logFile": true, // range variable over logFiles
|
||||||
|
}
|
||||||
|
for ref := range buildFields {
|
||||||
|
if nonParamGlobals[ref] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, inStruct := globalToField[ref]
|
||||||
|
assert.True(t, inStruct,
|
||||||
|
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractStructJSONFields returns field names from a named struct type.
|
||||||
|
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
|
||||||
|
t.Helper()
|
||||||
|
var fields []string
|
||||||
|
ast.Inspect(file, func(n ast.Node) bool {
|
||||||
|
ts, ok := n.(*ast.TypeSpec)
|
||||||
|
if !ok || ts.Name.Name != structName {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
st, ok := ts.Type.(*ast.StructType)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, f := range st.Fields.List {
|
||||||
|
if len(f.Names) > 0 {
|
||||||
|
fields = append(fields, f.Names[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFuncFieldRefs returns which of the given field names appear inside the
|
||||||
|
// named function, either as selector expressions (params.FieldName) or as
|
||||||
|
// composite literal keys (&serviceParams{FieldName: ...}).
|
||||||
|
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
|
||||||
|
t.Helper()
|
||||||
|
fieldSet := make(map[string]bool, len(fields))
|
||||||
|
for _, f := range fields {
|
||||||
|
fieldSet[f] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
found := make(map[string]bool)
|
||||||
|
fn := findFuncDecl(file, funcName)
|
||||||
|
require.NotNil(t, fn, "function %s not found", funcName)
|
||||||
|
|
||||||
|
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||||
|
switch v := n.(type) {
|
||||||
|
case *ast.SelectorExpr:
|
||||||
|
if fieldSet[v.Sel.Name] {
|
||||||
|
found[v.Sel.Name] = true
|
||||||
|
}
|
||||||
|
case *ast.KeyValueExpr:
|
||||||
|
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
|
||||||
|
found[ident.Name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
|
||||||
|
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
|
||||||
|
t.Helper()
|
||||||
|
fn := findFuncDecl(file, funcName)
|
||||||
|
require.NotNil(t, fn, "function %s not found", funcName)
|
||||||
|
|
||||||
|
refs := make(map[string]bool)
|
||||||
|
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||||
|
if ident, ok := n.(*ast.Ident); ok {
|
||||||
|
refs[ident.Name] = true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return refs
|
||||||
|
}
|
||||||
|
|
||||||
|
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
|
||||||
|
for _, decl := range file.Decls {
|
||||||
|
fn, ok := decl.(*ast.FuncDecl)
|
||||||
|
if ok && fn.Name.Name == name {
|
||||||
|
return fn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldToGlobalVar maps serviceParams field names to the package-level variable
|
||||||
|
// names used in buildServiceArguments and applyServiceParams.
|
||||||
|
func fieldToGlobalVar(field string) string {
|
||||||
|
m := map[string]string{
|
||||||
|
"LogLevel": "logLevel",
|
||||||
|
"DaemonAddr": "daemonAddr",
|
||||||
|
"ManagementURL": "managementURL",
|
||||||
|
"ConfigPath": "configPath",
|
||||||
|
"LogFiles": "logFiles",
|
||||||
|
"DisableProfiles": "profilesDisabled",
|
||||||
|
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||||
|
"DisableNetworks": "networksDisabled",
|
||||||
|
"ServiceEnvVars": "serviceEnvVars",
|
||||||
|
}
|
||||||
|
if v, ok := m[field]; ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
// Default: lowercase first letter.
|
||||||
|
return strings.ToLower(field[:1]) + field[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvMapToSlice(t *testing.T) {
|
||||||
|
m := map[string]string{"A": "1", "B": "2"}
|
||||||
|
s := envMapToSlice(m)
|
||||||
|
assert.Len(t, s, 2)
|
||||||
|
assert.Contains(t, s, "A=1")
|
||||||
|
assert.Contains(t, s, "B=2")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvMapToSlice_Empty(t *testing.T) {
|
||||||
|
s := envMapToSlice(map[string]string{})
|
||||||
|
assert.Empty(t, s)
|
||||||
|
}
|
||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -13,6 +15,22 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TestMain intercepts when this test binary is run as a daemon subprocess.
|
||||||
|
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
|
||||||
|
// "service run ..." arguments. Since the test binary can't handle cobra CLI
|
||||||
|
// args, it exits immediately, causing daemon -r to respawn rapidly until
|
||||||
|
// hitting the rate limit and exiting. This makes service restart unreliable.
|
||||||
|
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
|
||||||
|
sig := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
|
||||||
|
<-sig
|
||||||
|
return
|
||||||
|
}
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
serviceStartTimeout = 10 * time.Second
|
serviceStartTimeout = 10 * time.Second
|
||||||
serviceStopTimeout = 5 * time.Second
|
serviceStopTimeout = 5 * time.Second
|
||||||
@@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) {
|
|||||||
logLevel = "info"
|
logLevel = "info"
|
||||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||||
|
|
||||||
|
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cleanup: create service config: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cleanup: create service: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the subtests already cleaned up, there's nothing to do.
|
||||||
|
if _, err := s.Status(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.Stop(); err != nil {
|
||||||
|
t.Errorf("cleanup: stop service: %v", err)
|
||||||
|
}
|
||||||
|
if err := s.Uninstall(); err != nil {
|
||||||
|
t.Errorf("cleanup: uninstall service: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
t.Run("Install", func(t *testing.T) {
|
t.Run("Install", func(t *testing.T) {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ var (
|
|||||||
ipsFilterMap map[string]struct{}
|
ipsFilterMap map[string]struct{}
|
||||||
prefixNamesFilterMap map[string]struct{}
|
prefixNamesFilterMap map[string]struct{}
|
||||||
connectionTypeFilter string
|
connectionTypeFilter string
|
||||||
|
checkFlag string
|
||||||
)
|
)
|
||||||
|
|
||||||
var statusCmd = &cobra.Command{
|
var statusCmd = &cobra.Command{
|
||||||
@@ -49,6 +50,7 @@ func init() {
|
|||||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||||
|
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -56,6 +58,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
|
if checkFlag != "" {
|
||||||
|
return runHealthCheck(cmd)
|
||||||
|
}
|
||||||
|
|
||||||
err := parseFilters()
|
err := parseFilters()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -68,15 +74,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
ctx := internal.CtxInitState(cmd.Context())
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
|
||||||
resp, err := getStatus(ctx, false)
|
resp, err := getStatus(ctx, true, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
status := resp.GetStatus()
|
status := resp.GetStatus()
|
||||||
|
|
||||||
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
needsAuth := status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
||||||
status == string(internal.StatusSessionExpired) {
|
status == string(internal.StatusSessionExpired)
|
||||||
|
|
||||||
|
if needsAuth && !jsonFlag && !yamlFlag {
|
||||||
cmd.Printf("Daemon status: %s\n\n"+
|
cmd.Printf("Daemon status: %s\n\n"+
|
||||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||||
" netbird up \n\n"+
|
" netbird up \n\n"+
|
||||||
@@ -99,7 +107,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
profName = activeProf.Name
|
profName = activeProf.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||||
|
Anonymize: anonymizeFlag,
|
||||||
|
DaemonVersion: resp.GetDaemonVersion(),
|
||||||
|
DaemonStatus: nbstatus.ParseDaemonStatus(status),
|
||||||
|
StatusFilter: statusFilter,
|
||||||
|
PrefixNamesFilter: prefixNamesFilter,
|
||||||
|
PrefixNamesFilterMap: prefixNamesFilterMap,
|
||||||
|
IPsFilter: ipsFilterMap,
|
||||||
|
ConnectionTypeFilter: connectionTypeFilter,
|
||||||
|
ProfileName: profName,
|
||||||
|
})
|
||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
switch {
|
switch {
|
||||||
case detailFlag:
|
case detailFlag:
|
||||||
@@ -121,7 +139,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//nolint
|
//nolint
|
||||||
@@ -131,7 +149,7 @@ func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse
|
|||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
|
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: fullPeerStatus, ShouldRunProbes: shouldRunProbes})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
@@ -185,6 +203,83 @@ func enableDetailFlagWhenFilterFlag() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runHealthCheck(cmd *cobra.Command) error {
|
||||||
|
check := strings.ToLower(checkFlag)
|
||||||
|
switch check {
|
||||||
|
case "live", "ready", "startup":
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown check %q, must be one of: live, ready, startup", checkFlag)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
|
||||||
|
isStartup := check == "startup"
|
||||||
|
resp, err := getStatus(ctx, isStartup, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch check {
|
||||||
|
case "live":
|
||||||
|
return nil
|
||||||
|
case "ready":
|
||||||
|
return checkReadiness(resp)
|
||||||
|
case "startup":
|
||||||
|
return checkStartup(resp)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkReadiness(resp *proto.StatusResponse) error {
|
||||||
|
daemonStatus := internal.StatusType(resp.GetStatus())
|
||||||
|
switch daemonStatus {
|
||||||
|
case internal.StatusIdle, internal.StatusConnecting, internal.StatusConnected:
|
||||||
|
return nil
|
||||||
|
case internal.StatusNeedsLogin, internal.StatusLoginFailed, internal.StatusSessionExpired:
|
||||||
|
return fmt.Errorf("readiness check: daemon status is %s", daemonStatus)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("readiness check: unexpected daemon status %q", daemonStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkStartup(resp *proto.StatusResponse) error {
|
||||||
|
fullStatus := resp.GetFullStatus()
|
||||||
|
if fullStatus == nil {
|
||||||
|
return fmt.Errorf("startup check: no full status available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fullStatus.GetManagementState().GetConnected() {
|
||||||
|
return fmt.Errorf("startup check: management not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fullStatus.GetSignalState().GetConnected() {
|
||||||
|
return fmt.Errorf("startup check: signal not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
var relayCount, relaysConnected int
|
||||||
|
for _, r := range fullStatus.GetRelays() {
|
||||||
|
uri := r.GetURI()
|
||||||
|
if !strings.HasPrefix(uri, "rel://") && !strings.HasPrefix(uri, "rels://") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
relayCount++
|
||||||
|
if r.GetAvailable() {
|
||||||
|
relaysConnected++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if relayCount > 0 && relaysConnected == 0 {
|
||||||
|
return fmt.Errorf("startup check: no relay servers available (0/%d connected)", relayCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func parseInterfaceIP(interfaceIP string) string {
|
func parseInterfaceIP(interfaceIP string) string {
|
||||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
@@ -100,9 +102,16 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
|
|
||||||
jobManager := job.NewJobManager(nil, store, peersmanager)
|
jobManager := job.NewJobManager(nil, store, peersmanager)
|
||||||
|
|
||||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
ctx := context.Background()
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
@@ -113,12 +122,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
Return(&types.Settings{}, nil).
|
Return(&types.Settings{}, nil).
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -152,7 +160,7 @@ func startClientDaemon(
|
|||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
|
|
||||||
server := client.New(ctx,
|
server := client.New(ctx,
|
||||||
"", "", false, false)
|
"", "", false, false, false)
|
||||||
if err := server.Start(); err != nil {
|
if err := server.Start(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
|||||||
r := peer.NewRecorder(config.ManagementURL.String())
|
r := peer.NewRecorder(config.ManagementURL.String())
|
||||||
r.GetFullStatus()
|
r.GetFullStatus()
|
||||||
|
|
||||||
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||||
|
|
||||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
@@ -21,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,14 +33,14 @@ var (
|
|||||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerConnStatus is a peer's connection status.
|
|
||||||
type PeerConnStatus = peer.ConnStatus
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// PeerStatusConnected indicates the peer is in connected state.
|
// PeerStatusConnected indicates the peer is in connected state.
|
||||||
PeerStatusConnected = peer.StatusConnected
|
PeerStatusConnected = peer.StatusConnected
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PeerConnStatus is a peer's connection status.
|
||||||
|
type PeerConnStatus = peer.ConnStatus
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance.
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
@@ -81,6 +83,14 @@ type Options struct {
|
|||||||
BlockInbound bool
|
BlockInbound bool
|
||||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||||
WireguardPort *int
|
WireguardPort *int
|
||||||
|
// MTU is the MTU for the WireGuard interface.
|
||||||
|
// Valid values are in the range 576..8192 bytes.
|
||||||
|
// If non-nil, this value overrides any value stored in the config file.
|
||||||
|
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
|
||||||
|
// Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams.
|
||||||
|
MTU *uint16
|
||||||
|
// DNSLabels defines additional DNS labels configured in the peer.
|
||||||
|
DNSLabels []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -112,6 +122,12 @@ func New(opts Options) (*Client, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.MTU != nil {
|
||||||
|
if err := iface.ValidateMTU(*opts.MTU); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid MTU: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if opts.LogOutput != nil {
|
if opts.LogOutput != nil {
|
||||||
logrus.SetOutput(opts.LogOutput)
|
logrus.SetOutput(opts.LogOutput)
|
||||||
}
|
}
|
||||||
@@ -140,9 +156,14 @@ func New(opts Options) (*Client, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var parsedLabels domain.List
|
||||||
|
if parsedLabels, err = domain.FromStringList(opts.DNSLabels); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid dns labels: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
t := true
|
t := true
|
||||||
var config *profilemanager.Config
|
var config *profilemanager.Config
|
||||||
var err error
|
|
||||||
input := profilemanager.ConfigInput{
|
input := profilemanager.ConfigInput{
|
||||||
ConfigPath: opts.ConfigPath,
|
ConfigPath: opts.ConfigPath,
|
||||||
ManagementURL: opts.ManagementURL,
|
ManagementURL: opts.ManagementURL,
|
||||||
@@ -151,6 +172,8 @@ func New(opts Options) (*Client, error) {
|
|||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
BlockInbound: &opts.BlockInbound,
|
BlockInbound: &opts.BlockInbound,
|
||||||
WireguardPort: opts.WireguardPort,
|
WireguardPort: opts.WireguardPort,
|
||||||
|
MTU: opts.MTU,
|
||||||
|
DNSLabels: parsedLabels,
|
||||||
}
|
}
|
||||||
if opts.ConfigPath != "" {
|
if opts.ConfigPath != "" {
|
||||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||||
@@ -202,7 +225,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
client := internal.NewConnectClient(ctx, c.config, c.recorder)
|
||||||
client.SetSyncResponsePersistence(true)
|
client.SetSyncResponsePersistence(true)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
@@ -352,6 +375,32 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL.
|
||||||
|
// It returns an ExposeSession. Call Wait on the session to keep it alive.
|
||||||
|
func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) {
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := engine.GetExposeManager()
|
||||||
|
if mgr == nil {
|
||||||
|
return nil, fmt.Errorf("expose manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := mgr.Expose(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("expose: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ExposeSession{
|
||||||
|
Domain: resp.Domain,
|
||||||
|
ServiceName: resp.ServiceName,
|
||||||
|
ServiceURL: resp.ServiceURL,
|
||||||
|
mgr: mgr,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Status returns the current status of the client.
|
// Status returns the current status of the client.
|
||||||
func (c *Client) Status() (peer.FullStatus, error) {
|
func (c *Client) Status() (peer.FullStatus, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
|
|||||||
45
client/embed/expose.go
Normal file
45
client/embed/expose.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package embed
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/expose"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ExposeProtocolHTTP exposes the service as HTTP.
|
||||||
|
ExposeProtocolHTTP = expose.ProtocolHTTP
|
||||||
|
// ExposeProtocolHTTPS exposes the service as HTTPS.
|
||||||
|
ExposeProtocolHTTPS = expose.ProtocolHTTPS
|
||||||
|
// ExposeProtocolTCP exposes the service as TCP.
|
||||||
|
ExposeProtocolTCP = expose.ProtocolTCP
|
||||||
|
// ExposeProtocolUDP exposes the service as UDP.
|
||||||
|
ExposeProtocolUDP = expose.ProtocolUDP
|
||||||
|
// ExposeProtocolTLS exposes the service as TLS.
|
||||||
|
ExposeProtocolTLS = expose.ProtocolTLS
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExposeRequest is a request to expose a local service via the NetBird reverse proxy.
|
||||||
|
type ExposeRequest = expose.Request
|
||||||
|
|
||||||
|
// ExposeProtocolType represents the protocol used for exposing a service.
|
||||||
|
type ExposeProtocolType = expose.ProtocolType
|
||||||
|
|
||||||
|
// ExposeSession represents an active expose session. Use Wait to block until the session ends.
|
||||||
|
type ExposeSession struct {
|
||||||
|
Domain string
|
||||||
|
ServiceName string
|
||||||
|
ServiceURL string
|
||||||
|
|
||||||
|
mgr *expose.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait blocks while keeping the expose session alive.
|
||||||
|
// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session.
|
||||||
|
func (s *ExposeSession) Wait(ctx context.Context) error {
|
||||||
|
if s == nil || s.mgr == nil {
|
||||||
|
return errors.New("expose session is not initialized")
|
||||||
|
}
|
||||||
|
return s.mgr.KeepAlive(ctx, s.Domain)
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
@@ -35,20 +36,34 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
|
||||||
// in any case, because we need to allow netbird interface traffic
|
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
log.Info("forcing userspace firewall")
|
||||||
// for the userspace packet filtering firewall
|
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
|
||||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||||
|
|
||||||
|
// Kernel cannot fall back to anything else, need to return error
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return fm, err
|
return fm, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fall back to the userspace packet filter if native is unavailable
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
|
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||||
}
|
}
|
||||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
|
||||||
|
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
||||||
|
// needs a device filter for DNS interception hooks. Install a minimal
|
||||||
|
// hooks-only filter that passes all traffic through to the kernel firewall.
|
||||||
|
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
||||||
|
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||||
@@ -160,3 +175,17 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
|||||||
_, err := client.ListChains("filter")
|
_, err := client.ListChains("filter")
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func forceUserspaceFirewall() bool {
|
||||||
|
val := os.Getenv(EnvForceUserspaceFirewall)
|
||||||
|
if val == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
force, err := strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return force
|
||||||
|
}
|
||||||
|
|||||||
11
client/firewall/firewalld/firewalld.go
Normal file
11
client/firewall/firewalld/firewalld.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Package firewalld integrates with the firewalld daemon so NetBird can place
|
||||||
|
// its wg interface into firewalld's "trusted" zone. This is required because
|
||||||
|
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
|
||||||
|
// versions, which returns EPERM to any other process that tries to insert
|
||||||
|
// rules into them. The workaround mirrors what Tailscale does: let firewalld
|
||||||
|
// itself add the accept rules to its own chains by trusting the interface.
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
// TrustedZone is the firewalld zone name used for interfaces whose traffic
|
||||||
|
// should bypass firewalld filtering.
|
||||||
|
const TrustedZone = "trusted"
|
||||||
260
client/firewall/firewalld/firewalld_linux.go
Normal file
260
client/firewall/firewalld/firewalld_linux.go
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/godbus/dbus/v5"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dbusDest = "org.fedoraproject.FirewallD1"
|
||||||
|
dbusPath = "/org/fedoraproject/FirewallD1"
|
||||||
|
dbusRootIface = "org.fedoraproject.FirewallD1"
|
||||||
|
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
|
||||||
|
|
||||||
|
errZoneAlreadySet = "ZONE_ALREADY_SET"
|
||||||
|
errAlreadyEnabled = "ALREADY_ENABLED"
|
||||||
|
errUnknownIface = "UNKNOWN_INTERFACE"
|
||||||
|
errNotEnabled = "NOT_ENABLED"
|
||||||
|
|
||||||
|
// callTimeout bounds each individual DBus or firewall-cmd invocation.
|
||||||
|
// A fresh context is created for each call so a slow DBus probe can't
|
||||||
|
// exhaust the deadline before the firewall-cmd fallback gets to run.
|
||||||
|
callTimeout = 3 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errDBusUnavailable = errors.New("firewalld dbus unavailable")
|
||||||
|
|
||||||
|
// trustLogOnce ensures the "added to trusted zone" message is logged at
|
||||||
|
// Info level only for the first successful add per process; repeat adds
|
||||||
|
// from other init paths are quieter.
|
||||||
|
trustLogOnce sync.Once
|
||||||
|
|
||||||
|
parentCtxMu sync.RWMutex
|
||||||
|
parentCtx context.Context = context.Background()
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetParentContext installs a parent context whose cancellation aborts any
|
||||||
|
// in-flight TrustInterface call. It does not affect UntrustInterface, which
|
||||||
|
// always uses a fresh Background-rooted timeout so cleanup can still run
|
||||||
|
// during engine shutdown when the engine context is already cancelled.
|
||||||
|
func SetParentContext(ctx context.Context) {
|
||||||
|
parentCtxMu.Lock()
|
||||||
|
parentCtx = ctx
|
||||||
|
parentCtxMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func getParentContext() context.Context {
|
||||||
|
parentCtxMu.RLock()
|
||||||
|
defer parentCtxMu.RUnlock()
|
||||||
|
return parentCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrustInterface places iface into firewalld's trusted zone if firewalld is
|
||||||
|
// running. It is idempotent and best-effort: errors are returned so callers
|
||||||
|
// can log, but a non-running firewalld is not an error. Only the first
|
||||||
|
// successful call per process logs at Info. Respects the parent context set
|
||||||
|
// via SetParentContext so startup-time cancellation unblocks it.
|
||||||
|
func TrustInterface(iface string) error {
|
||||||
|
parent := getParentContext()
|
||||||
|
if !isRunning(parent) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := addTrusted(parent, iface); err != nil {
|
||||||
|
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
|
||||||
|
}
|
||||||
|
trustLogOnce.Do(func() {
|
||||||
|
log.Infof("added %s to firewalld trusted zone", iface)
|
||||||
|
})
|
||||||
|
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
|
||||||
|
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
|
||||||
|
// during shutdown after the engine context has been cancelled.
|
||||||
|
func UntrustInterface(iface string) error {
|
||||||
|
if !isRunning(context.Background()) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := removeTrusted(context.Background(), iface); err != nil {
|
||||||
|
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
|
||||||
|
return context.WithTimeout(parent, callTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRunning(parent context.Context) bool {
|
||||||
|
ctx, cancel := newCallContext(parent)
|
||||||
|
ok, err := isRunningDBus(ctx)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
ctx, cancel = newCallContext(parent)
|
||||||
|
defer cancel()
|
||||||
|
return isRunningCLI(ctx)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func addTrusted(parent context.Context, iface string) error {
|
||||||
|
ctx, cancel := newCallContext(parent)
|
||||||
|
err := addDBus(ctx, iface)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errDBusUnavailable) {
|
||||||
|
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
|
||||||
|
}
|
||||||
|
ctx, cancel = newCallContext(parent)
|
||||||
|
defer cancel()
|
||||||
|
return addCLI(ctx, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeTrusted(parent context.Context, iface string) error {
|
||||||
|
ctx, cancel := newCallContext(parent)
|
||||||
|
err := removeDBus(ctx, iface)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errDBusUnavailable) {
|
||||||
|
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
|
||||||
|
}
|
||||||
|
ctx, cancel = newCallContext(parent)
|
||||||
|
defer cancel()
|
||||||
|
return removeCLI(ctx, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRunningDBus(ctx context.Context) (bool, error) {
|
||||||
|
conn, err := dbus.SystemBus()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||||
|
}
|
||||||
|
obj := conn.Object(dbusDest, dbusPath)
|
||||||
|
|
||||||
|
var zone string
|
||||||
|
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
|
||||||
|
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRunningCLI(ctx context.Context) bool {
|
||||||
|
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addDBus(ctx context.Context, iface string) error {
|
||||||
|
conn, err := dbus.SystemBus()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||||
|
}
|
||||||
|
obj := conn.Object(dbusDest, dbusPath)
|
||||||
|
|
||||||
|
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
|
||||||
|
if call.Err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbusErrContains(call.Err, errAlreadyEnabled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbusErrContains(call.Err, errZoneAlreadySet) {
|
||||||
|
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
|
||||||
|
if move.Err != nil {
|
||||||
|
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("firewalld addInterface: %w", call.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeDBus(ctx context.Context, iface string) error {
|
||||||
|
conn, err := dbus.SystemBus()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||||
|
}
|
||||||
|
obj := conn.Object(dbusDest, dbusPath)
|
||||||
|
|
||||||
|
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
|
||||||
|
if call.Err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addCLI(ctx context.Context, iface string) error {
|
||||||
|
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||||
|
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --change-interface (no --permanent) binds the interface for the
|
||||||
|
// current runtime only; we do not want membership to persist across
|
||||||
|
// reboots because netbird re-asserts it on every startup.
|
||||||
|
out, err := exec.CommandContext(ctx,
|
||||||
|
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
|
||||||
|
).CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeCLI(ctx context.Context, iface string) error {
|
||||||
|
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||||
|
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.CommandContext(ctx,
|
||||||
|
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
|
||||||
|
).CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
msg := strings.TrimSpace(string(out))
|
||||||
|
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dbusErrContains(err error, code string) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var de dbus.Error
|
||||||
|
if errors.As(err, &de) {
|
||||||
|
for _, b := range de.Body {
|
||||||
|
if s, ok := b.(string); ok && strings.Contains(s, code) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Contains(err.Error(), code)
|
||||||
|
}
|
||||||
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/godbus/dbus/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDBusErrContains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
code string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"nil error", nil, errZoneAlreadySet, false},
|
||||||
|
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
|
||||||
|
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
|
||||||
|
{
|
||||||
|
"dbus.Error body match",
|
||||||
|
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
|
||||||
|
errZoneAlreadySet,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dbus.Error body miss",
|
||||||
|
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
|
||||||
|
errAlreadyEnabled,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dbus.Error non-string body falls back to Error()",
|
||||||
|
dbus.Error{Name: "x", Body: []any{123}},
|
||||||
|
"x",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := dbusErrContains(tc.err, tc.code)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
25
client/firewall/firewalld/firewalld_other.go
Normal file
25
client/firewall/firewalld/firewalld_other.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// SetParentContext is a no-op on non-Linux platforms because firewalld only
|
||||||
|
// runs on Linux.
|
||||||
|
func SetParentContext(context.Context) {
|
||||||
|
// intentionally empty: firewalld is a Linux-only daemon
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||||
|
// runs on Linux.
|
||||||
|
func TrustInterface(string) error {
|
||||||
|
// intentionally empty: firewalld is a Linux-only daemon
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||||
|
// runs on Linux.
|
||||||
|
func UntrustInterface(string) error {
|
||||||
|
// intentionally empty: firewalld is a Linux-only daemon
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -7,6 +7,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when
|
||||||
|
// native iptables/nftables is available. This only applies when the WireGuard interface
|
||||||
|
// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of
|
||||||
|
// kernel netfilter rules.
|
||||||
|
const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL"
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ const (
|
|||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
// rules chains contains the effective ACL rules
|
||||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||||
|
|
||||||
|
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
|
||||||
|
// external DNAT from bypassing ACL rules.
|
||||||
|
mangleFwdKey = "MANGLE-FORWARD"
|
||||||
)
|
)
|
||||||
|
|
||||||
type aclEntries map[string][][]string
|
type aclEntries map[string][][]string
|
||||||
@@ -274,6 +278,12 @@ func (m *aclManager) cleanChains() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, rule := range m.entries[mangleFwdKey] {
|
||||||
|
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||||
if err := m.flushIPSet(ipsetName); err != nil {
|
if err := m.flushIPSet(ipsetName); err != nil {
|
||||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
@@ -303,6 +313,10 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for chainName, rules := range m.entries {
|
for chainName, rules := range m.entries {
|
||||||
|
// mangle FORWARD guard rules are handled separately below
|
||||||
|
if chainName == mangleFwdKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||||
@@ -322,6 +336,13 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
}
|
}
|
||||||
clear(m.optionalEntries)
|
clear(m.optionalEntries)
|
||||||
|
|
||||||
|
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
||||||
|
for _, rule := range m.entries[mangleFwdKey] {
|
||||||
|
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
|
||||||
|
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -343,6 +364,22 @@ func (m *aclManager) seedInitialEntries() {
|
|||||||
|
|
||||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||||
|
|
||||||
|
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
|
||||||
|
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
|
||||||
|
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
|
||||||
|
// ACL mark check where it cannot be overridden.
|
||||||
|
m.appendToEntries(mangleFwdKey, []string{
|
||||||
|
"-i", m.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
||||||
|
"-j", "ACCEPT",
|
||||||
|
})
|
||||||
|
m.appendToEntries(mangleFwdKey, []string{
|
||||||
|
"-i", m.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "DNAT",
|
||||||
|
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||||
|
"-j", "DROP",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) seedInitialOptionalEntries() {
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -26,13 +27,13 @@ type Manager struct {
|
|||||||
ipv4Client *iptables.IPTables
|
ipv4Client *iptables.IPTables
|
||||||
aclMgr *aclManager
|
aclMgr *aclManager
|
||||||
router *router
|
router *router
|
||||||
|
rawSupported bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
@@ -65,7 +66,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
WGAddress: m.wgIface.Address(),
|
WGAddress: m.wgIface.Address(),
|
||||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
|
||||||
MTU: m.router.mtu,
|
MTU: m.router.mtu,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -84,7 +84,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := m.initNoTrackChain(); err != nil {
|
if err := m.initNoTrackChain(); err != nil {
|
||||||
return fmt.Errorf("init notrack chain: %w", err)
|
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trust after all fatal init steps so a later failure doesn't leave the
|
||||||
|
// interface in firewalld's trusted zone without a corresponding Close.
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// persist early to ensure cleanup of chains
|
// persist early to ensure cleanup of chains
|
||||||
@@ -192,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||||
|
// stays persisted and the crash-recovery path retries firewalld cleanup.
|
||||||
|
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
// attempt to delete state only if all other operations succeeded
|
// attempt to delete state only if all other operations succeeded
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
@@ -202,12 +214,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic.
|
||||||
|
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||||
|
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||||
func (m *Manager) AllowNetbird() error {
|
func (m *Manager) AllowNetbird() error {
|
||||||
if !m.wgIface.IsUserspaceBind() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
@@ -220,6 +230,11 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,6 +300,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
|
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
|
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
chainNameRaw = "NETBIRD-RAW"
|
chainNameRaw = "NETBIRD-RAW"
|
||||||
chainOUTPUT = "OUTPUT"
|
chainOUTPUT = "OUTPUT"
|
||||||
@@ -318,6 +349,10 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if !m.rawSupported {
|
||||||
|
return fmt.Errorf("raw table not available")
|
||||||
|
}
|
||||||
|
|
||||||
wgPortStr := fmt.Sprintf("%d", wgPort)
|
wgPortStr := fmt.Sprintf("%d", wgPort)
|
||||||
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
||||||
|
|
||||||
@@ -375,12 +410,16 @@ func (m *Manager) initNoTrackChain() error {
|
|||||||
return fmt.Errorf("add prerouting jump rule: %w", err)
|
return fmt.Errorf("add prerouting jump rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.rawSupported = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) cleanupNoTrackChain() error {
|
func (m *Manager) cleanupNoTrackChain() error {
|
||||||
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if !m.rawSupported {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return fmt.Errorf("check chain exists: %w", err)
|
return fmt.Errorf("check chain exists: %w", err)
|
||||||
}
|
}
|
||||||
if !exists {
|
if !exists {
|
||||||
@@ -401,6 +440,7 @@ func (m *Manager) cleanupNoTrackChain() error {
|
|||||||
return fmt.Errorf("clear and delete chain: %w", err)
|
return fmt.Errorf("clear and delete chain: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.rawSupported = false
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -47,8 +47,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
|
|||||||
panic("AddressFunc is not set")
|
panic("AddressFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
|
||||||
|
|
||||||
func TestIptablesManager(t *testing.T) {
|
func TestIptablesManager(t *testing.T) {
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ const (
|
|||||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||||
chainRTPRE = "NETBIRD-RT-PRE"
|
chainRTPRE = "NETBIRD-RT-PRE"
|
||||||
chainRTRDR = "NETBIRD-RT-RDR"
|
chainRTRDR = "NETBIRD-RT-RDR"
|
||||||
|
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
||||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||||
routingFinalForwardJump = "ACCEPT"
|
routingFinalForwardJump = "ACCEPT"
|
||||||
routingFinalNatJump = "MASQUERADE"
|
routingFinalNatJump = "MASQUERADE"
|
||||||
@@ -43,6 +44,7 @@ const (
|
|||||||
jumpManglePre = "jump-mangle-pre"
|
jumpManglePre = "jump-mangle-pre"
|
||||||
jumpNatPre = "jump-nat-pre"
|
jumpNatPre = "jump-nat-pre"
|
||||||
jumpNatPost = "jump-nat-post"
|
jumpNatPost = "jump-nat-post"
|
||||||
|
jumpNatOutput = "jump-nat-output"
|
||||||
jumpMSSClamp = "jump-mss-clamp"
|
jumpMSSClamp = "jump-mss-clamp"
|
||||||
markManglePre = "mark-mangle-pre"
|
markManglePre = "mark-mangle-pre"
|
||||||
markManglePost = "mark-mangle-post"
|
markManglePost = "mark-mangle-post"
|
||||||
@@ -387,6 +389,14 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("flushing routing related tables")
|
log.Debug("flushing routing related tables")
|
||||||
|
|
||||||
|
// Remove jump rules from built-in chains before deleting custom chains,
|
||||||
|
// otherwise the chain deletion fails with "device or resource busy".
|
||||||
|
jumpRule := []string{"-j", chainNATOutput}
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
||||||
|
log.Debugf("clean OUTPUT jump rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
for _, chainInfo := range []struct {
|
for _, chainInfo := range []struct {
|
||||||
chain string
|
chain string
|
||||||
table string
|
table string
|
||||||
@@ -396,6 +406,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
{chainRTPRE, tableMangle},
|
{chainRTPRE, tableMangle},
|
||||||
{chainRTNAT, tableNat},
|
{chainRTNAT, tableNat},
|
||||||
{chainRTRDR, tableNat},
|
{chainRTRDR, tableNat},
|
||||||
|
{chainNATOutput, tableNat},
|
||||||
{chainRTMSSCLAMP, tableMangle},
|
{chainRTMSSCLAMP, tableMangle},
|
||||||
} {
|
} {
|
||||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||||
@@ -970,6 +981,81 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
||||||
|
func (r *router) ensureNATOutputChain() error {
|
||||||
|
if _, exists := r.rules[jumpNatOutput]; exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||||
|
}
|
||||||
|
if !chainExists {
|
||||||
|
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
||||||
|
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
jumpRule := []string{"-j", chainNATOutput}
|
||||||
|
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
|
||||||
|
if !chainExists {
|
||||||
|
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
||||||
|
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpNatOutput] = jumpRule
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
|
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
|
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.ensureNATOutputChain(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dnatRule := []string{
|
||||||
|
"-p", strings.ToLower(string(protocol)),
|
||||||
|
"--dport", strconv.Itoa(int(sourcePort)),
|
||||||
|
"-d", localAddr.String(),
|
||||||
|
"-j", "DNAT",
|
||||||
|
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||||
|
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules[ruleID] = dnatRule
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
|
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
|
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
|
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||||
|
return fmt.Errorf("delete output DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func applyPort(flag string, port *firewall.Port) []string {
|
func applyPort(flag string, port *firewall.Port) []string {
|
||||||
if port == nil {
|
if port == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
type InterfaceState struct {
|
type InterfaceState struct {
|
||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress wgaddr.Address `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
UserspaceBind bool `json:"userspace_bind"`
|
|
||||||
MTU uint16 `json:"mtu"`
|
MTU uint16 `json:"mtu"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,10 +22,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
|
|||||||
return i.WGAddress
|
return i.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
|
||||||
return i.UserspaceBind
|
|
||||||
}
|
|
||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
|
|||||||
@@ -169,6 +169,14 @@ type Manager interface {
|
|||||||
// RemoveInboundDNAT removes inbound DNAT rule
|
// RemoveInboundDNAT removes inbound DNAT rule
|
||||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||||
|
|
||||||
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
|
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||||
|
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||||
|
|
||||||
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
|
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||||
|
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||||
|
|
||||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -40,7 +41,6 @@ func getTableName() string {
|
|||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
@@ -95,7 +95,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := m.initNoTrackChains(workTable); err != nil {
|
if err := m.initNoTrackChains(workTable); err != nil {
|
||||||
return fmt.Errorf("init notrack chains: %w", err)
|
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stateManager.RegisterState(&ShutdownState{})
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
@@ -108,7 +108,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
WGAddress: m.wgIface.Address(),
|
WGAddress: m.wgIface.Address(),
|
||||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
|
||||||
MTU: m.router.mtu,
|
MTU: m.router.mtu,
|
||||||
},
|
},
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -205,12 +204,10 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return m.router.RemoveNatRule(pair)
|
return m.router.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic.
|
||||||
|
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||||
|
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||||
func (m *Manager) AllowNetbird() error {
|
func (m *Manager) AllowNetbird() error {
|
||||||
if !m.wgIface.IsUserspaceBind() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -221,6 +218,10 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -346,6 +347,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
|
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
|
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
chainNameRawOutput = "netbird-raw-out"
|
chainNameRawOutput = "netbird-raw-out"
|
||||||
chainNameRawPrerouting = "netbird-raw-pre"
|
chainNameRawPrerouting = "netbird-raw-pre"
|
||||||
|
|||||||
@@ -52,8 +52,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
|
|||||||
panic("AddressFunc is not set")
|
panic("AddressFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
|
||||||
|
|
||||||
func TestNftablesManager(t *testing.T) {
|
func TestNftablesManager(t *testing.T) {
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
@@ -36,9 +37,12 @@ const (
|
|||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||||
|
chainNameNATOutput = "netbird-nat-output"
|
||||||
chainNameForward = "FORWARD"
|
chainNameForward = "FORWARD"
|
||||||
chainNameMangleForward = "netbird-mangle-forward"
|
chainNameMangleForward = "netbird-mangle-forward"
|
||||||
|
|
||||||
|
firewalldTableName = "firewalld"
|
||||||
|
|
||||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||||
userDataAcceptInputRule = "inputaccept"
|
userDataAcceptInputRule = "inputaccept"
|
||||||
@@ -132,6 +136,10 @@ func (r *router) Reset() error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeNatPreroutingRules(); err != nil {
|
if err := r.removeNatPreroutingRules(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||||
}
|
}
|
||||||
@@ -279,6 +287,10 @@ func (r *router) createContainers() error {
|
|||||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
log.Errorf("failed to refresh rules: %s", err)
|
log.Errorf("failed to refresh rules: %s", err)
|
||||||
}
|
}
|
||||||
@@ -1318,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||||
|
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||||
|
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||||
|
if chain.Table.Name == firewalldTableName {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Skip all iptables-managed tables in the ip family
|
// Skip all iptables-managed tables in the ip family
|
||||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||||
return false
|
return false
|
||||||
@@ -1853,6 +1872,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
||||||
|
func (r *router) ensureNATOutputChain() error {
|
||||||
|
if _, exists := r.chains[chainNameNATOutput]; exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameNATOutput,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookOutput,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
delete(r.chains, chainNameNATOutput)
|
||||||
|
return fmt.Errorf("create NAT output chain: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
|
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
|
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.ensureNATOutputChain(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoNum, err := protoToInt(protocol)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("convert protocol to number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 2,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||||
|
|
||||||
|
exprs = append(exprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: localAddr.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||||
|
},
|
||||||
|
&expr.NAT{
|
||||||
|
Type: expr.NATTypeDestNAT,
|
||||||
|
Family: uint32(nftables.TableFamilyIPv4),
|
||||||
|
RegAddrMin: 1,
|
||||||
|
RegProtoMin: 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
dnatRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameNATOutput],
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(ruleID),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(dnatRule)
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[ruleID] = dnatRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
|
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
|
rule, exists := r.rules[ruleID]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||||
func (r *router) applyNetwork(
|
func (r *router) applyNetwork(
|
||||||
network firewall.Network,
|
network firewall.Network,
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
type InterfaceState struct {
|
type InterfaceState struct {
|
||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress wgaddr.Address `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
UserspaceBind bool `json:"userspace_bind"`
|
|
||||||
MTU uint16 `json:"mtu"`
|
MTU uint16 `json:"mtu"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,10 +21,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
|
|||||||
return i.WGAddress
|
return i.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
|
||||||
return i.UserspaceBind
|
|
||||||
}
|
|
||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState struct {
|
||||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,9 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Close(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
}
|
}
|
||||||
|
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AllowNetbird()
|
return m.nativeFirewall.AllowNetbird()
|
||||||
}
|
}
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
37
client/firewall/uspfilter/common/hooks.go
Normal file
37
client/firewall/uspfilter/common/hooks.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PacketHook stores a registered hook for a specific IP:port.
|
||||||
|
type PacketHook struct {
|
||||||
|
IP netip.Addr
|
||||||
|
Port uint16
|
||||||
|
Fn func([]byte) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// HookMatches checks if a packet's destination matches the hook and invokes it.
|
||||||
|
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
||||||
|
if h == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.IP == dstIP && h.Port == dport {
|
||||||
|
return h.Fn(packetData)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHook atomically stores a hook, handling nil removal.
|
||||||
|
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||||
|
if hook == nil {
|
||||||
|
ptr.Store(nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ptr.Store(&PacketHook{
|
||||||
|
IP: ip,
|
||||||
|
Port: dPort,
|
||||||
|
Fn: hook,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
|
Name() string
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
GetWGDevice() *wgdevice.Device
|
GetWGDevice() *wgdevice.Device
|
||||||
|
|||||||
@@ -140,6 +140,10 @@ type Manager struct {
|
|||||||
mtu uint16
|
mtu uint16
|
||||||
mssClampValue uint16
|
mssClampValue uint16
|
||||||
mssClampEnabled bool
|
mssClampEnabled bool
|
||||||
|
|
||||||
|
// Only one hook per protocol is supported. Outbound direction only.
|
||||||
|
udpHookOut atomic.Pointer[common.PacketHook]
|
||||||
|
tcpHookOut atomic.Pointer[common.PacketHook]
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -594,6 +598,8 @@ func (m *Manager) resetState() {
|
|||||||
maps.Clear(m.incomingRules)
|
maps.Clear(m.incomingRules)
|
||||||
maps.Clear(m.routeRulesMap)
|
maps.Clear(m.routeRulesMap)
|
||||||
m.routeRules = m.routeRules[:0]
|
m.routeRules = m.routeRules[:0]
|
||||||
|
m.udpHookOut.Store(nil)
|
||||||
|
m.tcpHookOut.Store(nil)
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
@@ -713,6 +719,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
|
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||||
// SNATed routed traffic may appear as local IP but still requires clamping.
|
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||||
if m.mssClampEnabled {
|
if m.mssClampEnabled {
|
||||||
@@ -895,39 +904,12 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
|||||||
d.dnatOrigPort = 0
|
d.dnatOrigPort = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
|
||||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||||
m.mutex.RLock()
|
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
||||||
defer m.mutex.RUnlock()
|
}
|
||||||
|
|
||||||
// Check specific destination IP first
|
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||||
if rules, exists := m.outgoingRules[dstIP]; exists {
|
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
||||||
for _, rule := range rules {
|
|
||||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
|
||||||
return rule.udpHook(packetData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check IPv4 unspecified address
|
|
||||||
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
|
||||||
for _, rule := range rules {
|
|
||||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
|
||||||
return rule.udpHook(packetData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check IPv6 unspecified address
|
|
||||||
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
|
||||||
for _, rule := range rules {
|
|
||||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
|
||||||
return rule.udpHook(packetData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterInbound implements filtering logic for incoming packets.
|
// filterInbound implements filtering logic for incoming packets.
|
||||||
@@ -1278,12 +1260,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
|||||||
return rule.mgmtId, rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
// if rule has UDP hook (and if we are here we match this rule)
|
|
||||||
// we ignore rule.drop and call this hook
|
|
||||||
if rule.udpHook != nil {
|
|
||||||
return rule.mgmtId, rule.udpHook(packetData), true
|
|
||||||
}
|
|
||||||
|
|
||||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||||
return rule.mgmtId, rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
@@ -1342,65 +1318,14 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
|||||||
return sourceMatched
|
return sourceMatched
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
||||||
//
|
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
common.SetHook(&m.udpHookOut, ip, dPort, hook)
|
||||||
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
|
||||||
r := PeerRule{
|
|
||||||
id: uuid.New().String(),
|
|
||||||
ip: ip,
|
|
||||||
protoLayer: layers.LayerTypeUDP,
|
|
||||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
|
||||||
udpHook: hook,
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip.Is4() {
|
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
if in {
|
|
||||||
// Incoming UDP hooks are stored in allow rules map
|
|
||||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
|
||||||
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
|
||||||
}
|
|
||||||
m.incomingRules[r.ip][r.id] = r
|
|
||||||
} else {
|
|
||||||
if _, ok := m.outgoingRules[r.ip]; !ok {
|
|
||||||
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
|
||||||
}
|
|
||||||
m.outgoingRules[r.ip][r.id] = r
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
return r.id
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePacketHook removes packet hook by given ID
|
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
||||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||||
m.mutex.Lock()
|
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
// Check incoming hooks (stored in allow rules)
|
|
||||||
for _, arr := range m.incomingRules {
|
|
||||||
for _, r := range arr {
|
|
||||||
if r.id == hookID {
|
|
||||||
delete(arr, r.id)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Check outgoing hooks
|
|
||||||
for _, arr := range m.outgoingRules {
|
|
||||||
for _, r := range arr {
|
|
||||||
if r.id == hookID {
|
|
||||||
delete(arr, r.id)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("hook with given id not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLogLevel sets the log level for the firewall manager
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
@@ -30,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
|
|||||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
|
NameFunc func() string
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
AddressFunc func() wgaddr.Address
|
AddressFunc func() wgaddr.Address
|
||||||
GetWGDeviceFunc func() *wgdevice.Device
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
GetDeviceFunc func() *device.FilteredDevice
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) Name() string {
|
||||||
|
if i.NameFunc == nil {
|
||||||
|
return "wgtest"
|
||||||
|
}
|
||||||
|
return i.NameFunc()
|
||||||
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||||
if i.GetWGDeviceFunc == nil {
|
if i.GetWGDeviceFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -186,81 +195,52 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddUDPPacketHook(t *testing.T) {
|
func TestSetUDPPacketHook(t *testing.T) {
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in bool
|
|
||||||
expDir fw.RuleDirection
|
|
||||||
ip netip.Addr
|
|
||||||
dPort uint16
|
|
||||||
hook func([]byte) bool
|
|
||||||
expectedID string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Test Outgoing UDP Packet Hook",
|
|
||||||
in: false,
|
|
||||||
expDir: fw.RuleDirectionOUT,
|
|
||||||
ip: netip.MustParseAddr("10.168.0.1"),
|
|
||||||
dPort: 8000,
|
|
||||||
hook: func([]byte) bool { return true },
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Test Incoming UDP Packet Hook",
|
|
||||||
in: true,
|
|
||||||
expDir: fw.RuleDirectionIN,
|
|
||||||
ip: netip.MustParseAddr("::1"),
|
|
||||||
dPort: 9000,
|
|
||||||
hook: func([]byte) bool { return false },
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger, nbiface.DefaultMTU)
|
}, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
var called bool
|
||||||
|
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
|
||||||
var addedRule PeerRule
|
called = true
|
||||||
if tt.in {
|
return true
|
||||||
// Incoming UDP hooks are stored in allow rules map
|
|
||||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, rule := range manager.incomingRules[tt.ip] {
|
|
||||||
addedRule = rule
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if len(manager.outgoingRules[tt.ip]) != 1 {
|
|
||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
|
||||||
addedRule = rule
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
|
||||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.dPort != addedRule.dPort.Values[0] {
|
|
||||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
|
||||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if addedRule.udpHook == nil {
|
|
||||||
t.Errorf("expected udpHook to be set")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
h := manager.udpHookOut.Load()
|
||||||
|
require.NotNil(t, h)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
||||||
|
assert.Equal(t, uint16(8000), h.Port)
|
||||||
|
assert.True(t, h.Fn(nil))
|
||||||
|
assert.True(t, called)
|
||||||
|
|
||||||
|
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
||||||
|
assert.Nil(t, manager.udpHookOut.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetTCPPacketHook(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||||
|
|
||||||
|
var called bool
|
||||||
|
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
|
||||||
|
called = true
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
h := manager.tcpHookOut.Load()
|
||||||
|
require.NotNil(t, h)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
||||||
|
assert.Equal(t, uint16(53), h.Port)
|
||||||
|
assert.True(t, h.Fn(nil))
|
||||||
|
assert.True(t, called)
|
||||||
|
|
||||||
|
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
||||||
|
assert.Nil(t, manager.tcpHookOut.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||||
@@ -530,39 +510,12 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Add a UDP packet hook
|
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
|
||||||
hookFunc := func(data []byte) bool { return true }
|
|
||||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
|
||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
|
||||||
found := false
|
|
||||||
for _, arr := range manager.outgoingRules {
|
|
||||||
for _, rule := range arr {
|
|
||||||
if rule.id == hookID {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
|
||||||
t.Fatalf("The hook was not added properly.")
|
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
|
||||||
}
|
|
||||||
|
|
||||||
// Now remove the packet hook
|
|
||||||
err = manager.RemovePacketHook(hookID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to remove hook: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
|
||||||
for _, arr := range manager.outgoingRules {
|
|
||||||
for _, rule := range arr {
|
|
||||||
if rule.id == hookID {
|
|
||||||
t.Fatalf("The hook was not removed properly.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
@@ -592,8 +545,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hookCalled := false
|
hookCalled := false
|
||||||
hookID := manager.AddUDPPacketHook(
|
manager.SetUDPPacketHook(
|
||||||
false,
|
|
||||||
netip.MustParseAddr("100.10.0.100"),
|
netip.MustParseAddr("100.10.0.100"),
|
||||||
53,
|
53,
|
||||||
func([]byte) bool {
|
func([]byte) bool {
|
||||||
@@ -601,7 +553,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
require.NotEmpty(t, hookID)
|
|
||||||
|
|
||||||
// Create test UDP packet
|
// Create test UDP packet
|
||||||
ipv4 := &layers.IPv4{
|
ipv4 := &layers.IPv4{
|
||||||
|
|||||||
90
client/firewall/uspfilter/hooks_filter.go
Normal file
90
client/firewall/uspfilter/hooks_filter.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipv4HeaderMinLen = 20
|
||||||
|
ipv4ProtoOffset = 9
|
||||||
|
ipv4FlagsOffset = 6
|
||||||
|
ipv4DstOffset = 16
|
||||||
|
ipProtoUDP = 17
|
||||||
|
ipProtoTCP = 6
|
||||||
|
ipv4FragOffMask = 0x1fff
|
||||||
|
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
|
||||||
|
dstPortOffset = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
|
||||||
|
// It is installed on the WireGuard interface when the userspace bind is active
|
||||||
|
// but a full firewall filter (Manager) is not needed because a native kernel
|
||||||
|
// firewall (nftables/iptables) handles packet filtering.
|
||||||
|
type HooksFilter struct {
|
||||||
|
udpHook atomic.Pointer[common.PacketHook]
|
||||||
|
tcpHook atomic.Pointer[common.PacketHook]
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ device.PacketFilter = (*HooksFilter)(nil)
|
||||||
|
|
||||||
|
// FilterOutbound checks outbound packets for DNS hook matches.
|
||||||
|
// Only IPv4 packets matching the registered hook IP:port are intercepted.
|
||||||
|
// IPv6 and non-IP packets pass through unconditionally.
|
||||||
|
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
|
||||||
|
if len(packetData) < ipv4HeaderMinLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process IPv4 packets, let everything else pass through.
|
||||||
|
if packetData[0]>>4 != 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ihl := int(packetData[0]&0x0f) * 4
|
||||||
|
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip non-first fragments: they don't carry L4 headers.
|
||||||
|
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
|
||||||
|
if flagsAndOffset&ipv4FragOffMask != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := packetData[ipv4ProtoOffset]
|
||||||
|
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
|
||||||
|
|
||||||
|
switch proto {
|
||||||
|
case ipProtoUDP:
|
||||||
|
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
|
||||||
|
case ipProtoTCP:
|
||||||
|
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterInbound allows all inbound packets (native firewall handles filtering).
|
||||||
|
func (f *HooksFilter) FilterInbound([]byte, int) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUDPPacketHook registers the UDP packet hook.
|
||||||
|
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||||
|
common.SetHook(&f.udpHook, ip, dPort, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTCPPacketHook registers the TCP packet hook.
|
||||||
|
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||||
|
common.SetHook(&f.tcpHook, ip, dPort, hook)
|
||||||
|
}
|
||||||
@@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get interfaces: %v", err)
|
log.Warnf("failed to get interfaces: %v", err)
|
||||||
} else {
|
} else {
|
||||||
|
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
||||||
|
// case where an interface comes up between refreshes.
|
||||||
for _, intf := range interfaces {
|
for _, intf := range interfaces {
|
||||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||||
|
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
||||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
var layerType gopacket.LayerType
|
var layerType gopacket.LayerType
|
||||||
switch protocol {
|
switch protocol {
|
||||||
@@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddOutputDNAT delegates to the native firewall if available.
|
||||||
|
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return fmt.Errorf("output DNAT not supported without native firewall")
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveOutputDNAT delegates to the native firewall if available.
|
||||||
|
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
|
}
|
||||||
|
|
||||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
if !m.portDNATEnabled.Load() {
|
if !m.portDNATEnabled.Load() {
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ type PeerRule struct {
|
|||||||
sPort *firewall.Port
|
sPort *firewall.Port
|
||||||
dPort *firewall.Port
|
dPort *firewall.Port
|
||||||
drop bool
|
drop bool
|
||||||
|
|
||||||
udpHook func([]byte) bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the rule id
|
// ID returns the rule id
|
||||||
|
|||||||
@@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "UDPTraffic_WithHook",
|
name: "UDPTraffic_WithHook",
|
||||||
setup: func(m *Manager) {
|
setup: func(m *Manager) {
|
||||||
hookFunc := func([]byte) bool {
|
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
|
||||||
return true
|
return true // drop (intercepted by hook)
|
||||||
}
|
})
|
||||||
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
|
||||||
},
|
},
|
||||||
expectedStages: []PacketStage{
|
expectedStages: []PacketStage{
|
||||||
StageReceived,
|
StageReceived,
|
||||||
StageInboundPortDNAT,
|
StageOutbound1to1NAT,
|
||||||
StageInbound1to1NAT,
|
StageOutboundPortReverse,
|
||||||
StageConntrack,
|
|
||||||
StageRouting,
|
|
||||||
StagePeerACL,
|
|
||||||
StageCompleted,
|
StageCompleted,
|
||||||
},
|
},
|
||||||
expectedAllow: false,
|
expectedAllow: false,
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
|
|
||||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) {
|
||||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
// for js, the outer websocket layer takes care of tls
|
// for js, the outer websocket layer takes care of tls
|
||||||
if tlsEnabled && runtime.GOOS != "js" {
|
if tlsEnabled && runtime.GOOS != "js" {
|
||||||
@@ -46,9 +46,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := grpc.DialContext(
|
opts := []grpc.DialOption{
|
||||||
connCtx,
|
|
||||||
addr,
|
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(tlsEnabled, component),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
grpc.WithBlock(),
|
grpc.WithBlock(),
|
||||||
@@ -56,7 +54,10 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
}),
|
}),
|
||||||
)
|
}
|
||||||
|
opts = append(opts, extraOpts...)
|
||||||
|
|
||||||
|
conn, err := grpc.DialContext(connCtx, addr, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("dial context: %w", err)
|
return nil, fmt.Errorf("dial context: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
|||||||
ipv6Count++
|
ipv6Count++
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
||||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
// routing-correctness checks above are the real assertions; the counts
|
||||||
|
// are a sanity bound to catch a totally silent path.
|
||||||
|
minDelivered := packetsPerFamily * 80 / 100
|
||||||
|
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
||||||
|
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||||
|
|||||||
@@ -5,20 +5,18 @@ package configurer
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func openUAPI(deviceName string) (net.Listener, error) {
|
func openUAPI(deviceName string) (net.Listener, error) {
|
||||||
uapiSock, err := ipc.UAPIOpen(deviceName)
|
uapiSock, err := ipc.UAPIOpen(deviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to open uapi socket: %v", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err := ipc.UAPIListen(deviceName, uapiSock)
|
listener, err := ipc.UAPIListen(deviceName, uapiSock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to listen on uapi socket: %v", err)
|
_ = uapiSock.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -54,6 +54,14 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
|
|||||||
return wgCfg
|
return wgCfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||||
|
return &WGUSPConfigurer{
|
||||||
|
device: device,
|
||||||
|
deviceName: deviceName,
|
||||||
|
activityRecorder: activityRecorder,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||||
log.Debugf("adding Wireguard private key")
|
log.Debugf("adding Wireguard private key")
|
||||||
key, err := wgtypes.ParseKey(privateKey)
|
key, err := wgtypes.ParseKey(privateKey)
|
||||||
|
|||||||
@@ -15,14 +15,17 @@ type PacketFilter interface {
|
|||||||
// FilterInbound filter incoming packets from external sources to host
|
// FilterInbound filter incoming packets from external sources to host
|
||||||
FilterInbound(packetData []byte, size int) bool
|
FilterInbound(packetData []byte, size int) bool
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
|
||||||
//
|
// Hook function returns true if the packet should be dropped.
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
// Only one UDP hook is supported; calling again replaces the previous hook.
|
||||||
// Hook function receives raw network packet data as argument.
|
// Pass nil hook to remove.
|
||||||
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||||
|
|
||||||
// RemovePacketHook removes hook by ID
|
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
|
||||||
RemovePacketHook(hookID string) error
|
// Hook function returns true if the packet should be dropped.
|
||||||
|
// Only one TCP hook is supported; calling again replaces the previous hook.
|
||||||
|
// Pass nil hook to remove.
|
||||||
|
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilteredDevice to override Read or Write of packets
|
// FilteredDevice to override Read or Write of packets
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||||
)
|
)
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cErr := tunIface.Close(); cErr != nil {
|
if cErr := tunIface.Close(); cErr != nil {
|
||||||
|
|||||||
@@ -217,7 +217,6 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
|||||||
// Close closes the tunnel interface
|
// Close closes the tunnel interface
|
||||||
func (w *WGIface) Close() error {
|
func (w *WGIface) Close() error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
|
||||||
@@ -225,7 +224,15 @@ func (w *WGIface) Close() error {
|
|||||||
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := w.tun.Close(); err != nil {
|
// Release w.mu before calling w.tun.Close(): the underlying
|
||||||
|
// wireguard-go device.Close() waits for its send/receive goroutines
|
||||||
|
// to drain. Some of those goroutines re-enter WGIface methods that
|
||||||
|
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
|
||||||
|
// holding the mutex here would deadlock the shutdown path.
|
||||||
|
tun := w.tun
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
if err := tun.Close(); err != nil {
|
||||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
113
client/iface/iface_close_test.go
Normal file
113
client/iface/iface_close_test.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeTunDevice implements WGTunDevice and lets the test control when
|
||||||
|
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
|
||||||
|
// until its goroutines drain. Some of those goroutines (e.g. the packet
|
||||||
|
// filter DNS hook in client/internal/dns) call back into WGIface, so if
|
||||||
|
// WGIface.Close() held w.mu across tun.Close() the shutdown would
|
||||||
|
// deadlock.
|
||||||
|
type fakeTunDevice struct {
|
||||||
|
closeStarted chan struct{}
|
||||||
|
unblockClose chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
|
||||||
|
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
|
||||||
|
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
|
||||||
|
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
|
||||||
|
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
|
||||||
|
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
|
||||||
|
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
|
||||||
|
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
|
||||||
|
|
||||||
|
func (f *fakeTunDevice) Close() error {
|
||||||
|
close(f.closeStarted)
|
||||||
|
<-f.unblockClose
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeProxyFactory struct{}
|
||||||
|
|
||||||
|
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
|
||||||
|
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
|
||||||
|
func (fakeProxyFactory) Free() error { return nil }
|
||||||
|
|
||||||
|
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
|
||||||
|
// that surfaces as a macOS test-timeout in
|
||||||
|
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
|
||||||
|
// while waiting for the wireguard-go device goroutines to finish, and
|
||||||
|
// one of those goroutines (the DNS filter hook) calls back into
|
||||||
|
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
|
||||||
|
// the lock before tun.Close() returns control.
|
||||||
|
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
|
||||||
|
tun := &fakeTunDevice{
|
||||||
|
closeStarted: make(chan struct{}),
|
||||||
|
unblockClose: make(chan struct{}),
|
||||||
|
}
|
||||||
|
w := &WGIface{
|
||||||
|
tun: tun,
|
||||||
|
wgProxyFactory: fakeProxyFactory{},
|
||||||
|
}
|
||||||
|
|
||||||
|
closeDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
closeDone <- w.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-tun.closeStarted:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
close(tun.unblockClose)
|
||||||
|
t.Fatal("tun.Close() was never invoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate the WireGuard read goroutine calling back into WGIface
|
||||||
|
// via the packet filter's DNS hook. If Close() still held w.mu
|
||||||
|
// during tun.Close(), this would block until the test timeout.
|
||||||
|
getDeviceDone := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = w.GetDevice()
|
||||||
|
close(getDeviceDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-getDeviceDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
close(tun.unblockClose)
|
||||||
|
wg.Wait()
|
||||||
|
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
|
||||||
|
}
|
||||||
|
|
||||||
|
close(tun.unblockClose)
|
||||||
|
select {
|
||||||
|
case <-closeDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
|||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddUDPPacketHook mocks base method.
|
// SetUDPPacketHook mocks base method.
|
||||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
|
||||||
ret0, _ := ret[0].(string)
|
|
||||||
return ret0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
|
||||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTCPPacketHook mocks base method.
|
||||||
|
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterInbound mocks base method.
|
// FilterInbound mocks base method.
|
||||||
@@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePacketHook mocks base method.
|
|
||||||
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePacketHook indicates an expected call of RemovePacketHook.
|
|
||||||
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,87 +0,0 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
|
|
||||||
|
|
||||||
// Package mocks is a generated GoMock package.
|
|
||||||
package mocks
|
|
||||||
|
|
||||||
import (
|
|
||||||
net "net"
|
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockPacketFilter is a mock of PacketFilter interface.
|
|
||||||
type MockPacketFilter struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *MockPacketFilterMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
|
|
||||||
type MockPacketFilterMockRecorder struct {
|
|
||||||
mock *MockPacketFilter
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockPacketFilter creates a new mock instance.
|
|
||||||
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
|
|
||||||
mock := &MockPacketFilter{ctrl: ctrl}
|
|
||||||
mock.recorder = &MockPacketFilterMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
|
||||||
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUDPPacketHook mocks base method.
|
|
||||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
|
||||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FilterInbound mocks base method.
|
|
||||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
|
||||||
ret0, _ := ret[0].(bool)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// FilterInbound indicates an expected call of FilterInbound.
|
|
||||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FilterOutbound mocks base method.
|
|
||||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
|
||||||
ret0, _ := ret[0].(bool)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
|
||||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetwork mocks base method.
|
|
||||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetwork indicates an expected call of SetNetwork.
|
|
||||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
|
||||||
}
|
|
||||||
@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if u.address.Network.Contains(a) {
|
if u.address.Network.Contains(a) {
|
||||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
|||||||
u.addrCache.Store(addr.String(), isRouted)
|
u.addrCache.Store(addr.String(), isRouted)
|
||||||
if isRouted {
|
if isRouted {
|
||||||
// Extra log, as the error only shows up with ICE logging enabled
|
// Extra log, as the error only shows up with ICE logging enabled
|
||||||
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
|
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||||
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ import (
|
|||||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
func TestDefaultManager(t *testing.T) {
|
func TestDefaultManager(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||||
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
{
|
{
|
||||||
@@ -135,6 +138,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
func TestDefaultManagerStateless(t *testing.T) {
|
func TestDefaultManagerStateless(t *testing.T) {
|
||||||
// stateless currently only in userspace, so we have to disable kernel
|
// stateless currently only in userspace, so we have to disable kernel
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||||
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
||||||
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
@@ -194,6 +198,7 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
// This tests the full ACL manager -> uspfilter integration.
|
// This tests the full ACL manager -> uspfilter integration.
|
||||||
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||||
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
@@ -258,6 +263,7 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
|||||||
// up when they're removed from the network map in a subsequent update.
|
// up when they're removed from the network map in a subsequent update.
|
||||||
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
@@ -339,6 +345,7 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
|||||||
// one added without leaking.
|
// one added without leaking.
|
||||||
func TestRuleUpdateChangingAction(t *testing.T) {
|
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
|||||||
var needsLogin bool
|
var needsLogin bool
|
||||||
|
|
||||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
if isLoginNeeded(err) {
|
if isLoginNeeded(err) {
|
||||||
needsLogin = true
|
needsLogin = true
|
||||||
return nil
|
return nil
|
||||||
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
|||||||
var isAuthError bool
|
var isAuthError bool
|
||||||
|
|
||||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
if serverKey != nil && isRegistrationNeeded(err) {
|
if isRegistrationNeeded(err) {
|
||||||
log.Debugf("peer registration required")
|
log.Debugf("peer registration required")
|
||||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
|||||||
|
|
||||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||||
serverKey, err := client.GetServerPublicKey()
|
protoFlow, err := client.GetPKCEAuthorizationFlow()
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||||
@@ -221,7 +215,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
|||||||
config := &PKCEAuthProviderConfig{
|
config := &PKCEAuthProviderConfig{
|
||||||
Audience: protoConfig.GetAudience(),
|
Audience: protoConfig.GetAudience(),
|
||||||
ClientID: protoConfig.GetClientID(),
|
ClientID: protoConfig.GetClientID(),
|
||||||
ClientSecret: protoConfig.GetClientSecret(),
|
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||||
Scope: protoConfig.GetScope(),
|
Scope: protoConfig.GetScope(),
|
||||||
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
|||||||
|
|
||||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||||
serverKey, err := client.GetServerPublicKey()
|
protoFlow, err := client.GetDeviceAuthorizationFlow()
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||||
@@ -266,7 +254,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
|||||||
config := &DeviceAuthProviderConfig{
|
config := &DeviceAuthProviderConfig{
|
||||||
Audience: protoConfig.GetAudience(),
|
Audience: protoConfig.GetAudience(),
|
||||||
ClientID: protoConfig.GetClientID(),
|
ClientID: protoConfig.GetClientID(),
|
||||||
ClientSecret: protoConfig.GetClientSecret(),
|
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||||
Domain: protoConfig.Domain,
|
Domain: protoConfig.Domain,
|
||||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||||
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// doMgmLogin performs the actual login operation with the management service
|
// doMgmLogin performs the actual login operation with the management service
|
||||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
|
||||||
serverKey, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx)
|
||||||
a.setSystemInfoFlags(sysInfo)
|
a.setSystemInfoFlags(sysInfo)
|
||||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||||
return serverKey, loginResp, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
// Otherwise tries to register with the provided setupKey via command line.
|
// Otherwise tries to register with the provided setupKey via command line.
|
||||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
validSetupKey, err := uuid.Parse(setupKey)
|
validSetupKey, err := uuid.Parse(setupKey)
|
||||||
if err != nil && jwtToken == "" {
|
if err != nil && jwtToken == "" {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||||
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
|
|||||||
log.Debugf("sending peer registration request to Management Service")
|
log.Debugf("sending peer registration request to Management Service")
|
||||||
info := system.GetInfo(ctx)
|
info := system.GetInfo(ctx)
|
||||||
a.setSystemInfoFlags(info)
|
a.setSystemInfoFlags(info)
|
||||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed registering peer %v", err)
|
log.Errorf("failed registering peer %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -23,12 +23,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
"github.com/netbirdio/netbird/client/internal/updater"
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
@@ -43,14 +44,19 @@ import (
|
|||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// androidRunOverride is set on Android to inject mobile dependencies
|
||||||
|
// when using embed.Client (which calls Run() with empty MobileDependency).
|
||||||
|
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
|
||||||
|
|
||||||
type ConnectClient struct {
|
type ConnectClient struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
config *profilemanager.Config
|
config *profilemanager.Config
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
doInitialAutoUpdate bool
|
|
||||||
|
|
||||||
engine *Engine
|
engine *Engine
|
||||||
engineMutex sync.Mutex
|
engineMutex sync.Mutex
|
||||||
|
clientMetrics *metrics.ClientMetrics
|
||||||
|
updateManager *updater.Manager
|
||||||
|
|
||||||
persistSyncResponse bool
|
persistSyncResponse bool
|
||||||
}
|
}
|
||||||
@@ -59,19 +65,24 @@ func NewConnectClient(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
config *profilemanager.Config,
|
config *profilemanager.Config,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
doInitalAutoUpdate bool,
|
|
||||||
) *ConnectClient {
|
) *ConnectClient {
|
||||||
return &ConnectClient{
|
return &ConnectClient{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
config: config,
|
config: config,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
doInitialAutoUpdate: doInitalAutoUpdate,
|
|
||||||
engineMutex: sync.Mutex{},
|
engineMutex: sync.Mutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
||||||
|
c.updateManager = um
|
||||||
|
}
|
||||||
|
|
||||||
// Run with main logic.
|
// Run with main logic.
|
||||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||||
|
if androidRunOverride != nil {
|
||||||
|
return androidRunOverride(c, runningChan, logPath)
|
||||||
|
}
|
||||||
return c.run(MobileDependency{}, runningChan, logPath)
|
return c.run(MobileDependency{}, runningChan, logPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,6 +94,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
dnsAddresses []netip.AddrPort,
|
dnsAddresses []netip.AddrPort,
|
||||||
dnsReadyListener dns.ReadyListener,
|
dnsReadyListener dns.ReadyListener,
|
||||||
stateFilePath string,
|
stateFilePath string,
|
||||||
|
cacheDir string,
|
||||||
) error {
|
) error {
|
||||||
// in case of non Android os these variables will be nil
|
// in case of non Android os these variables will be nil
|
||||||
mobileDependency := MobileDependency{
|
mobileDependency := MobileDependency{
|
||||||
@@ -92,6 +104,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
|
TempDir: cacheDir,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, "")
|
return c.run(mobileDependency, nil, "")
|
||||||
}
|
}
|
||||||
@@ -100,6 +113,7 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
fileDescriptor int32,
|
fileDescriptor int32,
|
||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsManager dns.IosDnsManager,
|
dnsManager dns.IosDnsManager,
|
||||||
|
dnsAddresses []netip.AddrPort,
|
||||||
stateFilePath string,
|
stateFilePath string,
|
||||||
) error {
|
) error {
|
||||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||||
@@ -109,6 +123,7 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
FileDescriptor: fileDescriptor,
|
FileDescriptor: fileDescriptor,
|
||||||
NetworkChangeListener: networkChangeListener,
|
NetworkChangeListener: networkChangeListener,
|
||||||
DnsManager: dnsManager,
|
DnsManager: dnsManager,
|
||||||
|
HostDNSAddresses: dnsAddresses,
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, "")
|
return c.run(mobileDependency, nil, "")
|
||||||
@@ -131,10 +146,34 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Stop metrics push on exit
|
||||||
|
defer func() {
|
||||||
|
if c.clientMetrics != nil {
|
||||||
|
c.clientMetrics.StopPush()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||||
|
|
||||||
nbnet.Init()
|
nbnet.Init()
|
||||||
|
|
||||||
|
// Initialize metrics once at startup (always active for debug bundles)
|
||||||
|
if c.clientMetrics == nil {
|
||||||
|
agentInfo := metrics.AgentInfo{
|
||||||
|
DeploymentType: metrics.DeploymentTypeUnknown,
|
||||||
|
Version: version.NetbirdVersion(),
|
||||||
|
OS: runtime.GOOS,
|
||||||
|
Arch: runtime.GOARCH,
|
||||||
|
}
|
||||||
|
c.clientMetrics = metrics.NewClientMetrics(agentInfo)
|
||||||
|
log.Debugf("initialized client metrics")
|
||||||
|
|
||||||
|
// Start metrics push if enabled (uses daemon context, persists across engine restarts)
|
||||||
|
if metrics.IsMetricsPushEnabled() {
|
||||||
|
c.clientMetrics.StartPush(c.ctx, metrics.PushConfigFromEnv())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@@ -187,15 +226,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
stateManager := statemanager.New(path)
|
stateManager := statemanager.New(path)
|
||||||
stateManager.RegisterState(&sshconfig.ShutdownState{})
|
stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||||
|
|
||||||
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
|
if c.updateManager != nil {
|
||||||
if err == nil {
|
c.updateManager.CheckUpdateSuccess(c.ctx)
|
||||||
updateManager.CheckUpdateSuccess(c.ctx)
|
}
|
||||||
|
|
||||||
inst := installer.New()
|
inst := installer.New()
|
||||||
if err := inst.CleanUpInstallerFiles(); err != nil {
|
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||||
log.Errorf("failed to clean up temporary installer file: %v", err)
|
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
defer c.statusRecorder.ClientStop()
|
defer c.statusRecorder.ClientStop()
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
@@ -222,6 +260,16 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
||||||
mgmClient.SetConnStateListener(mgmNotifier)
|
mgmClient.SetConnStateListener(mgmNotifier)
|
||||||
|
|
||||||
|
// Update metrics with actual deployment type after connection
|
||||||
|
deploymentType := metrics.DetermineDeploymentType(mgmClient.GetServerURL())
|
||||||
|
agentInfo := metrics.AgentInfo{
|
||||||
|
DeploymentType: deploymentType,
|
||||||
|
Version: version.NetbirdVersion(),
|
||||||
|
OS: runtime.GOOS,
|
||||||
|
Arch: runtime.GOARCH,
|
||||||
|
}
|
||||||
|
c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String())
|
||||||
|
|
||||||
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
||||||
defer func() {
|
defer func() {
|
||||||
if err = mgmClient.Close(); err != nil {
|
if err = mgmClient.Close(); err != nil {
|
||||||
@@ -230,8 +278,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
||||||
|
loginStarted := time.Now()
|
||||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false)
|
||||||
log.Debug(err)
|
log.Debug(err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
state.Set(StatusNeedsLogin)
|
state.Set(StatusNeedsLogin)
|
||||||
@@ -240,6 +290,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}
|
}
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true)
|
||||||
c.statusRecorder.MarkManagementConnected()
|
c.statusRecorder.MarkManagementConnected()
|
||||||
|
|
||||||
localPeerState := peer.LocalPeerState{
|
localPeerState := peer.LocalPeerState{
|
||||||
@@ -289,6 +340,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
engineConfig.TempDir = mobileDependency.TempDir
|
||||||
|
|
||||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||||
c.statusRecorder.SetRelayMgr(relayManager)
|
c.statusRecorder.SetRelayMgr(relayManager)
|
||||||
@@ -308,7 +360,16 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{
|
||||||
|
SignalClient: signalClient,
|
||||||
|
MgmClient: mgmClient,
|
||||||
|
RelayManager: relayManager,
|
||||||
|
StatusRecorder: c.statusRecorder,
|
||||||
|
Checks: checks,
|
||||||
|
StateManager: stateManager,
|
||||||
|
UpdateManager: c.updateManager,
|
||||||
|
ClientMetrics: c.clientMetrics,
|
||||||
|
}, mobileDependency)
|
||||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||||
c.engine = engine
|
c.engine = engine
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
@@ -318,21 +379,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
|
|
||||||
// AutoUpdate will be true when the user click on "Connect" menu on the UI
|
|
||||||
if c.doInitialAutoUpdate {
|
|
||||||
log.Infof("start engine by ui, run auto-update check")
|
|
||||||
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
|
|
||||||
c.doInitialAutoUpdate = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
if runningChan != nil {
|
if runningChan != nil {
|
||||||
|
select {
|
||||||
|
case <-runningChan:
|
||||||
|
default:
|
||||||
close(runningChan)
|
close(runningChan)
|
||||||
runningChan = nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
@@ -567,12 +622,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
|||||||
|
|
||||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx)
|
||||||
sysInfo.SetFlags(
|
sysInfo.SetFlags(
|
||||||
config.RosenpassEnabled,
|
config.RosenpassEnabled,
|
||||||
@@ -591,12 +640,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.EnableSSHRemotePortForwarding,
|
config.EnableSSHRemotePortForwarding,
|
||||||
config.DisableSSHAuth,
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return loginResp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||||
|
|||||||
73
client/internal/connect_android_default.go
Normal file
73
client/internal/connect_android_default.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
|
||||||
|
// It returns an empty interface list, which means ICE P2P candidates won't be
|
||||||
|
// discovered — connections will fall back to relay. Applications that need P2P
|
||||||
|
// should provide a real implementation via runOnAndroidEmbed that uses
|
||||||
|
// Android's ConnectivityManager to enumerate network interfaces.
|
||||||
|
type noopIFaceDiscover struct{}
|
||||||
|
|
||||||
|
func (noopIFaceDiscover) IFaces() (string, error) {
|
||||||
|
// Return empty JSON array — no local interfaces advertised for ICE.
|
||||||
|
// This is intentional: without Android's ConnectivityManager, we cannot
|
||||||
|
// reliably enumerate interfaces (netlink is restricted on Android 11+).
|
||||||
|
// Relay connections still work; only P2P hole-punching is disabled.
|
||||||
|
return "[]", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// noopNetworkChangeListener is a stub for embed.Client on Android.
|
||||||
|
// Network change events are ignored since the embed client manages its own
|
||||||
|
// reconnection logic via the engine's built-in retry mechanism.
|
||||||
|
type noopNetworkChangeListener struct{}
|
||||||
|
|
||||||
|
func (noopNetworkChangeListener) OnNetworkChanged(string) {
|
||||||
|
// No-op: embed.Client relies on the engine's internal reconnection
|
||||||
|
// logic rather than OS-level network change notifications.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (noopNetworkChangeListener) SetInterfaceIP(string) {
|
||||||
|
// No-op: in netstack mode, the overlay IP is managed by the userspace
|
||||||
|
// network stack, not by OS-level interface configuration.
|
||||||
|
}
|
||||||
|
|
||||||
|
// noopDnsReadyListener is a stub for embed.Client on Android.
|
||||||
|
// DNS readiness notifications are not needed in netstack/embed mode
|
||||||
|
// since system DNS is disabled and DNS resolution happens externally.
|
||||||
|
type noopDnsReadyListener struct{}
|
||||||
|
|
||||||
|
func (noopDnsReadyListener) OnReady() {
|
||||||
|
// No-op: embed.Client does not need DNS readiness notifications.
|
||||||
|
// System DNS is disabled in netstack mode.
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
|
||||||
|
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
|
||||||
|
var _ dns.ReadyListener = noopDnsReadyListener{}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Wire up the default override so embed.Client.Start() works on Android
|
||||||
|
// with netstack mode. Provides complete no-op stubs for all mobile
|
||||||
|
// dependencies so the engine's existing Android code paths work unchanged.
|
||||||
|
// Applications that need P2P ICE or real DNS should replace this by
|
||||||
|
// setting androidRunOverride before calling Start().
|
||||||
|
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
|
||||||
|
return c.runOnAndroidEmbed(
|
||||||
|
noopIFaceDiscover{},
|
||||||
|
noopNetworkChangeListener{},
|
||||||
|
[]netip.AddrPort{},
|
||||||
|
noopDnsReadyListener{},
|
||||||
|
runningChan,
|
||||||
|
logPath,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
32
client/internal/connect_android_embed.go
Normal file
32
client/internal/connect_android_embed.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
|
||||||
|
// so embed.Client.Start() can detect when the engine is ready.
|
||||||
|
// It provides complete MobileDependency so the engine's existing
|
||||||
|
// Android code paths work unchanged.
|
||||||
|
func (c *ConnectClient) runOnAndroidEmbed(
|
||||||
|
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||||
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
|
dnsAddresses []netip.AddrPort,
|
||||||
|
dnsReadyListener dns.ReadyListener,
|
||||||
|
runningChan chan struct{},
|
||||||
|
logPath string,
|
||||||
|
) error {
|
||||||
|
mobileDependency := MobileDependency{
|
||||||
|
IFaceDiscover: iFaceDiscover,
|
||||||
|
NetworkChangeListener: networkChangeListener,
|
||||||
|
HostDNSAddresses: dnsAddresses,
|
||||||
|
DnsReadyListener: dnsReadyListener,
|
||||||
|
}
|
||||||
|
return c.run(mobileDependency, runningChan, logPath)
|
||||||
|
}
|
||||||
60
client/internal/daemonaddr/resolve.go
Normal file
60
client/internal/daemonaddr/resolve.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
//go:build !windows && !ios && !android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var scanDir = "/var/run/netbird"
|
||||||
|
|
||||||
|
// setScanDir overrides the scan directory (used by tests).
|
||||||
|
func setScanDir(dir string) {
|
||||||
|
scanDir = dir
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
|
||||||
|
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
|
||||||
|
// mismatch between the netbird@.service template (which places the socket under
|
||||||
|
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
|
||||||
|
func ResolveUnixDaemonAddr(addr string) string {
|
||||||
|
if !strings.HasPrefix(addr, "unix://") {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
sockPath := strings.TrimPrefix(addr, "unix://")
|
||||||
|
if _, err := os.Stat(sockPath); err == nil {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(scanDir)
|
||||||
|
if err != nil {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
var found []string
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(e.Name(), ".sock") {
|
||||||
|
found = append(found, filepath.Join(scanDir, e.Name()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch len(found) {
|
||||||
|
case 1:
|
||||||
|
resolved := "unix://" + found[0]
|
||||||
|
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
|
||||||
|
return resolved
|
||||||
|
case 0:
|
||||||
|
return addr
|
||||||
|
default:
|
||||||
|
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
}
|
||||||
8
client/internal/daemonaddr/resolve_stub.go
Normal file
8
client/internal/daemonaddr/resolve_stub.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build windows || ios || android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
|
||||||
|
func ResolveUnixDaemonAddr(addr string) string {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
121
client/internal/daemonaddr/resolve_test.go
Normal file
121
client/internal/daemonaddr/resolve_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
//go:build !windows && !ios && !android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createSockFile creates a regular file with a .sock extension.
|
||||||
|
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
|
||||||
|
// sufficient and avoids Unix socket path-length limits on macOS.
|
||||||
|
func createSockFile(t *testing.T, path string) {
|
||||||
|
t.Helper()
|
||||||
|
if err := os.WriteFile(path, nil, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to create test sock file at %s: %v", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
sock := filepath.Join(tmp, "netbird.sock")
|
||||||
|
createSockFile(t, sock)
|
||||||
|
|
||||||
|
addr := "unix://" + sock
|
||||||
|
got := ResolveUnixDaemonAddr(addr)
|
||||||
|
if got != addr {
|
||||||
|
t.Errorf("expected %s, got %s", addr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
// Default socket does not exist
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
// Create a scan dir with one socket
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
instanceSock := filepath.Join(sd, "main.sock")
|
||||||
|
createSockFile(t, instanceSock)
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
expected := "unix://" + instanceSock
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
createSockFile(t, filepath.Join(sd, "main.sock"))
|
||||||
|
createSockFile(t, filepath.Join(sd, "other.sock"))
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
|
||||||
|
addr := "tcp://127.0.0.1:41731"
|
||||||
|
got := ResolveUnixDaemonAddr(addr)
|
||||||
|
if got != addr {
|
||||||
|
t.Errorf("expected %s, got %s", addr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(filepath.Join(tmp, "nonexistent"))
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"slices"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -25,13 +24,12 @@ import (
|
|||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const readmeContent = `Netbird debug bundle
|
const readmeContent = `Netbird debug bundle
|
||||||
@@ -53,6 +51,8 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
|
|||||||
config.txt: Anonymized configuration information of the NetBird client.
|
config.txt: Anonymized configuration information of the NetBird client.
|
||||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||||
|
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
|
||||||
|
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
|
||||||
mutex.prof: Mutex profiling information.
|
mutex.prof: Mutex profiling information.
|
||||||
goroutine.prof: Goroutine profiling information.
|
goroutine.prof: Goroutine profiling information.
|
||||||
block.prof: Block profiling information.
|
block.prof: Block profiling information.
|
||||||
@@ -219,6 +219,11 @@ const (
|
|||||||
darwinStdoutLogPath = "/var/log/netbird.err.log"
|
darwinStdoutLogPath = "/var/log/netbird.err.log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MetricsExporter is an interface for exporting metrics
|
||||||
|
type MetricsExporter interface {
|
||||||
|
Export(w io.Writer) error
|
||||||
|
}
|
||||||
|
|
||||||
type BundleGenerator struct {
|
type BundleGenerator struct {
|
||||||
anonymizer *anonymize.Anonymizer
|
anonymizer *anonymize.Anonymizer
|
||||||
|
|
||||||
@@ -227,8 +232,10 @@ type BundleGenerator struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
syncResponse *mgmProto.SyncResponse
|
syncResponse *mgmProto.SyncResponse
|
||||||
logPath string
|
logPath string
|
||||||
|
tempDir string
|
||||||
cpuProfile []byte
|
cpuProfile []byte
|
||||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
|
clientMetrics MetricsExporter
|
||||||
|
|
||||||
anonymize bool
|
anonymize bool
|
||||||
includeSystemInfo bool
|
includeSystemInfo bool
|
||||||
@@ -248,8 +255,10 @@ type GeneratorDependencies struct {
|
|||||||
StatusRecorder *peer.Status
|
StatusRecorder *peer.Status
|
||||||
SyncResponse *mgmProto.SyncResponse
|
SyncResponse *mgmProto.SyncResponse
|
||||||
LogPath string
|
LogPath string
|
||||||
|
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||||
CPUProfile []byte
|
CPUProfile []byte
|
||||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
|
ClientMetrics MetricsExporter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||||
@@ -266,8 +275,10 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
statusRecorder: deps.StatusRecorder,
|
statusRecorder: deps.StatusRecorder,
|
||||||
syncResponse: deps.SyncResponse,
|
syncResponse: deps.SyncResponse,
|
||||||
logPath: deps.LogPath,
|
logPath: deps.LogPath,
|
||||||
|
tempDir: deps.TempDir,
|
||||||
cpuProfile: deps.CPUProfile,
|
cpuProfile: deps.CPUProfile,
|
||||||
refreshStatus: deps.RefreshStatus,
|
refreshStatus: deps.RefreshStatus,
|
||||||
|
clientMetrics: deps.ClientMetrics,
|
||||||
|
|
||||||
anonymize: cfg.Anonymize,
|
anonymize: cfg.Anonymize,
|
||||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||||
@@ -277,7 +288,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
|
|
||||||
// Generate creates a debug bundle and returns the location.
|
// Generate creates a debug bundle and returns the location.
|
||||||
func (g *BundleGenerator) Generate() (resp string, err error) {
|
func (g *BundleGenerator) Generate() (resp string, err error) {
|
||||||
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
|
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("create zip file: %w", err)
|
return "", fmt.Errorf("create zip file: %w", err)
|
||||||
}
|
}
|
||||||
@@ -351,19 +362,20 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := g.addServiceParams(); err != nil {
|
||||||
|
log.Errorf("failed to add service params to debug bundle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.addMetrics(); err != nil {
|
||||||
|
log.Errorf("failed to add metrics to debug bundle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := g.addWgShow(); err != nil {
|
if err := g.addWgShow(); err != nil {
|
||||||
log.Errorf("failed to add wg show output: %v", err)
|
log.Errorf("failed to add wg show output: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
if err := g.addPlatformLog(); err != nil {
|
||||||
if err := g.addLogfile(); err != nil {
|
log.Errorf("failed to add logs to debug bundle: %v", err)
|
||||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
|
||||||
if err := g.trySystemdLogFallback(); err != nil {
|
|
||||||
log.Errorf("failed to add systemd logs as fallback: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if err := g.trySystemdLogFallback(); err != nil {
|
|
||||||
log.Errorf("failed to add systemd logs: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.addUpdateLogs(); err != nil {
|
if err := g.addUpdateLogs(); err != nil {
|
||||||
@@ -418,7 +430,10 @@ func (g *BundleGenerator) addStatus() error {
|
|||||||
fullStatus := g.statusRecorder.GetFullStatus()
|
fullStatus := g.statusRecorder.GetFullStatus()
|
||||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
|
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
|
||||||
|
Anonymize: g.anonymize,
|
||||||
|
ProfileName: profName,
|
||||||
|
})
|
||||||
statusOutput := overview.FullDetailSummary()
|
statusOutput := overview.FullDetailSummary()
|
||||||
|
|
||||||
statusReader := strings.NewReader(statusOutput)
|
statusReader := strings.NewReader(statusOutput)
|
||||||
@@ -473,6 +488,90 @@ func (g *BundleGenerator) addConfig() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
serviceParamsFile = "service.json"
|
||||||
|
serviceParamsBundle = "service_params.json"
|
||||||
|
maskedValue = "***"
|
||||||
|
envVarPrefix = "NB_"
|
||||||
|
jsonKeyManagementURL = "management_url"
|
||||||
|
jsonKeyServiceEnv = "service_env_vars"
|
||||||
|
)
|
||||||
|
|
||||||
|
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
|
||||||
|
|
||||||
|
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
|
||||||
|
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
|
||||||
|
func (g *BundleGenerator) addServiceParams() error {
|
||||||
|
path := filepath.Join(configs.StateDir, serviceParamsFile)
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("read service params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var params map[string]any
|
||||||
|
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||||
|
return fmt.Errorf("parse service params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.anonymize {
|
||||||
|
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
|
||||||
|
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
g.sanitizeServiceEnvVars(params)
|
||||||
|
|
||||||
|
sanitizedData, err := json.MarshalIndent(params, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal sanitized service params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
|
||||||
|
return fmt.Errorf("add service params to zip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
|
||||||
|
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
|
||||||
|
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
|
||||||
|
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
|
||||||
|
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sanitized := make(map[string]any, len(envVars))
|
||||||
|
for k, v := range envVars {
|
||||||
|
val, _ := v.(string)
|
||||||
|
switch {
|
||||||
|
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
|
||||||
|
sanitized[k] = maskedValue
|
||||||
|
case g.anonymize:
|
||||||
|
sanitized[k] = g.anonymizer.AnonymizeString(val)
|
||||||
|
default:
|
||||||
|
sanitized[k] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
params[jsonKeyServiceEnv] = sanitized
|
||||||
|
}
|
||||||
|
|
||||||
|
// isSensitiveEnvVar returns true for env var names that may contain secrets.
|
||||||
|
func isSensitiveEnvVar(key string) bool {
|
||||||
|
lower := strings.ToLower(key)
|
||||||
|
for _, s := range sensitiveEnvSubstrings {
|
||||||
|
if strings.Contains(lower, s) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
||||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||||
|
|
||||||
@@ -744,6 +843,30 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addMetrics() error {
|
||||||
|
if g.clientMetrics == nil {
|
||||||
|
log.Debugf("skipping metrics in debug bundle: no metrics collector")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := g.clientMetrics.Export(&buf); err != nil {
|
||||||
|
return fmt.Errorf("export metrics: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if buf.Len() == 0 {
|
||||||
|
log.Debugf("skipping metrics.txt in debug bundle: no metrics data")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.addFileToZip(&buf, "metrics.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add metrics file to zip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("added metrics to debug bundle")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addLogfile() error {
|
func (g *BundleGenerator) addLogfile() error {
|
||||||
if g.logPath == "" {
|
if g.logPath == "" {
|
||||||
log.Debugf("skipping empty log file in debug bundle")
|
log.Debugf("skipping empty log file in debug bundle")
|
||||||
|
|||||||
41
client/internal/debug/debug_android.go
Normal file
41
client/internal/debug/debug_android.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os/exec"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addPlatformLog() error {
|
||||||
|
cmd := exec.Command("/system/bin/logcat", "-d")
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("logcat stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("start logcat: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var logReader io.Reader = stdout
|
||||||
|
if g.anonymize {
|
||||||
|
var pw *io.PipeWriter
|
||||||
|
logReader, pw = io.Pipe()
|
||||||
|
go anonymizeLog(stdout, pw, g.anonymizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add logcat to zip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
return fmt.Errorf("wait logcat: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("added logcat output to debug bundle")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
25
client/internal/debug/debug_nonandroid.go
Normal file
25
client/internal/debug/debug_nonandroid.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addPlatformLog() error {
|
||||||
|
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||||
|
if err := g.addLogfile(); err != nil {
|
||||||
|
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||||
|
if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,8 +1,12 @@
|
|||||||
package debug
|
package debug
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"archive/zip"
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -10,6 +14,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsSensitiveEnvVar(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
key string
|
||||||
|
sensitive bool
|
||||||
|
}{
|
||||||
|
{"NB_SETUP_KEY", true},
|
||||||
|
{"NB_API_TOKEN", true},
|
||||||
|
{"NB_CLIENT_SECRET", true},
|
||||||
|
{"NB_PASSWORD", true},
|
||||||
|
{"NB_CREDENTIAL", true},
|
||||||
|
{"NB_LOG_LEVEL", false},
|
||||||
|
{"NB_MANAGEMENT_URL", false},
|
||||||
|
{"NB_HOSTNAME", false},
|
||||||
|
{"HOME", false},
|
||||||
|
{"PATH", false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.key, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
anonymize bool
|
||||||
|
input map[string]any
|
||||||
|
check func(t *testing.T, params map[string]any)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no env vars key",
|
||||||
|
anonymize: false,
|
||||||
|
input: map[string]any{"management_url": "https://mgmt.example.com"},
|
||||||
|
check: func(t *testing.T, params map[string]any) {
|
||||||
|
t.Helper()
|
||||||
|
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
|
||||||
|
_, ok := params[jsonKeyServiceEnv]
|
||||||
|
assert.False(t, ok, "service_env_vars should not be added")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-NB vars are masked",
|
||||||
|
anonymize: false,
|
||||||
|
input: map[string]any{
|
||||||
|
jsonKeyServiceEnv: map[string]any{
|
||||||
|
"HOME": "/root",
|
||||||
|
"PATH": "/usr/bin",
|
||||||
|
"NB_LOG_LEVEL": "debug",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: func(t *testing.T, params map[string]any) {
|
||||||
|
t.Helper()
|
||||||
|
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||||
|
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
|
||||||
|
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
|
||||||
|
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sensitive NB vars are masked",
|
||||||
|
anonymize: false,
|
||||||
|
input: map[string]any{
|
||||||
|
jsonKeyServiceEnv: map[string]any{
|
||||||
|
"NB_SETUP_KEY": "abc123",
|
||||||
|
"NB_API_TOKEN": "tok_xyz",
|
||||||
|
"NB_LOG_LEVEL": "info",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: func(t *testing.T, params map[string]any) {
|
||||||
|
t.Helper()
|
||||||
|
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||||
|
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
|
||||||
|
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
|
||||||
|
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "safe NB vars anonymized when anonymize is true",
|
||||||
|
anonymize: true,
|
||||||
|
input: map[string]any{
|
||||||
|
jsonKeyServiceEnv: map[string]any{
|
||||||
|
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
|
||||||
|
"NB_LOG_LEVEL": "debug",
|
||||||
|
"NB_SETUP_KEY": "secret",
|
||||||
|
"SOME_OTHER": "val",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
check: func(t *testing.T, params map[string]any) {
|
||||||
|
t.Helper()
|
||||||
|
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||||
|
// Safe NB_ values should be anonymized (not the original, not masked)
|
||||||
|
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
|
||||||
|
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
|
||||||
|
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
|
||||||
|
|
||||||
|
logVal := env["NB_LOG_LEVEL"].(string)
|
||||||
|
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
|
||||||
|
|
||||||
|
// Sensitive and non-NB_ still masked
|
||||||
|
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
|
||||||
|
assert.Equal(t, maskedValue, env["SOME_OTHER"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||||
|
g := &BundleGenerator{
|
||||||
|
anonymize: tt.anonymize,
|
||||||
|
anonymizer: anonymizer,
|
||||||
|
}
|
||||||
|
g.sanitizeServiceEnvVars(tt.input)
|
||||||
|
tt.check(t, tt.input)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddServiceParams(t *testing.T) {
|
||||||
|
t.Run("missing service.json returns nil", func(t *testing.T) {
|
||||||
|
g := &BundleGenerator{
|
||||||
|
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||||
|
}
|
||||||
|
|
||||||
|
origStateDir := configs.StateDir
|
||||||
|
configs.StateDir = t.TempDir()
|
||||||
|
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||||
|
|
||||||
|
err := g.addServiceParams()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
origStateDir := configs.StateDir
|
||||||
|
configs.StateDir = dir
|
||||||
|
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||||
|
|
||||||
|
input := map[string]any{
|
||||||
|
jsonKeyManagementURL: "https://api.example.com:443",
|
||||||
|
jsonKeyServiceEnv: map[string]any{
|
||||||
|
"NB_LOG_LEVEL": "trace",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
zw := zip.NewWriter(&buf)
|
||||||
|
|
||||||
|
g := &BundleGenerator{
|
||||||
|
anonymize: true,
|
||||||
|
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||||
|
archive: zw,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, g.addServiceParams())
|
||||||
|
require.NoError(t, zw.Close())
|
||||||
|
|
||||||
|
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, zr.File, 1)
|
||||||
|
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
|
||||||
|
|
||||||
|
rc, err := zr.File[0].Open()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
var result map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||||
|
|
||||||
|
mgmt := result[jsonKeyManagementURL].(string)
|
||||||
|
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
|
||||||
|
assert.NotEmpty(t, mgmt)
|
||||||
|
|
||||||
|
env := result[jsonKeyServiceEnv].(map[string]any)
|
||||||
|
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
origStateDir := configs.StateDir
|
||||||
|
configs.StateDir = dir
|
||||||
|
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||||
|
|
||||||
|
input := map[string]any{
|
||||||
|
jsonKeyManagementURL: "https://api.example.com:443",
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
zw := zip.NewWriter(&buf)
|
||||||
|
|
||||||
|
g := &BundleGenerator{
|
||||||
|
anonymize: false,
|
||||||
|
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||||
|
archive: zw,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, g.addServiceParams())
|
||||||
|
require.NoError(t, zw.Close())
|
||||||
|
|
||||||
|
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rc, err := zr.File[0].Open()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
var result map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||||
|
|
||||||
|
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to check if IP is in CGNAT range
|
// Helper function to check if IP is in CGNAT range
|
||||||
func isInCGNATRange(ip net.IP) bool {
|
func isInCGNATRange(ip net.IP) bool {
|
||||||
cgnat := net.IPNet{
|
cgnat := net.IPNet{
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ package debug
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Skip("Skipping upload test on docker ci")
|
t.Skip("Skipping upload test on docker ci")
|
||||||
}
|
}
|
||||||
testDir := t.TempDir()
|
testDir := t.TempDir()
|
||||||
testURL := "http://localhost:8080"
|
addr := reserveLoopbackPort(t)
|
||||||
|
testURL := "http://" + addr
|
||||||
t.Setenv("SERVER_URL", testURL)
|
t.Setenv("SERVER_URL", testURL)
|
||||||
|
t.Setenv("SERVER_ADDRESS", addr)
|
||||||
t.Setenv("STORE_DIR", testDir)
|
t.Setenv("STORE_DIR", testDir)
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Errorf("Failed to stop server: %v", err)
|
t.Errorf("Failed to stop server: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
waitForServer(t, addr)
|
||||||
|
|
||||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||||
fileContent := []byte("test file content")
|
fileContent := []byte("test file content")
|
||||||
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, fileContent, createdFileContent)
|
require.Equal(t, fileContent, createdFileContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
||||||
|
// address, then releases it so the server under test can rebind. The close/
|
||||||
|
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
||||||
|
// it's essentially never contended in practice.
|
||||||
|
func reserveLoopbackPort(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
addr := l.Addr().String()
|
||||||
|
require.NoError(t, l.Close())
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForServer(t *testing.T, addr string) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
_ = c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatalf("server did not start listening on %s in time", addr)
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultResolvConfPath = "/etc/resolv.conf"
|
defaultResolvConfPath = "/etc/resolv.conf"
|
||||||
|
nsswitchConfPath = "/etc/nsswitch.conf"
|
||||||
)
|
)
|
||||||
|
|
||||||
type resolvConf struct {
|
type resolvConf struct {
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -73,6 +76,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
w.response = m
|
w.response = m
|
||||||
|
if m.MsgHdr.Truncated {
|
||||||
|
w.SetMeta("truncated", "true")
|
||||||
|
}
|
||||||
return w.ResponseWriter.WriteMsg(m)
|
return w.ResponseWriter.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,16 +195,26 @@ func (c *HandlerChain) logHandlers() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
c.dispatch(w, r, math.MaxInt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatch routes a DNS request through the chain, skipping handlers with
|
||||||
|
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
||||||
|
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
requestID := resutil.GenerateRequestID()
|
requestID := resutil.GenerateRequestID()
|
||||||
logger := log.WithFields(log.Fields{
|
fields := log.Fields{
|
||||||
"request_id": requestID,
|
"request_id": requestID,
|
||||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||||
})
|
}
|
||||||
|
if addr := w.RemoteAddr(); addr != nil {
|
||||||
|
fields["client"] = addr.String()
|
||||||
|
}
|
||||||
|
logger := log.WithFields(fields)
|
||||||
|
|
||||||
question := r.Question[0]
|
question := r.Question[0]
|
||||||
qname := strings.ToLower(question.Name)
|
qname := strings.ToLower(question.Name)
|
||||||
@@ -209,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
// Try handlers in priority order
|
// Try handlers in priority order
|
||||||
for _, entry := range handlers {
|
for _, entry := range handlers {
|
||||||
|
if entry.Priority > maxPriority {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !c.isHandlerMatch(qname, entry) {
|
if !c.isHandlerMatch(qname, entry) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -261,9 +280,58 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
|||||||
meta += " " + k + "=" + v
|
meta += " " + k + "=" + v
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
|
||||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||||
meta, time.Since(startTime))
|
cw.response.Len(), meta, time.Since(startTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
||||||
|
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
||||||
|
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
||||||
|
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
||||||
|
// (bounded by the invoked handler's internal timeout).
|
||||||
|
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty question")
|
||||||
|
}
|
||||||
|
|
||||||
|
base := &internalResponseWriter{}
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
c.dispatch(base, r, maxPriority)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
||||||
|
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
||||||
|
strings.ToLower(r.Question[0].Name), maxPriority)
|
||||||
|
}
|
||||||
|
return base.response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
||||||
|
// priority ≤ maxPriority.
|
||||||
|
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, h := range c.handlers {
|
||||||
|
if h.Pattern == "." && h.Priority <= maxPriority {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||||
@@ -284,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
||||||
|
type internalResponseWriter struct {
|
||||||
|
response *dns.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
||||||
|
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
|
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
|
|
||||||
|
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
||||||
|
// still surface their answer to ResolveInternal.
|
||||||
|
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
if err := msg.Unpack(p); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
w.response = msg
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *internalResponseWriter) Close() error { return nil }
|
||||||
|
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
||||||
|
|
||||||
|
// TsigTimersOnly is part of dns.ResponseWriter.
|
||||||
|
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
||||||
|
// no-op: in-process queries carry no TSIG state.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack is part of dns.ResponseWriter.
|
||||||
|
func (w *internalResponseWriter) Hijack() {
|
||||||
|
// no-op: in-process queries have no underlying connection to hand off.
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
package dns_test
|
package dns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
@@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
||||||
|
// which handler ResolveInternal dispatches to.
|
||||||
|
type answeringHandler struct {
|
||||||
|
name string
|
||||||
|
ip string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(r)
|
||||||
|
resp.Answer = []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP(h.ip).To4(),
|
||||||
|
}}
|
||||||
|
_ = w.WriteMsg(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *answeringHandler) String() string { return h.name }
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||||
|
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||||
|
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
assert.Equal(t, 1, len(resp.Answer))
|
||||||
|
a, ok := resp.Answer[0].(*dns.A)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||||
|
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||||
|
assert.Error(t, err, "no handler at or below maxPriority should error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
||||||
|
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
||||||
|
type rawWriteHandler struct {
|
||||||
|
ip string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(r)
|
||||||
|
resp.Answer = []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP(h.ip).To4(),
|
||||||
|
}}
|
||||||
|
packed, err := resp.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = w.Write(packed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Len(t, resp.Answer, 1)
|
||||||
|
a, ok := resp.Answer[0].(*dns.A)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
||||||
|
type hangingHandler struct {
|
||||||
|
block chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
<-h.block
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(r)
|
||||||
|
_ = w.WriteMsg(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hangingHandler) String() string { return "hangingHandler" }
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
h := &hangingHandler{block: make(chan struct{})}
|
||||||
|
defer close(h.block)
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
||||||
|
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
||||||
|
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
||||||
|
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
||||||
|
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
||||||
|
|
||||||
|
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||||
|
|
||||||
|
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
||||||
|
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
||||||
|
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
||||||
|
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
||||||
|
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
@@ -22,6 +24,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
||||||
|
netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS"
|
||||||
globalIPv4State = "State:/Network/Global/IPv4"
|
globalIPv4State = "State:/Network/Global/IPv4"
|
||||||
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
||||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||||
@@ -35,6 +38,14 @@ const (
|
|||||||
searchSuffix = "Search"
|
searchSuffix = "Search"
|
||||||
matchSuffix = "Match"
|
matchSuffix = "Match"
|
||||||
localSuffix = "Local"
|
localSuffix = "Local"
|
||||||
|
|
||||||
|
// maxDomainsPerResolverEntry is the max number of domains per scutil resolver key.
|
||||||
|
// scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap.
|
||||||
|
maxDomainsPerResolverEntry = 50
|
||||||
|
|
||||||
|
// maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key.
|
||||||
|
// scutil has an undocumented ~2048 byte value buffer; we stay well under it.
|
||||||
|
maxDomainBytesPerResolverEntry = 1500
|
||||||
)
|
)
|
||||||
|
|
||||||
type systemConfigurator struct {
|
type systemConfigurator struct {
|
||||||
@@ -84,29 +95,24 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
if err := s.removeKeysContaining(matchSuffix); err != nil {
|
||||||
var err error
|
log.Warnf("failed to remove old match keys: %v", err)
|
||||||
if len(matchDomains) != 0 {
|
|
||||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
|
||||||
} else {
|
|
||||||
log.Infof("removing match domains from the system")
|
|
||||||
err = s.removeKeyFromSystemConfig(matchKey)
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if len(matchDomains) != 0 {
|
||||||
|
if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil {
|
||||||
return fmt.Errorf("add match domains: %w", err)
|
return fmt.Errorf("add match domains: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
s.updateState(stateManager)
|
s.updateState(stateManager)
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
if err := s.removeKeysContaining(searchSuffix); err != nil {
|
||||||
if len(searchDomains) != 0 {
|
log.Warnf("failed to remove old search keys: %v", err)
|
||||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
|
|
||||||
} else {
|
|
||||||
log.Infof("removing search domains from the system")
|
|
||||||
err = s.removeKeyFromSystemConfig(searchKey)
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if len(searchDomains) != 0 {
|
||||||
|
if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil {
|
||||||
return fmt.Errorf("add search domains: %w", err)
|
return fmt.Errorf("add search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
s.updateState(stateManager)
|
s.updateState(stateManager)
|
||||||
|
|
||||||
if err := s.flushDNSCache(); err != nil {
|
if err := s.flushDNSCache(); err != nil {
|
||||||
@@ -149,8 +155,7 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
|||||||
|
|
||||||
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||||
if len(s.createdKeys) == 0 {
|
if len(s.createdKeys) == 0 {
|
||||||
// return defaults for startup calls
|
return s.discoverExistingKeys()
|
||||||
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := make([]string, 0, len(s.createdKeys))
|
keys := make([]string, 0, len(s.createdKeys))
|
||||||
@@ -160,6 +165,47 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
|||||||
return keys
|
return keys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist.
|
||||||
|
// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown).
|
||||||
|
func (s *systemConfigurator) discoverExistingKeys() []string {
|
||||||
|
dnsKeys, err := getSystemDNSKeys()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get system DNS keys: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var keys []string
|
||||||
|
|
||||||
|
for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} {
|
||||||
|
key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix)
|
||||||
|
if strings.Contains(dnsKeys, key) {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, suffix := range []string{searchSuffix, matchSuffix} {
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||||
|
if !strings.Contains(dnsKeys, key) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemDNSKeys gets all DNS keys
|
||||||
|
func getSystemDNSKeys() (string, error) {
|
||||||
|
command := "list .*DNS\nquit\n"
|
||||||
|
out, err := runSystemConfigCommand(command)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(out), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||||
line := buildRemoveKeyOperation(key)
|
line := buildRemoveKeyOperation(key)
|
||||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||||
@@ -184,12 +230,11 @@ func (s *systemConfigurator) addLocalDNS() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.addSearchDomains(
|
domainsStr := strings.Join(s.systemDNSSettings.Domains, " ")
|
||||||
localKey,
|
if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil {
|
||||||
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
return fmt.Errorf("add local dns state: %w", err)
|
||||||
); err != nil {
|
|
||||||
return fmt.Errorf("add search domains: %w", err)
|
|
||||||
}
|
}
|
||||||
|
s.createdKeys[localKey] = struct{}{}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -280,28 +325,77 @@ func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
|||||||
return slices.Clone(s.origNameservers)
|
return slices.Clone(s.origNameservers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
// splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits.
|
||||||
err := s.addDNSState(key, domains, ip, port, true)
|
func splitDomainsIntoBatches(domains []string) [][]string {
|
||||||
if err != nil {
|
if len(domains) == 0 {
|
||||||
return fmt.Errorf("add dns state: %w", err)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
var batches [][]string
|
||||||
|
var current []string
|
||||||
|
currentBytes := 0
|
||||||
|
|
||||||
s.createdKeys[key] = struct{}{}
|
for _, d := range domains {
|
||||||
|
domainLen := len(d)
|
||||||
|
newBytes := currentBytes + domainLen
|
||||||
|
if currentBytes > 0 {
|
||||||
|
newBytes++ // space separator
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) {
|
||||||
|
batches = append(batches, current)
|
||||||
|
current = nil
|
||||||
|
currentBytes = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
current = append(current, d)
|
||||||
|
if currentBytes > 0 {
|
||||||
|
currentBytes += 1 + domainLen
|
||||||
|
} else {
|
||||||
|
currentBytes = domainLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(current) > 0 {
|
||||||
|
batches = append(batches, current)
|
||||||
|
}
|
||||||
|
|
||||||
|
return batches
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error {
|
// removeKeysContaining removes all created keys that contain the given substring.
|
||||||
err := s.addDNSState(key, domains, dnsServer, port, false)
|
func (s *systemConfigurator) removeKeysContaining(suffix string) error {
|
||||||
if err != nil {
|
var toRemove []string
|
||||||
return fmt.Errorf("add dns state: %w", err)
|
for key := range s.createdKeys {
|
||||||
|
if strings.Contains(key, suffix) {
|
||||||
|
toRemove = append(toRemove, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var multiErr *multierror.Error
|
||||||
|
for _, key := range toRemove {
|
||||||
|
if err := s.removeKeyFromSystemConfig(key); err != nil {
|
||||||
|
multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(multiErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch.
|
||||||
|
func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error {
|
||||||
|
batches := splitDomainsIntoBatches(domains)
|
||||||
|
|
||||||
|
for i, batch := range batches {
|
||||||
|
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||||
|
domainsStr := strings.Join(batch, " ")
|
||||||
|
|
||||||
|
if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil {
|
||||||
|
return fmt.Errorf("add dns state for batch %d: %w", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
|
||||||
|
|
||||||
s.createdKeys[key] = struct{}{}
|
s.createdKeys[key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -364,7 +458,6 @@ func (s *systemConfigurator) flushDNSCache() error {
|
|||||||
if out, err := cmd.CombinedOutput(); err != nil {
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("flushed DNS cache")
|
log.Info("flushed DNS cache")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,10 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -49,17 +52,22 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, sm.PersistState(context.Background()))
|
require.NoError(t, sm.PersistState(context.Background()))
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
|
||||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
|
|
||||||
|
// Collect all created keys for cleanup verification
|
||||||
|
createdKeys := make([]string, 0, len(configurator.createdKeys))
|
||||||
|
for key := range configurator.createdKeys {
|
||||||
|
createdKeys = append(createdKeys, key)
|
||||||
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for _, key := range createdKeys {
|
||||||
_ = removeTestDNSKey(key)
|
_ = removeTestDNSKey(key)
|
||||||
}
|
}
|
||||||
|
_ = removeTestDNSKey(localKey)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for _, key := range createdKeys {
|
||||||
exists, err := checkDNSKeyExists(key)
|
exists, err := checkDNSKeyExists(key)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if exists {
|
if exists {
|
||||||
@@ -83,13 +91,223 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
|||||||
err = shutdownState.Cleanup()
|
err = shutdownState.Cleanup()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for _, key := range createdKeys {
|
||||||
exists, err := checkDNSKeyExists(key)
|
exists, err := checkDNSKeyExists(key)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc.
|
||||||
|
func generateShortDomains(count int) []string {
|
||||||
|
domains := make([]string, 0, count)
|
||||||
|
for i := range count {
|
||||||
|
label := ""
|
||||||
|
n := i
|
||||||
|
for {
|
||||||
|
label = string(rune('a'+n%26)) + label
|
||||||
|
n = n/26 - 1
|
||||||
|
if n < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
domains = append(domains, label+".com")
|
||||||
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com
|
||||||
|
func generateLongDomains(count int) []string {
|
||||||
|
domains := make([]string, 0, count)
|
||||||
|
for i := range count {
|
||||||
|
domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i))
|
||||||
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key.
|
||||||
|
func readDomainsFromKey(t *testing.T, key string) []string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
cmd := exec.Command(scutilPath)
|
||||||
|
cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key))
|
||||||
|
out, err := cmd.Output()
|
||||||
|
require.NoError(t, err, "scutil show should succeed")
|
||||||
|
|
||||||
|
var domains []string
|
||||||
|
inArray := false
|
||||||
|
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "<array>") {
|
||||||
|
inArray = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if inArray {
|
||||||
|
if line == "}" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// lines look like: "0 : a.com"
|
||||||
|
parts := strings.SplitN(line, " : ", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
domains = append(domains, parts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NoError(t, scanner.Err())
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitDomainsIntoBatches(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
domains []string
|
||||||
|
expectedCount int
|
||||||
|
checkAllPresent bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
domains: nil,
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "under_limit",
|
||||||
|
domains: generateShortDomains(10),
|
||||||
|
expectedCount: 1,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "at_element_limit",
|
||||||
|
domains: generateShortDomains(50),
|
||||||
|
expectedCount: 1,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "over_element_limit",
|
||||||
|
domains: generateShortDomains(51),
|
||||||
|
expectedCount: 2,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "triple_element_limit",
|
||||||
|
domains: generateShortDomains(150),
|
||||||
|
expectedCount: 3,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "long_domains_hit_byte_limit",
|
||||||
|
domains: generateLongDomains(50),
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "500_short_domains",
|
||||||
|
domains: generateShortDomains(500),
|
||||||
|
expectedCount: 10,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "500_long_domains",
|
||||||
|
domains: generateLongDomains(500),
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
batches := splitDomainsIntoBatches(tc.domains)
|
||||||
|
|
||||||
|
if tc.expectedCount > 0 {
|
||||||
|
assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify each batch respects limits
|
||||||
|
for i, batch := range batches {
|
||||||
|
assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry,
|
||||||
|
"batch %d exceeds element limit", i)
|
||||||
|
|
||||||
|
totalBytes := 0
|
||||||
|
for j, d := range batch {
|
||||||
|
if j > 0 {
|
||||||
|
totalBytes++
|
||||||
|
}
|
||||||
|
totalBytes += len(d)
|
||||||
|
}
|
||||||
|
assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry,
|
||||||
|
"batch %d exceeds byte limit (%d bytes)", i, totalBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.checkAllPresent {
|
||||||
|
var all []string
|
||||||
|
for _, batch := range batches {
|
||||||
|
all = append(all, batch...)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.domains, all, "all domains should be present in order")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism
|
||||||
|
// and verifies all domains are readable across multiple scutil keys.
|
||||||
|
func TestMatchDomainBatching(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping scutil integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
count int
|
||||||
|
generator func(int) []string
|
||||||
|
}{
|
||||||
|
{"short_10", 10, generateShortDomains},
|
||||||
|
{"short_50", 50, generateShortDomains},
|
||||||
|
{"short_100", 100, generateShortDomains},
|
||||||
|
{"short_200", 200, generateShortDomains},
|
||||||
|
{"short_500", 500, generateShortDomains},
|
||||||
|
{"long_10", 10, generateLongDomains},
|
||||||
|
{"long_50", 50, generateLongDomains},
|
||||||
|
{"long_100", 100, generateLongDomains},
|
||||||
|
{"long_200", 200, generateLongDomains},
|
||||||
|
{"long_500", 500, generateLongDomains},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
for key := range configurator.createdKeys {
|
||||||
|
_ = removeTestDNSKey(key)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
domains := tc.generator(tc.count)
|
||||||
|
err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
batches := splitDomainsIntoBatches(domains)
|
||||||
|
t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches))
|
||||||
|
|
||||||
|
// Read back all domains from all batched keys
|
||||||
|
var got []string
|
||||||
|
for i := range batches {
|
||||||
|
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i)
|
||||||
|
exists, err := checkDNSKeyExists(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, exists, "key %s should exist", key)
|
||||||
|
|
||||||
|
got = append(got, readDomainsFromKey(t, key)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches))
|
||||||
|
assert.Equal(t, tc.count, len(got), "all domains should be readable")
|
||||||
|
assert.Equal(t, domains, got, "domains should match in order")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkDNSKeyExists(key string) (bool, error) {
|
func checkDNSKeyExists(key string) (bool, error) {
|
||||||
cmd := exec.Command(scutilPath)
|
cmd := exec.Command(scutilPath)
|
||||||
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||||
@@ -158,15 +376,15 @@ func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Man
|
|||||||
createdKeys: make(map[string]struct{}),
|
createdKeys: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
|
||||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
|
||||||
|
|
||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
_ = sm.Stop(context.Background())
|
_ = sm.Stop(context.Background())
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for key := range configurator.createdKeys {
|
||||||
_ = removeTestDNSKey(key)
|
_ = removeTestDNSKey(key)
|
||||||
}
|
}
|
||||||
|
// Also clean up old-format keys and local key in case they exist
|
||||||
|
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix))
|
||||||
|
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix))
|
||||||
|
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix))
|
||||||
}
|
}
|
||||||
|
|
||||||
return configurator, sm, cleanup
|
return configurator, sm, cleanup
|
||||||
|
|||||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface string) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, reason, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("System DNS manager discovered: %s", osManager)
|
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
||||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOSDNSManagerType() (osManagerType, error) {
|
func getOSDNSManagerType() (osManagerType, string, error) {
|
||||||
|
resolved := isSystemdResolvedRunning()
|
||||||
|
nss := isLibnssResolveUsed()
|
||||||
|
stub := checkStub()
|
||||||
|
|
||||||
|
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
||||||
|
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
||||||
|
// that go through nss-resolve, and in foreign mode they can loop back
|
||||||
|
// through resolved as an upstream.
|
||||||
|
if resolved && (nss || stub) {
|
||||||
|
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, reason, rejected, err := scanResolvConfHeader()
|
||||||
|
if err != nil {
|
||||||
|
return 0, "", err
|
||||||
|
}
|
||||||
|
if reason != "" {
|
||||||
|
return mgr, reason, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
||||||
|
if len(rejected) > 0 {
|
||||||
|
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
||||||
|
}
|
||||||
|
return fileManager, fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
||||||
|
// matching manager. If reason is empty the caller should pick file mode and
|
||||||
|
// use rejected for diagnostics.
|
||||||
|
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||||
file, err := os.Open(defaultResolvConfPath)
|
file, err := os.Open(defaultResolvConfPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := file.Close(); err != nil {
|
if cerr := file.Close(); cerr != nil {
|
||||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var rejected []string
|
||||||
scanner := bufio.NewScanner(file)
|
scanner := bufio.NewScanner(file)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
text := scanner.Text()
|
text := scanner.Text()
|
||||||
@@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if text[0] != '#' {
|
if text[0] != '#' {
|
||||||
return fileManager, nil
|
break
|
||||||
}
|
}
|
||||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
||||||
return netbirdManager, nil
|
return mgr, reason, nil, nil
|
||||||
}
|
} else if rej != "" {
|
||||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
rejected = append(rejected, rej)
|
||||||
return networkManager, nil
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
|
||||||
if checkStub() {
|
|
||||||
return systemdManager, nil
|
|
||||||
} else {
|
|
||||||
return fileManager, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "resolvconf") {
|
|
||||||
if isSystemdResolveConfMode() {
|
|
||||||
return systemdManager, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return resolvConfManager, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||||
return 0, fmt.Errorf("scan: %w", err)
|
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
||||||
}
|
}
|
||||||
|
return 0, "", rejected, nil
|
||||||
return fileManager, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
// matchResolvConfHeader inspects a single comment line. Returns either a
|
||||||
|
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
||||||
|
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
||||||
|
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||||
|
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "NetworkManager") {
|
||||||
|
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||||
|
return networkManager, "NetworkManager header + supported version on dbus", ""
|
||||||
|
}
|
||||||
|
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "resolvconf") {
|
||||||
|
if isSystemdResolveConfMode() {
|
||||||
|
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
||||||
|
}
|
||||||
|
return resolvConfManager, "resolvconf header detected", ""
|
||||||
|
}
|
||||||
|
return 0, "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
||||||
|
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
||||||
|
// into file mode while resolved is active.
|
||||||
func checkStub() bool {
|
func checkStub() bool {
|
||||||
rConf, err := parseDefaultResolvConf()
|
rConf, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse resolv conf: %s", err)
|
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,3 +178,36 @@ func checkStub() bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
||||||
|
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
||||||
|
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
||||||
|
func isLibnssResolveUsed() bool {
|
||||||
|
bs, err := os.ReadFile(nsswitchConfPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return parseNsswitchResolveAhead(bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseNsswitchResolveAhead(data []byte) bool {
|
||||||
|
for _, line := range strings.Split(string(data), "\n") {
|
||||||
|
if i := strings.IndexByte(line, '#'); i >= 0 {
|
||||||
|
line = line[:i]
|
||||||
|
}
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) < 2 || fields[0] != "hosts:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, module := range fields[1:] {
|
||||||
|
switch module {
|
||||||
|
case "dns":
|
||||||
|
return false
|
||||||
|
case "resolve":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
76
client/internal/dns/host_unix_test.go
Normal file
76
client/internal/dns/host_unix_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
//go:build (linux && !android) || freebsd
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestParseNsswitchResolveAhead(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "resolve before dns with action token",
|
||||||
|
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dns before resolve",
|
||||||
|
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "debian default with only dns",
|
||||||
|
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "neither resolve nor dns",
|
||||||
|
in: "hosts: files myhostname\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no hosts line",
|
||||||
|
in: "passwd: files systemd\ngroup: files systemd\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
in: "",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "comments and blank lines ignored",
|
||||||
|
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trailing inline comment",
|
||||||
|
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hosts token must be the first field",
|
||||||
|
in: " hosts: resolve dns\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other db line mentioning resolve is ignored",
|
||||||
|
in: "networks: resolve\nhosts: dns\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only resolve, no dns",
|
||||||
|
in: "hosts: files resolve\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
||||||
|
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user