mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 14:06:41 +00:00
Compare commits
345 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af973b2440 | ||
|
|
dd9bff9a4b | ||
|
|
1be5e454ba | ||
|
|
4a25a0d413 | ||
|
|
7fc3c7088e | ||
|
|
1869e70894 | ||
|
|
79783cc3dc | ||
|
|
584298e3bd | ||
|
|
f683afa647 | ||
|
|
ba2631d388 | ||
|
|
6ae4e2b691 | ||
|
|
51eee9dcf5 | ||
|
|
660e9e0e35 | ||
|
|
4ef6089053 | ||
|
|
c4e297cc96 | ||
|
|
e3f5497176 | ||
|
|
6a5dcc01a6 | ||
|
|
18b6d3bb0f | ||
|
|
ccbfdc5265 | ||
|
|
ab04537278 | ||
|
|
29c36c9837 | ||
|
|
c47e9bf547 | ||
|
|
abb682c935 | ||
|
|
79e8a4a8bb | ||
|
|
f2e81c024a | ||
|
|
6d10650e70 | ||
|
|
a81c683c66 | ||
|
|
25cb50901e | ||
|
|
a8e0844758 | ||
|
|
8b9ee6f26a | ||
|
|
82e8fcc3a7 | ||
|
|
e2b7777ba7 | ||
|
|
4e4d1a39f6 | ||
|
|
17dc1b0be1 | ||
|
|
a06436eeab | ||
|
|
a83cc2a3a3 | ||
|
|
d56537d0fd | ||
|
|
31bb483e40 | ||
|
|
cd91ae6e3a | ||
|
|
a9ec1e61d3 | ||
|
|
a13010c4af | ||
|
|
cfac3cdd53 | ||
|
|
5ecba61718 | ||
|
|
2ea12ce258 | ||
|
|
0b46289136 | ||
|
|
71044165d0 | ||
|
|
eafd816159 | ||
|
|
e1a687407e | ||
|
|
bd8031651e | ||
|
|
a63439543d | ||
|
|
90cd6e7f6e | ||
|
|
ea4a63c9b3 | ||
|
|
e047330ffd | ||
|
|
9dcc0796a6 | ||
|
|
4b6999e06a | ||
|
|
69952ee5c5 | ||
|
|
3710880ce0 | ||
|
|
17b75bf58f | ||
|
|
3ba1714524 | ||
|
|
3470da76fc | ||
|
|
c86df2c041 | ||
|
|
0e8315b149 | ||
|
|
2ab9790588 | ||
|
|
1ecb97306f | ||
|
|
15e96a779c | ||
|
|
dada0cc124 | ||
|
|
9c0b4fcd5f | ||
|
|
8a788ef238 | ||
|
|
20e0c18845 | ||
|
|
5b637bb4ca | ||
|
|
c565a46a6f | ||
|
|
7b7eae617a | ||
|
|
1ed27fec1a | ||
|
|
83edde3449 | ||
|
|
1b43f029a9 | ||
|
|
aeb908b68c | ||
|
|
f08b17c7bd | ||
|
|
cce8742490 | ||
|
|
c56696bab1 | ||
|
|
7bb004cf50 | ||
|
|
28910ce188 | ||
|
|
f8dc134210 | ||
|
|
148f5fde23 | ||
|
|
b76259bc31 | ||
|
|
88cc57bcef | ||
|
|
385c64c364 | ||
|
|
0b05497c25 | ||
|
|
4e3e824276 | ||
|
|
effc1a31ac | ||
|
|
03051a37fe | ||
|
|
8cf2a28b6f | ||
|
|
9f3422de1b | ||
|
|
e6d0e9bb13 | ||
|
|
da0ad21fd4 | ||
|
|
2940f16f19 | ||
|
|
44c8d871c2 | ||
|
|
96a88057f9 | ||
|
|
d96fe6391e | ||
|
|
fe7fd31955 | ||
|
|
86b19f243e | ||
|
|
d0940d03c4 | ||
|
|
5a51753dbf | ||
|
|
70be82d68a | ||
|
|
dde79bb2dc | ||
|
|
3822b1a065 | ||
|
|
8b68f00f59 | ||
|
|
fe197f0a0b | ||
|
|
675c934ce1 | ||
|
|
708c761fa6 | ||
|
|
78dc6508a4 | ||
|
|
7f6c824122 | ||
|
|
9ba3569573 | ||
|
|
fd38f4cc59 | ||
|
|
c5d5fcedd9 | ||
|
|
13c0a082b5 | ||
|
|
48962d4b65 | ||
|
|
c469707986 | ||
|
|
13c40f6b2c | ||
|
|
6071be0d08 | ||
|
|
4b269782ea | ||
|
|
518bf0e36a | ||
|
|
c80bb9740a | ||
|
|
3ceef1ef74 | ||
|
|
acb0b4a9a5 | ||
|
|
29aa68ecf7 | ||
|
|
50a97b19d1 | ||
|
|
229ce7504f | ||
|
|
b4f3619aff | ||
|
|
e77a4fbd66 | ||
|
|
f8f368a981 | ||
|
|
153b986100 | ||
|
|
1c47c0981c | ||
|
|
defd85e118 | ||
|
|
ec1085f5f7 | ||
|
|
d13cc179e8 | ||
|
|
a39e6d4f2b | ||
|
|
4875835024 | ||
|
|
f5a74c36f8 | ||
|
|
c71828f5a1 | ||
|
|
dc83af6c2e | ||
|
|
35544e1081 | ||
|
|
2ddb4a5645 | ||
|
|
c25fb02f1e | ||
|
|
28583c9507 | ||
|
|
ba41602e4b | ||
|
|
3e24a77625 | ||
|
|
4b8b281d5b | ||
|
|
a07a714d93 | ||
|
|
58ce93f6c3 | ||
|
|
293e507000 | ||
|
|
c948208493 | ||
|
|
2106734aa4 | ||
|
|
51162d6be6 | ||
|
|
45ef6e5279 | ||
|
|
3b2ffe006a | ||
|
|
a497f0873f | ||
|
|
6e4ec246ef | ||
|
|
7270b840cf | ||
|
|
fb007e09a9 | ||
|
|
9ce6450351 | ||
|
|
672fff0ad9 | ||
|
|
22474d92ef | ||
|
|
0e4a657700 | ||
|
|
e24ee0e68b | ||
|
|
cea9ab0932 | ||
|
|
229dc6afce | ||
|
|
e2fe7d53f8 | ||
|
|
e8f1fb507c | ||
|
|
7e410cde28 | ||
|
|
afe0d338be | ||
|
|
a18b367e60 | ||
|
|
91e44e112e | ||
|
|
a38d1ef8a8 | ||
|
|
a32e91de24 | ||
|
|
92b551fa4b | ||
|
|
53c1fa117a | ||
|
|
50525aaf8d | ||
|
|
d8ced86d19 | ||
|
|
2718d15825 | ||
|
|
fff234bdd5 | ||
|
|
0802673048 | ||
|
|
d54b7e3f14 | ||
|
|
650084132b | ||
|
|
534631fb27 | ||
|
|
9d34c818d7 | ||
|
|
2436a5be15 | ||
|
|
430f2bf7fa | ||
|
|
16362f285d | ||
|
|
34c7f89804 | ||
|
|
ead8fab70a | ||
|
|
50008f3c12 | ||
|
|
24b5122cc1 | ||
|
|
9099b246dc | ||
|
|
30ff3c06eb | ||
|
|
d02ca20c06 | ||
|
|
7afe842a95 | ||
|
|
47d628af73 | ||
|
|
0f1e51f391 | ||
|
|
5d6024ac59 | ||
|
|
6c7ee31330 | ||
|
|
b38357875e | ||
|
|
c230c7be28 | ||
|
|
d7cd746cc9 | ||
|
|
7941479994 | ||
|
|
e3623fd756 | ||
|
|
68c2744ebe | ||
|
|
a9d8d0e5c6 | ||
|
|
7f94fbc1e4 | ||
|
|
542d7e5d61 | ||
|
|
f93f73f541 | ||
|
|
930bf7e0f2 | ||
|
|
d7345c7dbd | ||
|
|
3e2cb70d58 | ||
|
|
d4c5292e8f | ||
|
|
45047343c4 | ||
|
|
c09fb312e8 | ||
|
|
8dfb4b2b20 | ||
|
|
6b17cb08c0 | ||
|
|
aa866493aa | ||
|
|
7b28137cf6 | ||
|
|
a8383f5612 | ||
|
|
2fc385155e | ||
|
|
b7271b77b6 | ||
|
|
1ef6b7ada6 | ||
|
|
ea454d0528 | ||
|
|
a6670ccab3 | ||
|
|
f226e8f7f3 | ||
|
|
75890ca5a6 | ||
|
|
e6cf631dbc | ||
|
|
b87f90c211 | ||
|
|
3e0cefa3dc | ||
|
|
2fe3359ae8 | ||
|
|
0aa8f07be3 | ||
|
|
36d47a7331 | ||
|
|
10fa5acb0b | ||
|
|
e3a679609f | ||
|
|
7fc09f8ed1 | ||
|
|
079843602c | ||
|
|
70bf22c354 | ||
|
|
3d891cfa97 | ||
|
|
78e3bb374a | ||
|
|
a61c7ca1ee | ||
|
|
7696ba2e36 | ||
|
|
235877c379 | ||
|
|
befab0f8d1 | ||
|
|
914d080a57 | ||
|
|
a274b4b38f | ||
|
|
ce3c585514 | ||
|
|
963d8abad5 | ||
|
|
38eb56381f | ||
|
|
43b3822090 | ||
|
|
b0fb370c4d | ||
|
|
99328ee76f | ||
|
|
36fc3ea253 | ||
|
|
a7979259f3 | ||
|
|
ea6fa72bc0 | ||
|
|
f9adde6b1d | ||
|
|
ba25586646 | ||
|
|
952ab63e8d | ||
|
|
5e84f802ed | ||
|
|
f40b0ff820 | ||
|
|
95a4840374 | ||
|
|
27424170e4 | ||
|
|
b7a04dc511 | ||
|
|
a8ace6f64a | ||
|
|
3fa1073f49 | ||
|
|
76d86c10ff | ||
|
|
2d34c6c8b2 | ||
|
|
a7f3477bdd | ||
|
|
af0a72d296 | ||
|
|
d1e836e760 | ||
|
|
8dd45c4ca2 | ||
|
|
9db009058b | ||
|
|
29c01deb05 | ||
|
|
7224d9824d | ||
|
|
8afc28fdff | ||
|
|
4ba2fb7b53 | ||
|
|
2e6076923d | ||
|
|
4c001dc751 | ||
|
|
2b8e240752 | ||
|
|
bee490713d | ||
|
|
1cb7fd94ab | ||
|
|
dc9a547950 | ||
|
|
2be0933246 | ||
|
|
c0b1cd6bde | ||
|
|
dd00289f8e | ||
|
|
b23a02ee97 | ||
|
|
80f726cfea | ||
|
|
aa8828186f | ||
|
|
0a990d196d | ||
|
|
00e8050949 | ||
|
|
18ee4c93fb | ||
|
|
8fa2da00b6 | ||
|
|
b851cd73c9 | ||
|
|
4c19d7ef6d | ||
|
|
cbecb9a0ce | ||
|
|
4fc8db08ba | ||
|
|
7ca46e0a75 | ||
|
|
a4ea5143af | ||
|
|
e9257b6423 | ||
|
|
3c9d3a1d2c | ||
|
|
b426f14190 | ||
|
|
d48acfba39 | ||
|
|
35b48cd8e5 | ||
|
|
15bca53309 | ||
|
|
898b599db5 | ||
|
|
c07bba18bb | ||
|
|
4c24d3b808 | ||
|
|
ad4ab3d04f | ||
|
|
e21153fae1 | ||
|
|
41c3360e23 | ||
|
|
1960d32443 | ||
|
|
74b83b3303 | ||
|
|
c2c3470868 | ||
|
|
2bda3dc3cc | ||
|
|
52573c8664 | ||
|
|
0d3c34e23f | ||
|
|
891df5c74b | ||
|
|
6f3f162d2b | ||
|
|
f6fa5fd02c | ||
|
|
8f4e0ba29e | ||
|
|
32b7dc7c43 | ||
|
|
78d2ebe1de | ||
|
|
014f8eb4e5 | ||
|
|
cd42803291 | ||
|
|
5c5b303994 | ||
|
|
968873da22 | ||
|
|
e2772f918b | ||
|
|
cdf6a31b67 | ||
|
|
2ce72065a7 | ||
|
|
b3e7aafb58 | ||
|
|
dd610ad850 | ||
|
|
b4b0a832e7 | ||
|
|
79963c1f66 | ||
|
|
5f95282161 | ||
|
|
1cca54f9d5 | ||
|
|
219df22919 | ||
|
|
337d9934fd | ||
|
|
bba4d72a78 | ||
|
|
fb5c793126 | ||
|
|
1821dbb672 | ||
|
|
4fda6fe031 | ||
|
|
f286f0faf6 | ||
|
|
cba3d607bf | ||
|
|
5ca12834a1 |
@@ -1,9 +1,9 @@
|
||||
.gitignore
|
||||
.dockerignore
|
||||
olm
|
||||
*.json
|
||||
README.md
|
||||
Makefile
|
||||
public/
|
||||
LICENSE
|
||||
CONTRIBUTING.md
|
||||
CONTRIBUTING.md
|
||||
bin/
|
||||
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Summary
|
||||
description: A clear and concise summary of the requested feature.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Motivation
|
||||
description: |
|
||||
Why is this feature important?
|
||||
Explain the problem this feature would solve or what use case it would enable.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Proposed Solution
|
||||
description: |
|
||||
How would you like to see this feature implemented?
|
||||
Provide as much detail as possible about the desired behavior, configuration, or changes.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
description: Describe any alternative solutions or workarounds you've thought about.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context, mockups, or screenshots about the feature request here.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before submitting, please:
|
||||
- Check if there is an existing issue for this feature.
|
||||
- Clearly explain the benefit and use case.
|
||||
- Be as specific as possible to help contributors evaluate and implement.
|
||||
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: Bug Report
|
||||
description: Create a bug report
|
||||
labels: []
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Describe the Bug
|
||||
description: A clear and concise description of what the bug is.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Environment
|
||||
description: Please fill out the relevant details below for your environment.
|
||||
value: |
|
||||
- OS Type & Version: (e.g., Ubuntu 22.04)
|
||||
- Pangolin Version:
|
||||
- Gerbil Version:
|
||||
- Traefik Version:
|
||||
- Newt Version:
|
||||
- Olm Version: (if applicable)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: To Reproduce
|
||||
description: |
|
||||
Steps to reproduce the behavior, please provide a clear description of how to reproduce the issue, based on the linked minimal reproduction. Screenshots can be provided in the issue body below.
|
||||
|
||||
If using code blocks, make sure syntax highlighting is correct and double-check that the rendered preview is not broken.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: A clear and concise description of what you expected to happen.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before posting the issue go through the steps you've written down to make sure the steps provided are detailed and clear.
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Contributors should be able to follow the steps provided in order to reproduce the bug.
|
||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Need help or have questions?
|
||||
url: https://github.com/orgs/fosrl/discussions
|
||||
about: Ask questions, get help, and discuss with other community members
|
||||
- name: Request a Feature
|
||||
url: https://github.com/orgs/fosrl/discussions/new?category=feature-requests
|
||||
about: Feature requests should be opened as discussions so others can upvote and comment
|
||||
30
.github/dependabot.yml
vendored
Normal file
30
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "gomod"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
groups:
|
||||
patch-updates:
|
||||
update-types:
|
||||
- "patch"
|
||||
minor-updates:
|
||||
update-types:
|
||||
- "minor"
|
||||
|
||||
- package-ecosystem: "docker"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
groups:
|
||||
patch-updates:
|
||||
update-types:
|
||||
- "patch"
|
||||
minor-updates:
|
||||
update-types:
|
||||
- "minor"
|
||||
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
637
.github/workflows/cicd.yml
vendored
637
.github/workflows/cicd.yml
vendored
@@ -1,44 +1,615 @@
|
||||
name: CI/CD Pipeline
|
||||
|
||||
permissions:
|
||||
contents: write # gh-release
|
||||
packages: write # GHCR push
|
||||
id-token: write # Keyless-Signatures & Attestations
|
||||
attestations: write # actions/attest-build-provenance
|
||||
security-events: write # upload-sarif
|
||||
actions: read
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
push:
|
||||
tags:
|
||||
- "[0-9]+.[0-9]+.[0-9]+"
|
||||
- "[0-9]+.[0-9]+.[0-9]+-rc.[0-9]+"
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: "SemVer version to release (e.g., 1.2.3, no leading 'v')"
|
||||
required: true
|
||||
type: string
|
||||
publish_latest:
|
||||
description: "Also publish the 'latest' image tag"
|
||||
required: true
|
||||
type: boolean
|
||||
default: false
|
||||
publish_minor:
|
||||
description: "Also publish the 'major.minor' image tag (e.g., 1.2)"
|
||||
required: true
|
||||
type: boolean
|
||||
default: false
|
||||
target_branch:
|
||||
description: "Branch to tag"
|
||||
required: false
|
||||
default: "main"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event_name == 'workflow_dispatch' && github.event.inputs.version || github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: ubuntu-latest
|
||||
prepare:
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
name: Prepare release (create tag)
|
||||
runs-on: ubuntu-24.04
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Validate version input
|
||||
shell: bash
|
||||
env:
|
||||
INPUT_VERSION: ${{ inputs.version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if ! [[ "$INPUT_VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+(-rc\.[0-9]+)?$ ]]; then
|
||||
echo "Invalid version: $INPUT_VERSION (expected X.Y.Z or X.Y.Z-rc.N)" >&2
|
||||
exit 1
|
||||
fi
|
||||
- name: Create and push tag
|
||||
shell: bash
|
||||
env:
|
||||
TARGET_BRANCH: ${{ inputs.target_branch }}
|
||||
VERSION: ${{ inputs.version }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
git fetch --prune origin
|
||||
git checkout "$TARGET_BRANCH"
|
||||
git pull --ff-only origin "$TARGET_BRANCH"
|
||||
if git rev-parse -q --verify "refs/tags/$VERSION" >/dev/null; then
|
||||
echo "Tag $VERSION already exists" >&2
|
||||
exit 1
|
||||
fi
|
||||
git tag -a "$VERSION" -m "Release $VERSION"
|
||||
git push origin "refs/tags/$VERSION"
|
||||
release:
|
||||
if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && github.actor != 'github-actions[bot]') }}
|
||||
name: Build and Release
|
||||
runs-on: ubuntu-24.04
|
||||
timeout-minutes: 120
|
||||
env:
|
||||
DOCKERHUB_IMAGE: docker.io/fosrl/${{ github.event.repository.name }}
|
||||
GHCR_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: 1.23.1
|
||||
- name: Capture created timestamp
|
||||
run: echo "IMAGE_CREATED=$(date -u +%Y-%m-%dT%H:%M:%SZ)" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
# - name: Update version in main.go
|
||||
# run: |
|
||||
# TAG=${{ env.TAG }}
|
||||
# if [ -f main.go ]; then
|
||||
# sed -i 's/Olm version replaceme/Olm version '"$TAG"'/' main.go
|
||||
# echo "Updated main.go with version $TAG"
|
||||
# else
|
||||
# echo "main.go not found"
|
||||
# fi
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
- name: Set up 1.2.0 Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: docker.io
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Normalize image names to lowercase
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "GHCR_IMAGE=${GHCR_IMAGE,,}" >> "$GITHUB_ENV"
|
||||
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE,,}" >> "$GITHUB_ENV"
|
||||
shell: bash
|
||||
|
||||
- name: Extract tag name
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
INPUT_VERSION: ${{ inputs.version }}
|
||||
run: |
|
||||
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
|
||||
echo "TAG=${INPUT_VERSION}" >> $GITHUB_ENV
|
||||
else
|
||||
echo "TAG=${{ github.ref_name }}" >> $GITHUB_ENV
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Validate pushed tag format (no leading 'v')
|
||||
if: ${{ github.event_name == 'push' }}
|
||||
shell: bash
|
||||
env:
|
||||
TAG_GOT: ${{ env.TAG }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if [[ "$TAG_GOT" =~ ^[0-9]+\.[0-9]+\.[0-9]+(-rc\.[0-9]+)?$ ]]; then
|
||||
echo "Tag OK: $TAG_GOT"
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: Tag '$TAG_GOT' is not allowed. Use 'X.Y.Z' or 'X.Y.Z-rc.N' (no leading 'v')." >&2
|
||||
exit 1
|
||||
- name: Wait for tag to be visible (dispatch only)
|
||||
if: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for i in {1..90}; do
|
||||
if git ls-remote --tags origin "refs/tags/${TAG}" | grep -qE "refs/tags/${TAG}$"; then
|
||||
echo "Tag ${TAG} is visible on origin"; exit 0
|
||||
fi
|
||||
echo "Tag not yet visible, retrying... ($i/90)"
|
||||
sleep 2
|
||||
done
|
||||
echo "Tag ${TAG} not visible after waiting"; exit 1
|
||||
shell: bash
|
||||
|
||||
- name: Update version in main.go
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
if [ -f main.go ]; then
|
||||
sed -i 's/version_replaceme/'"$TAG"'/' main.go
|
||||
echo "Updated main.go with version $TAG"
|
||||
else
|
||||
echo "main.go not found"
|
||||
fi
|
||||
|
||||
- name: Ensure repository is at the tagged commit (dispatch only)
|
||||
if: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
git fetch --tags --force
|
||||
git checkout "refs/tags/${TAG}"
|
||||
echo "Checked out $(git rev-parse --short HEAD) for tag ${TAG}"
|
||||
shell: bash
|
||||
|
||||
- name: Detect release candidate (rc)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if [[ "${TAG}" =~ ^[0-9]+\.[0-9]+\.[0-9]+-rc\.[0-9]+$ ]]; then
|
||||
echo "IS_RC=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "IS_RC=false" >> $GITHUB_ENV
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Resolve publish-latest flag
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
PL_INPUT: ${{ inputs.publish_latest }}
|
||||
PL_VAR: ${{ vars.PUBLISH_LATEST }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
val="false"
|
||||
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
|
||||
if [ "${PL_INPUT}" = "true" ]; then val="true"; fi
|
||||
else
|
||||
if [ "${PL_VAR}" = "true" ]; then val="true"; fi
|
||||
fi
|
||||
echo "PUBLISH_LATEST=$val" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Resolve publish-minor flag
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
PM_INPUT: ${{ inputs.publish_minor }}
|
||||
PM_VAR: ${{ vars.PUBLISH_MINOR }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
val="false"
|
||||
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
|
||||
if [ "${PM_INPUT}" = "true" ]; then val="true"; fi
|
||||
else
|
||||
if [ "${PM_VAR}" = "true" ]; then val="true"; fi
|
||||
fi
|
||||
echo "PUBLISH_MINOR=$val" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Cache Go modules
|
||||
if: ${{ hashFiles('**/go.sum') != '' }}
|
||||
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Go vet & test
|
||||
if: ${{ hashFiles('**/go.mod') != '' }}
|
||||
run: |
|
||||
go version
|
||||
go vet ./...
|
||||
go test ./... -race -covermode=atomic
|
||||
shell: bash
|
||||
|
||||
- name: Resolve license fallback
|
||||
run: echo "IMAGE_LICENSE=${{ github.event.repository.license.spdx_id || 'NOASSERTION' }}" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Resolve registries list (GHCR always, Docker Hub only if creds)
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
images="${GHCR_IMAGE}"
|
||||
if [ -n "${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}" ] && [ -n "${{ secrets.DOCKER_HUB_USERNAME }}" ]; then
|
||||
images="${images}\n${DOCKERHUB_IMAGE}"
|
||||
fi
|
||||
{
|
||||
echo 'IMAGE_LIST<<EOF'
|
||||
echo -e "$images"
|
||||
echo 'EOF'
|
||||
} >> "$GITHUB_ENV"
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5.10.0
|
||||
with:
|
||||
images: ${{ env.IMAGE_LIST }}
|
||||
tags: |
|
||||
type=semver,pattern={{version}},value=${{ env.TAG }}
|
||||
type=semver,pattern={{major}}.{{minor}},value=${{ env.TAG }},enable=${{ env.PUBLISH_MINOR == 'true' && env.IS_RC != 'true' }}
|
||||
type=raw,value=latest,enable=${{ env.IS_RC != 'true' }}
|
||||
flavor: |
|
||||
latest=false
|
||||
labels: |
|
||||
org.opencontainers.image.title=${{ github.event.repository.name }}
|
||||
org.opencontainers.image.version=${{ env.TAG }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
org.opencontainers.image.source=${{ github.event.repository.html_url }}
|
||||
org.opencontainers.image.url=${{ github.event.repository.html_url }}
|
||||
org.opencontainers.image.documentation=${{ github.event.repository.html_url }}
|
||||
org.opencontainers.image.description=${{ github.event.repository.description }}
|
||||
org.opencontainers.image.licenses=${{ env.IMAGE_LICENSE }}
|
||||
org.opencontainers.image.created=${{ env.IMAGE_CREATED }}
|
||||
org.opencontainers.image.ref.name=${{ env.TAG }}
|
||||
org.opencontainers.image.authors=${{ github.repository_owner }}
|
||||
- name: Echo build config (non-secret)
|
||||
shell: bash
|
||||
env:
|
||||
IMAGE_TITLE: ${{ github.event.repository.name }}
|
||||
IMAGE_VERSION: ${{ env.TAG }}
|
||||
IMAGE_REVISION: ${{ github.sha }}
|
||||
IMAGE_SOURCE_URL: ${{ github.event.repository.html_url }}
|
||||
IMAGE_URL: ${{ github.event.repository.html_url }}
|
||||
IMAGE_DESCRIPTION: ${{ github.event.repository.description }}
|
||||
IMAGE_LICENSE: ${{ env.IMAGE_LICENSE }}
|
||||
DOCKERHUB_IMAGE: ${{ env.DOCKERHUB_IMAGE }}
|
||||
GHCR_IMAGE: ${{ env.GHCR_IMAGE }}
|
||||
DOCKER_HUB_USER: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
REPO: ${{ github.repository }}
|
||||
OWNER: ${{ github.repository_owner }}
|
||||
WORKFLOW_REF: ${{ github.workflow_ref }}
|
||||
REF: ${{ github.ref }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "=== OCI Label Values ==="
|
||||
echo "org.opencontainers.image.title=${IMAGE_TITLE}"
|
||||
echo "org.opencontainers.image.version=${IMAGE_VERSION}"
|
||||
echo "org.opencontainers.image.revision=${IMAGE_REVISION}"
|
||||
echo "org.opencontainers.image.source=${IMAGE_SOURCE_URL}"
|
||||
echo "org.opencontainers.image.url=${IMAGE_URL}"
|
||||
echo "org.opencontainers.image.description=${IMAGE_DESCRIPTION}"
|
||||
echo "org.opencontainers.image.licenses=${IMAGE_LICENSE}"
|
||||
echo
|
||||
echo "=== Images ==="
|
||||
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE}"
|
||||
echo "GHCR_IMAGE=${GHCR_IMAGE}"
|
||||
echo "DOCKER_HUB_USERNAME=${DOCKER_HUB_USER}"
|
||||
echo
|
||||
echo "=== GitHub Kontext ==="
|
||||
echo "repository=${REPO}"
|
||||
echo "owner=${OWNER}"
|
||||
echo "workflow_ref=${WORKFLOW_REF}"
|
||||
echo "ref=${REF}"
|
||||
echo "ref_name=${REF_NAME}"
|
||||
echo "run_url=${RUN_URL}"
|
||||
echo
|
||||
echo "=== docker/metadata-action outputs (Tags/Labels), raw ==="
|
||||
echo "::group::tags"
|
||||
echo "${{ steps.meta.outputs.tags }}"
|
||||
echo "::endgroup::"
|
||||
echo "::group::labels"
|
||||
echo "${{ steps.meta.outputs.labels }}"
|
||||
echo "::endgroup::"
|
||||
- name: Build and push (Docker Hub + GHCR)
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6.18.0
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64,linux/arm/v7
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha,scope=${{ github.repository }}
|
||||
cache-to: type=gha,mode=max,scope=${{ github.repository }}
|
||||
provenance: mode=max
|
||||
sbom: true
|
||||
|
||||
- name: Compute image digest refs
|
||||
run: |
|
||||
echo "DIGEST=${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
|
||||
echo "GHCR_REF=$GHCR_IMAGE@${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
|
||||
echo "DH_REF=$DOCKERHUB_IMAGE@${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
|
||||
echo "Built digest: ${{ steps.build.outputs.digest }}"
|
||||
shell: bash
|
||||
|
||||
- name: Attest build provenance (GHCR)
|
||||
id: attest-ghcr
|
||||
uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
|
||||
with:
|
||||
subject-name: ${{ env.GHCR_IMAGE }}
|
||||
subject-digest: ${{ steps.build.outputs.digest }}
|
||||
push-to-registry: true
|
||||
show-summary: true
|
||||
|
||||
- name: Attest build provenance (Docker Hub)
|
||||
continue-on-error: true
|
||||
id: attest-dh
|
||||
uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
|
||||
with:
|
||||
subject-name: index.docker.io/fosrl/${{ github.event.repository.name }}
|
||||
subject-digest: ${{ steps.build.outputs.digest }}
|
||||
push-to-registry: true
|
||||
show-summary: true
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
with:
|
||||
cosign-release: 'v3.0.2'
|
||||
|
||||
- name: Sanity check cosign private key
|
||||
env:
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
cosign public-key --key env://COSIGN_PRIVATE_KEY >/dev/null
|
||||
shell: bash
|
||||
|
||||
- name: Sign GHCR image (digest) with key (recursive)
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Signing ${GHCR_REF} (digest) recursively with provided key"
|
||||
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${GHCR_REF}"
|
||||
echo "Waiting 30 seconds for signatures to propagate..."
|
||||
shell: bash
|
||||
|
||||
- name: Generate SBOM (SPDX JSON)
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # v0.33.1
|
||||
with:
|
||||
image-ref: ${{ env.GHCR_IMAGE }}@${{ steps.build.outputs.digest }}
|
||||
format: spdx-json
|
||||
output: sbom.spdx.json
|
||||
|
||||
- name: Validate SBOM JSON
|
||||
run: jq -e . sbom.spdx.json >/dev/null
|
||||
shell: bash
|
||||
|
||||
- name: Minify SBOM JSON (optional hardening)
|
||||
run: jq -c . sbom.spdx.json > sbom.min.json && mv sbom.min.json sbom.spdx.json
|
||||
shell: bash
|
||||
|
||||
- name: Create SBOM attestation (GHCR, private key)
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
cosign attest \
|
||||
--key env://COSIGN_PRIVATE_KEY \
|
||||
--type spdxjson \
|
||||
--predicate sbom.spdx.json \
|
||||
"${GHCR_REF}"
|
||||
shell: bash
|
||||
|
||||
- name: Create SBOM attestation (Docker Hub, private key)
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
cosign attest \
|
||||
--key env://COSIGN_PRIVATE_KEY \
|
||||
--type spdxjson \
|
||||
--predicate sbom.spdx.json \
|
||||
"${DH_REF}"
|
||||
shell: bash
|
||||
|
||||
- name: Keyless sign & verify GHCR digest (OIDC)
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
WORKFLOW_REF: ${{ github.workflow_ref }} # owner/repo/.github/workflows/<file>@refs/tags/<tag>
|
||||
ISSUER: https://token.actions.githubusercontent.com
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Keyless signing ${GHCR_REF}"
|
||||
cosign sign --rekor-url https://rekor.sigstore.dev --recursive "${GHCR_REF}"
|
||||
echo "Verify keyless (OIDC) signature policy on ${GHCR_REF}"
|
||||
cosign verify \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${WORKFLOW_REF}" \
|
||||
"${GHCR_REF}" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Sign Docker Hub image (digest) with key (recursive)
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Signing ${DH_REF} (digest) recursively with provided key (Docker media types fallback)"
|
||||
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${DH_REF}"
|
||||
shell: bash
|
||||
|
||||
- name: Keyless sign & verify Docker Hub digest (OIDC)
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
ISSUER: https://token.actions.githubusercontent.com
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Keyless signing ${DH_REF} (force public-good Rekor)"
|
||||
cosign sign --rekor-url https://rekor.sigstore.dev --recursive "${DH_REF}"
|
||||
echo "Keyless verify via Rekor (strict identity)"
|
||||
if ! cosign verify \
|
||||
--rekor-url https://rekor.sigstore.dev \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
|
||||
"${DH_REF}" -o text; then
|
||||
echo "Rekor verify failed — retry offline bundle verify (no Rekor)"
|
||||
if ! cosign verify \
|
||||
--offline \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
|
||||
"${DH_REF}" -o text; then
|
||||
echo "Offline bundle verify failed — ignore tlog (TEMP for debugging)"
|
||||
cosign verify \
|
||||
--insecure-ignore-tlog=true \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
|
||||
"${DH_REF}" -o text || true
|
||||
fi
|
||||
fi
|
||||
- name: Verify signature (public key) GHCR digest + tag
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG_VAR="${TAG}"
|
||||
echo "Verifying (digest) ${GHCR_REF}"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "$GHCR_REF" -o text
|
||||
echo "Verifying (tag) $GHCR_IMAGE:$TAG_VAR"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "$GHCR_IMAGE:$TAG_VAR" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify SBOM attestation (GHCR)
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: cosign verify-attestation --key env://COSIGN_PUBLIC_KEY --type spdxjson "$GHCR_REF" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify SLSA provenance (GHCR)
|
||||
env:
|
||||
ISSUER: https://token.actions.githubusercontent.com
|
||||
WFREF: ${{ github.workflow_ref }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
# (optional) show which predicate types are present to aid debugging
|
||||
cosign download attestation "$GHCR_REF" \
|
||||
| jq -r '.payload | @base64d | fromjson | .predicateType' | sort -u || true
|
||||
# Verify the SLSA v1 provenance attestation (predicate URL)
|
||||
cosign verify-attestation \
|
||||
--type 'https://slsa.dev/provenance/v1' \
|
||||
--certificate-oidc-issuer "$ISSUER" \
|
||||
--certificate-identity "https://github.com/${WFREF}" \
|
||||
--rekor-url https://rekor.sigstore.dev \
|
||||
"$GHCR_REF" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify signature (public key) Docker Hub digest
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Verifying (digest) ${DH_REF} with Docker media types"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "${DH_REF}" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify signature (public key) Docker Hub tag
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Verifying (tag) $DOCKERHUB_IMAGE:$TAG with Docker media types"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "$DOCKERHUB_IMAGE:$TAG" -o text
|
||||
shell: bash
|
||||
|
||||
# - name: Trivy scan (GHCR image)
|
||||
# id: trivy
|
||||
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # v0.33.1
|
||||
# with:
|
||||
# image-ref: ${{ env.GHCR_IMAGE }}@${{ steps.build.outputs.digest }}
|
||||
# format: sarif
|
||||
# output: trivy-ghcr.sarif
|
||||
# ignore-unfixed: true
|
||||
# vuln-type: os,library
|
||||
# severity: CRITICAL,HIGH
|
||||
# exit-code: ${{ (vars.TRIVY_FAIL || '0') }}
|
||||
|
||||
# - name: Upload SARIF
|
||||
# if: ${{ always() && hashFiles('trivy-ghcr.sarif') != '' }}
|
||||
# uses: github/codeql-action/upload-sarif@fdbfb4d2750291e159f0156def62b853c2798ca2 # v4.31.5
|
||||
# with:
|
||||
# sarif_file: trivy-ghcr.sarif
|
||||
# category: Image Vulnerability Scan
|
||||
|
||||
- name: Build binaries
|
||||
env:
|
||||
CGO_ENABLED: "0"
|
||||
GOFLAGS: "-trimpath"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG_VAR="${TAG}"
|
||||
make go-build-release tag=$TAG_VAR
|
||||
shell: bash
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: softprops/action-gh-release@a06a81a03ee405af7f2048a818ed3f03bbf83c7b # v2.5.0
|
||||
with:
|
||||
tag_name: ${{ env.TAG }}
|
||||
generate_release_notes: true
|
||||
prerelease: ${{ env.IS_RC == 'true' }}
|
||||
files: |
|
||||
bin/*
|
||||
fail_on_unmatched_files: true
|
||||
draft: true
|
||||
body: |
|
||||
## Container Images
|
||||
- GHCR: `${{ env.GHCR_REF }}`
|
||||
- Docker Hub: `${{ env.DH_REF || 'N/A' }}`
|
||||
**Digest:** `${{ steps.build.outputs.digest }}`
|
||||
|
||||
132
.github/workflows/mirror.yaml
vendored
Normal file
132
.github/workflows/mirror.yaml
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
name: Mirror & Sign (Docker Hub to GHCR)
|
||||
|
||||
on:
|
||||
workflow_dispatch: {}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write # for keyless OIDC
|
||||
|
||||
env:
|
||||
SOURCE_IMAGE: docker.io/fosrl/olm
|
||||
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
jobs:
|
||||
mirror-and-dual-sign:
|
||||
runs-on: amd64-runner
|
||||
steps:
|
||||
- name: Install skopeo + jq
|
||||
run: |
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y skopeo jq
|
||||
skopeo --version
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
|
||||
- name: Input check
|
||||
run: |
|
||||
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
|
||||
echo "Source : ${SOURCE_IMAGE}"
|
||||
echo "Target : ${DEST_IMAGE}"
|
||||
|
||||
# Auth for skopeo (containers-auth)
|
||||
- name: Skopeo login to GHCR
|
||||
run: |
|
||||
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
# Auth for cosign (docker-config)
|
||||
- name: Docker login to GHCR (for cosign)
|
||||
run: |
|
||||
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
|
||||
|
||||
- name: List source tags
|
||||
run: |
|
||||
set -euo pipefail
|
||||
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
|
||||
| jq -r '.Tags[]' | sort -u > src-tags.txt
|
||||
echo "Found source tags: $(wc -l < src-tags.txt)"
|
||||
head -n 20 src-tags.txt || true
|
||||
|
||||
- name: List destination tags (skip existing)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
|
||||
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
|
||||
else
|
||||
: > dst-tags.txt
|
||||
fi
|
||||
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
|
||||
|
||||
- name: Mirror, dual-sign, and verify
|
||||
env:
|
||||
# keyless
|
||||
COSIGN_YES: "true"
|
||||
# key-based
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
# verify
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
copied=0; skipped=0; v_ok=0; errs=0
|
||||
|
||||
issuer="https://token.actions.githubusercontent.com"
|
||||
id_regex="^https://github.com/${{ github.repository }}/.+"
|
||||
|
||||
while read -r tag; do
|
||||
[ -z "$tag" ] && continue
|
||||
|
||||
if grep -Fxq "$tag" dst-tags.txt; then
|
||||
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
|
||||
skipped=$((skipped+1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
|
||||
if ! skopeo copy --all --retry-times 3 \
|
||||
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
|
||||
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
|
||||
errs=$((errs+1)); continue
|
||||
fi
|
||||
copied=$((copied+1))
|
||||
|
||||
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
|
||||
ref="${DEST_IMAGE}@${digest}"
|
||||
|
||||
echo "==> cosign sign (keyless) --recursive ${ref}"
|
||||
if ! cosign sign --recursive "${ref}"; then
|
||||
echo "::warning title=Keyless sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign sign (key) --recursive ${ref}"
|
||||
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
|
||||
echo "::warning title=Key sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (public key) ${ref}"
|
||||
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
|
||||
echo "::warning title=Verify(pubkey) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (keyless policy) ${ref}"
|
||||
if ! cosign verify \
|
||||
--certificate-oidc-issuer "${issuer}" \
|
||||
--certificate-identity-regexp "${id_regex}" \
|
||||
"${ref}" -o text; then
|
||||
echo "::warning title=Verify(keyless) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
else
|
||||
v_ok=$((v_ok+1))
|
||||
fi
|
||||
done < src-tags.txt
|
||||
|
||||
echo "---- Summary ----"
|
||||
echo "Copied : $copied"
|
||||
echo "Skipped : $skipped"
|
||||
echo "Verified OK : $v_ok"
|
||||
echo "Errors : $errs"
|
||||
37
.github/workflows/stale-bot.yml
vendored
Normal file
37
.github/workflows/stale-bot.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
name: Mark and Close Stale Issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
workflow_dispatch: # Allow manual trigger
|
||||
|
||||
permissions:
|
||||
contents: write # only for delete-branch option
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
|
||||
with:
|
||||
days-before-stale: 14
|
||||
days-before-close: 14
|
||||
stale-issue-message: 'This issue has been automatically marked as stale due to 14 days of inactivity. It will be closed in 14 days if no further activity occurs.'
|
||||
close-issue-message: 'This issue has been automatically closed due to inactivity. If you believe this is still relevant, please open a new issue with up-to-date information.'
|
||||
stale-issue-label: 'stale'
|
||||
|
||||
exempt-issue-labels: 'needs investigating, networking, new feature, reverse proxy, bug, api, authentication, documentation, enhancement, help wanted, good first issue, question'
|
||||
|
||||
exempt-all-issue-assignees: true
|
||||
|
||||
only-labels: ''
|
||||
exempt-pr-labels: ''
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
|
||||
operations-per-run: 100
|
||||
remove-stale-when-updated: true
|
||||
delete-branch: false
|
||||
enable-statistics: true
|
||||
42
.github/workflows/test.yml
vendored
Normal file
42
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Run Tests
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build-go:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
|
||||
with:
|
||||
go-version: 1.25
|
||||
|
||||
- name: Build binaries
|
||||
run: make go-build-release
|
||||
|
||||
build-docker:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
||||
|
||||
- name: Set up 1.2.0 Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
||||
|
||||
- name: Build Docker image
|
||||
run: make docker-build-dev
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,2 @@
|
||||
olm
|
||||
.DS_Store
|
||||
bin/
|
||||
1
.go-version
Normal file
1
.go-version
Normal file
@@ -0,0 +1 @@
|
||||
1.25
|
||||
393
API.md
Normal file
393
API.md
Normal file
@@ -0,0 +1,393 @@
|
||||
## API
|
||||
|
||||
Olm can be controlled with an embedded API server when using `--enable-api`. This allows you to start it as a daemon and trigger it with the following endpoints. The API can listen on either a TCP address or a Unix socket/Windows named pipe.
|
||||
|
||||
### Socket vs TCP
|
||||
|
||||
When `--enable-api` is used, Olm can listen on a TCP address when configured via `--http-addr` (like `:9452`). Alternatively, Olm can listen on a Unix socket (Linux/macOS) or Windows named pipe for local-only communication with better security when using `--socket-path` (like `/var/run/olm.sock`).
|
||||
|
||||
**Unix Socket (Linux/macOS):**
|
||||
- Socket path example: `/var/run/olm/olm.sock`
|
||||
- The directory is created automatically if it doesn't exist
|
||||
- Socket permissions are set to `0666` to allow access
|
||||
- Existing socket files are automatically removed on startup
|
||||
- Socket file is cleaned up when Olm stops
|
||||
|
||||
**Windows Named Pipe:**
|
||||
- Pipe path example: `\\.\pipe\olm`
|
||||
- If the path doesn't start with `\`, it's automatically prefixed with `\\.\pipe\`
|
||||
- Security descriptor grants full access to Everyone and the current owner
|
||||
- Named pipes are automatically cleaned up by Windows
|
||||
|
||||
**Connecting to the Socket:**
|
||||
|
||||
```bash
|
||||
# Linux/macOS - using curl with Unix socket
|
||||
curl --unix-socket /var/run/olm/olm.sock http://localhost/status
|
||||
|
||||
---
|
||||
|
||||
### POST /connect
|
||||
Initiates a new connection request to a Pangolin server.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"id": "string",
|
||||
"secret": "string",
|
||||
"endpoint": "string",
|
||||
"userToken": "string",
|
||||
"mtu": 1280,
|
||||
"dns": "8.8.8.8",
|
||||
"dnsProxyIP": "string",
|
||||
"upstreamDNS": ["8.8.8.8:53", "1.1.1.1:53"],
|
||||
"interfaceName": "olm",
|
||||
"holepunch": false,
|
||||
"tlsClientCert": "string",
|
||||
"pingInterval": "3s",
|
||||
"pingTimeout": "5s",
|
||||
"orgId": "string",
|
||||
"fingerprint": {
|
||||
"username": "string",
|
||||
"hostname": "string",
|
||||
"platform": "string",
|
||||
"osVersion": "string",
|
||||
"kernelVersion": "string",
|
||||
"arch": "string",
|
||||
"deviceModel": "string",
|
||||
"serialNumber": "string"
|
||||
},
|
||||
"postures": {}
|
||||
}
|
||||
```
|
||||
|
||||
**Required Fields:**
|
||||
- `id`: Olm ID generated by Pangolin
|
||||
- `secret`: Authentication secret for the Olm ID
|
||||
- `endpoint`: Target Pangolin endpoint URL
|
||||
|
||||
**Optional Fields:**
|
||||
- `userToken`: User authentication token
|
||||
- `mtu`: MTU for the internal WireGuard interface (default: 1280)
|
||||
- `dns`: DNS server to use for resolving the endpoint
|
||||
- `dnsProxyIP`: DNS proxy IP address
|
||||
- `upstreamDNS`: Array of upstream DNS servers
|
||||
- `interfaceName`: Name of the WireGuard interface (default: olm)
|
||||
- `holepunch`: Enable NAT hole punching (default: false)
|
||||
- `tlsClientCert`: TLS client certificate
|
||||
- `pingInterval`: Interval for pinging the server (default: 3s)
|
||||
- `pingTimeout`: Timeout for each ping (default: 5s)
|
||||
- `orgId`: Organization ID to connect to
|
||||
- `fingerprint`: Device fingerprinting information (should be set before connecting)
|
||||
- `username`: Current username on the device
|
||||
- `hostname`: Device hostname
|
||||
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
|
||||
- `osVersion`: Operating system version
|
||||
- `kernelVersion`: Kernel version
|
||||
- `arch`: System architecture (e.g., amd64, arm64)
|
||||
- `deviceModel`: Device model identifier
|
||||
- `serialNumber`: Device serial number
|
||||
- `postures`: Device posture/security information
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `202 Accepted`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "connection request accepted"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `400 Bad Request` - Invalid JSON or missing required fields
|
||||
- `409 Conflict` - Already connected to a server (disconnect first)
|
||||
|
||||
---
|
||||
|
||||
### GET /status
|
||||
Returns the current connection status, registration state, and peer information.
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"connected": true,
|
||||
"registered": true,
|
||||
"terminated": false,
|
||||
"version": "1.0.0",
|
||||
"agent": "olm",
|
||||
"orgId": "org_123",
|
||||
"peers": {
|
||||
"10": {
|
||||
"siteId": 10,
|
||||
"name": "Site A",
|
||||
"connected": true,
|
||||
"rtt": 145338339,
|
||||
"lastSeen": "2025-08-13T14:39:17.208334428-07:00",
|
||||
"endpoint": "p.fosrl.io:21820",
|
||||
"isRelay": true,
|
||||
"peerAddress": "100.89.128.5",
|
||||
"holepunchConnected": false
|
||||
},
|
||||
"8": {
|
||||
"siteId": 8,
|
||||
"name": "Site B",
|
||||
"connected": false,
|
||||
"rtt": 0,
|
||||
"lastSeen": "2025-08-13T14:39:19.663823645-07:00",
|
||||
"endpoint": "p.fosrl.io:21820",
|
||||
"isRelay": true,
|
||||
"peerAddress": "100.89.128.10",
|
||||
"holepunchConnected": false
|
||||
}
|
||||
},
|
||||
"networkSettings": {
|
||||
"tunnelIP": "100.89.128.3/20"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Fields:**
|
||||
- `connected`: Boolean indicating if connected to Pangolin
|
||||
- `registered`: Boolean indicating if registered with the server
|
||||
- `terminated`: Boolean indicating if the connection was terminated
|
||||
- `version`: Olm version string
|
||||
- `agent`: Agent identifier
|
||||
- `orgId`: Current organization ID
|
||||
- `peers`: Map of peer statuses by site ID
|
||||
- `siteId`: Peer site identifier
|
||||
- `name`: Site name
|
||||
- `connected`: Boolean peer connection state
|
||||
- `rtt`: Peer round-trip time (integer, nanoseconds)
|
||||
- `lastSeen`: Last time peer was seen (RFC3339 timestamp)
|
||||
- `endpoint`: Peer endpoint address
|
||||
- `isRelay`: Whether the peer is relayed (true) or direct (false)
|
||||
- `peerAddress`: Peer's IP address in the tunnel
|
||||
- `holepunchConnected`: Whether holepunch connection is established
|
||||
- `networkSettings`: Current network configuration including tunnel IP
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-GET requests
|
||||
|
||||
---
|
||||
|
||||
### POST /disconnect
|
||||
Disconnects from the current Pangolin server and tears down the WireGuard tunnel.
|
||||
|
||||
**Request Body:** None required
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "disconnect initiated"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `409 Conflict` - Not currently connected to a server
|
||||
|
||||
---
|
||||
|
||||
### POST /switch-org
|
||||
Switches to a different organization while maintaining the connection.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"orgId": "string"
|
||||
}
|
||||
```
|
||||
|
||||
**Required Fields:**
|
||||
- `orgId`: The organization ID to switch to
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "org switch request accepted"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `400 Bad Request` - Invalid JSON or missing orgId field
|
||||
- `500 Internal Server Error` - Org switch failed
|
||||
|
||||
---
|
||||
|
||||
### PUT /metadata
|
||||
Updates device fingerprinting and posture information. This endpoint can be called at any time to update metadata, but it's recommended to provide this information in the initial `/connect` request or immediately before connecting.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"fingerprint": {
|
||||
"username": "string",
|
||||
"hostname": "string",
|
||||
"platform": "string",
|
||||
"osVersion": "string",
|
||||
"kernelVersion": "string",
|
||||
"arch": "string",
|
||||
"deviceModel": "string",
|
||||
"serialNumber": "string"
|
||||
},
|
||||
"postures": {}
|
||||
}
|
||||
```
|
||||
|
||||
**Optional Fields:**
|
||||
- `fingerprint`: Device fingerprinting information
|
||||
- `username`: Current username on the device
|
||||
- `hostname`: Device hostname
|
||||
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
|
||||
- `osVersion`: Operating system version
|
||||
- `kernelVersion`: Kernel version
|
||||
- `arch`: System architecture (e.g., amd64, arm64)
|
||||
- `deviceModel`: Device model identifier
|
||||
- `serialNumber`: Device serial number
|
||||
- `postures`: Device posture/security information (object with arbitrary key-value pairs)
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "metadata updated"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-PUT requests
|
||||
- `400 Bad Request` - Invalid JSON
|
||||
|
||||
**Note:** It's recommended to call this endpoint BEFORE `/connect` to ensure fingerprinting information is available during the initial connection handshake.
|
||||
|
||||
---
|
||||
|
||||
### POST /exit
|
||||
Initiates a graceful shutdown of the Olm process.
|
||||
|
||||
**Request Body:** None required
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "shutdown initiated"
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** The response is sent before shutdown begins. There is a 100ms delay before the actual shutdown to ensure the response is delivered.
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
|
||||
---
|
||||
|
||||
### GET /health
|
||||
Simple health check endpoint to verify the API server is running.
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "ok"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-GET requests
|
||||
|
||||
---
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Update metadata before connecting (recommended)
|
||||
```bash
|
||||
curl -X PUT http://localhost:9452/metadata \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"fingerprint": {
|
||||
"username": "john",
|
||||
"hostname": "johns-laptop",
|
||||
"platform": "macos",
|
||||
"osVersion": "14.2.1",
|
||||
"arch": "arm64",
|
||||
"deviceModel": "MacBookPro18,3"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### Connect to a peer
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/connect \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"id": "31frd0uzbjvp721",
|
||||
"secret": "h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6",
|
||||
"endpoint": "https://example.com"
|
||||
}'
|
||||
```
|
||||
|
||||
### Connect with additional options
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/connect \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"id": "31frd0uzbjvp721",
|
||||
"secret": "h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6",
|
||||
"endpoint": "https://example.com",
|
||||
"mtu": 1400,
|
||||
"holepunch": true,
|
||||
"pingInterval": "5s"
|
||||
}'
|
||||
```
|
||||
|
||||
### Check connection status
|
||||
```bash
|
||||
curl http://localhost:9452/status
|
||||
```
|
||||
|
||||
### Switch organization
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/switch-org \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"orgId": "org_456"}'
|
||||
```
|
||||
|
||||
### Disconnect from server
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/disconnect
|
||||
```
|
||||
|
||||
### Health check
|
||||
```bash
|
||||
curl http://localhost:9452/health
|
||||
```
|
||||
|
||||
### Shutdown Olm
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/exit
|
||||
```
|
||||
|
||||
### Using Unix socket (Linux/macOS)
|
||||
```bash
|
||||
curl --unix-socket /var/run/olm/olm.sock http://localhost/status
|
||||
curl --unix-socket /var/run/olm/olm.sock -X POST http://localhost/disconnect
|
||||
```
|
||||
@@ -4,11 +4,7 @@ Contributions are welcome!
|
||||
|
||||
Please see the contribution and local development guide on the docs page before getting started:
|
||||
|
||||
https://docs.fossorial.io/development
|
||||
|
||||
For ideas about what features to work on and our future plans, please see the roadmap:
|
||||
|
||||
https://docs.fossorial.io/roadmap
|
||||
https://docs.pangolin.net/development/contributing
|
||||
|
||||
### Licensing Considerations
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.23.1-alpine AS builder
|
||||
FROM golang:1.25-alpine AS builder
|
||||
|
||||
# Set the working directory inside the container
|
||||
WORKDIR /app
|
||||
@@ -16,9 +16,9 @@ COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /olm
|
||||
|
||||
# Start a new stage from scratch
|
||||
FROM ubuntu:22.04 AS runner
|
||||
FROM alpine:3.23 AS runner
|
||||
|
||||
RUN apt-get update && apt-get install ca-certificates -y && rm -rf /var/lib/apt/lists/*
|
||||
RUN apk --no-cache add ca-certificates
|
||||
|
||||
# Copy the pre-built binary file from the previous stage and the entrypoint script
|
||||
COPY --from=builder /olm /usr/local/bin/
|
||||
|
||||
66
Makefile
66
Makefile
@@ -1,15 +1,67 @@
|
||||
.PHONY: all local docker-build-release
|
||||
|
||||
all: go-build-release
|
||||
all: local
|
||||
|
||||
local:
|
||||
CGO_ENABLED=0 go build -o olm
|
||||
local:
|
||||
CGO_ENABLED=0 go build -o ./bin/olm
|
||||
|
||||
go-build-release:
|
||||
docker-build:
|
||||
docker build -t fosrl/olm:latest .
|
||||
|
||||
docker-build-release:
|
||||
@if [ -z "$(tag)" ]; then \
|
||||
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
|
||||
exit 1; \
|
||||
fi
|
||||
docker buildx build . \
|
||||
--platform linux/arm/v7,linux/arm64,linux/amd64 \
|
||||
-t fosrl/olm:latest \
|
||||
-t fosrl/olm:$(tag) \
|
||||
-f Dockerfile \
|
||||
--push
|
||||
|
||||
docker-build-dev:
|
||||
docker buildx build . \
|
||||
--platform linux/arm/v7,linux/arm64,linux/amd64 \
|
||||
-t fosrl/olm:latest \
|
||||
-f Dockerfile
|
||||
|
||||
.PHONY: go-build-release \
|
||||
go-build-release-linux-arm64 go-build-release-linux-arm32-v7 \
|
||||
go-build-release-linux-arm32-v6 go-build-release-linux-amd64 \
|
||||
go-build-release-linux-riscv64 go-build-release-darwin-arm64 \
|
||||
go-build-release-darwin-amd64 go-build-release-windows-amd64
|
||||
|
||||
go-build-release: \
|
||||
go-build-release-linux-arm64 \
|
||||
go-build-release-linux-arm32-v7 \
|
||||
go-build-release-linux-arm32-v6 \
|
||||
go-build-release-linux-amd64 \
|
||||
go-build-release-linux-riscv64 \
|
||||
go-build-release-darwin-arm64 \
|
||||
go-build-release-darwin-amd64 \
|
||||
go-build-release-windows-amd64 \
|
||||
|
||||
go-build-release-linux-arm64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/olm_linux_arm64
|
||||
|
||||
go-build-release-linux-arm32-v7:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/olm_linux_arm32
|
||||
|
||||
go-build-release-linux-arm32-v6:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -o bin/olm_linux_arm32v6
|
||||
|
||||
go-build-release-linux-amd64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/olm_linux_amd64
|
||||
|
||||
go-build-release-linux-riscv64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -o bin/olm_linux_riscv64
|
||||
|
||||
go-build-release-darwin-arm64:
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/olm_darwin_arm64
|
||||
|
||||
go-build-release-darwin-amd64:
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/olm_darwin_amd64
|
||||
|
||||
go-build-release-windows-amd64:
|
||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/olm_windows_amd64.exe
|
||||
|
||||
clean:
|
||||
rm olm
|
||||
117
README.md
117
README.md
@@ -1,12 +1,12 @@
|
||||
# Olm
|
||||
|
||||
Olm is a [WireGuard](https://www.wireguard.com/) tunnel client designed to securely connect you computer to Newt sites running on remote networks.
|
||||
Olm is a [WireGuard](https://www.wireguard.com/) tunnel client designed to securely connect your computer to Newt sites running on remote networks.
|
||||
|
||||
### Installation and Documentation
|
||||
|
||||
Olm is used with Pangolin and Newt as part of the larger system. See documentation below:
|
||||
|
||||
- [Full Documentation](https://docs.fossorial.io)
|
||||
- [Full Documentation](https://docs.pangolin.net/manage/clients/understanding-clients)
|
||||
|
||||
## Key Functions
|
||||
|
||||
@@ -18,127 +18,20 @@ Using the Olm ID and a secret, the olm will make HTTP requests to Pangolin to re
|
||||
|
||||
When Olm receives WireGuard control messages, it will use the information encoded (endpoint, public key) to bring up a WireGuard tunnel on your computer to a remote Newt. It will ping over the tunnel to ensure the peer is brought up.
|
||||
|
||||
## CLI Args
|
||||
|
||||
- `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket.
|
||||
- `id`: Olm ID generated by Pangolin to identify the olm.
|
||||
- `secret`: A unique secret (not shared and kept private) used to authenticate the olm ID with the websocket in order to receive commands.
|
||||
- `mtu` (optional): MTU for the internal WG interface. Default: 1280
|
||||
- `dns` (optional): DNS server to use to resolve the endpoint. Default: 8.8.8.8
|
||||
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: INFO
|
||||
- `ping-interval` (optional): Interval for pinging the server. Default: 3s
|
||||
- `ping-timeout` (optional): Timeout for each ping. Default: 5s
|
||||
- `interface` (optional): Name of the WireGuard interface. Default: olm
|
||||
- `enable-http` (optional): Enable HTTP server for receiving connection requests. Default: false
|
||||
- `http-addr` (optional): HTTP server address (e.g., ':9452'). Default: :9452
|
||||
- `holepunch` (optional): Enable hole punching. Default: false
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All CLI arguments can also be set via environment variables:
|
||||
|
||||
- `PANGOLIN_ENDPOINT`: Equivalent to `--endpoint`
|
||||
- `OLM_ID`: Equivalent to `--id`
|
||||
- `OLM_SECRET`: Equivalent to `--secret`
|
||||
- `MTU`: Equivalent to `--mtu`
|
||||
- `DNS`: Equivalent to `--dns`
|
||||
- `LOG_LEVEL`: Equivalent to `--log-level`
|
||||
- `INTERFACE`: Equivalent to `--interface`
|
||||
- `HTTP_ADDR`: Equivalent to `--http-addr`
|
||||
- `PING_INTERVAL`: Equivalent to `--ping-interval`
|
||||
- `PING_TIMEOUT`: Equivalent to `--ping-timeout`
|
||||
- `HOLEPUNCH`: Set to "true" to enable hole punching (equivalent to `--holepunch`)
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
olm \
|
||||
--id 31frd0uzbjvp721 \
|
||||
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
||||
--endpoint https://example.com
|
||||
```
|
||||
|
||||
## Hole Punching
|
||||
|
||||
In the default mode, olm "relays" traffic through Gerbil in the cloud to get down to newt. This is a little more reliable. Support for NAT hole punching is also EXPERIMENTAL right now using the `--holepunch` flag. This will attempt to orchestrate a NAT hole punch between the two sites so that traffic flows directly. This will save data costs and speed. If it fails it should fall back to relaying.
|
||||
|
||||
Right now, basic NAT hole punching is supported. We plan to add:
|
||||
|
||||
- [ ] Birthday paradox
|
||||
- [ ] UPnP
|
||||
- [ ] LAN detection
|
||||
|
||||
## Windows Service
|
||||
|
||||
On Windows, olm has to be installed and run as a Windows service. When running it with the cli args live above it will attempt to install and run the service to function like a cli tool. You can also run the following:
|
||||
|
||||
### Service Management Commands
|
||||
|
||||
```
|
||||
# Install the service
|
||||
olm.exe install
|
||||
|
||||
# Start the service
|
||||
olm.exe start
|
||||
|
||||
# Stop the service
|
||||
olm.exe stop
|
||||
|
||||
# Check service status
|
||||
olm.exe status
|
||||
|
||||
# Remove the service
|
||||
olm.exe remove
|
||||
|
||||
# Run in debug mode (console output) with our without id & secret
|
||||
olm.exe debug
|
||||
|
||||
# Show help
|
||||
olm.exe help
|
||||
```
|
||||
|
||||
### Service Configuration
|
||||
|
||||
When running as a service, Olm will read configuration from environment variables or you can modify the service to include command-line arguments:
|
||||
|
||||
1. Install the service: `olm.exe install`
|
||||
2. Configure the service with your credentials using Windows Service Manager or by setting system environment variables:
|
||||
- `PANGOLIN_ENDPOINT=https://example.com`
|
||||
- `OLM_ID=your_olm_id`
|
||||
- `OLM_SECRET=your_secret`
|
||||
3. Start the service: `olm.exe start`
|
||||
|
||||
### Service Logs
|
||||
|
||||
When running as a service, logs are written to:
|
||||
|
||||
- Windows Event Log (Application log, source: "OlmWireguardService")
|
||||
- Log files in: `%PROGRAMDATA%\olm\logs\olm.log`
|
||||
|
||||
You can view the Windows Event Log using Event Viewer or PowerShell:
|
||||
|
||||
```powershell
|
||||
Get-EventLog -LogName Application -Source "OlmWireguardService" -Newest 10
|
||||
```
|
||||
In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to Newt. Hole punching attempts to orchestrate a NAT traversal between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil.
|
||||
|
||||
## Build
|
||||
|
||||
### Container
|
||||
### Binary
|
||||
|
||||
Ensure Docker is installed.
|
||||
Make sure to have Go 1.25 installed.
|
||||
|
||||
```bash
|
||||
make
|
||||
```
|
||||
|
||||
### Binary
|
||||
|
||||
Make sure to have Go 1.23.1 installed.
|
||||
|
||||
```bash
|
||||
make local
|
||||
```
|
||||
|
||||
## Licensing
|
||||
|
||||
Olm is dual licensed under the AGPLv3 and the Fossorial Commercial license. For inquiries about commercial licensing, please contact us.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
||||
|
||||
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
||||
2. Send a detailed report to [security@fossorial.io](mailto:security@fossorial.io) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
2. Send a detailed report to [security@pangolin.net](mailto:security@pangolin.net) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
|
||||
- Description and location of the vulnerability.
|
||||
- Potential impact of the vulnerability.
|
||||
|
||||
675
api/api.go
Normal file
675
api/api.go
Normal file
@@ -0,0 +1,675 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
)
|
||||
|
||||
// ConnectionRequest defines the structure for an incoming connection request
|
||||
type ConnectionRequest struct {
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
UserToken string `json:"userToken,omitempty"`
|
||||
MTU int `json:"mtu,omitempty"`
|
||||
DNS string `json:"dns,omitempty"`
|
||||
DNSProxyIP string `json:"dnsProxyIP,omitempty"`
|
||||
UpstreamDNS []string `json:"upstreamDNS,omitempty"`
|
||||
InterfaceName string `json:"interfaceName,omitempty"`
|
||||
Holepunch bool `json:"holepunch,omitempty"`
|
||||
TlsClientCert string `json:"tlsClientCert,omitempty"`
|
||||
PingInterval string `json:"pingInterval,omitempty"`
|
||||
PingTimeout string `json:"pingTimeout,omitempty"`
|
||||
OrgID string `json:"orgId,omitempty"`
|
||||
}
|
||||
|
||||
// SwitchOrgRequest defines the structure for switching organizations
|
||||
type SwitchOrgRequest struct {
|
||||
OrgID string `json:"org_id"`
|
||||
}
|
||||
|
||||
// PowerModeRequest represents a request to change power mode
|
||||
type PowerModeRequest struct {
|
||||
Mode string `json:"mode"` // "normal" or "low"
|
||||
}
|
||||
|
||||
// PeerStatus represents the status of a peer connection
|
||||
type PeerStatus struct {
|
||||
SiteID int `json:"siteId"`
|
||||
Name string `json:"name"`
|
||||
Connected bool `json:"connected"`
|
||||
RTT time.Duration `json:"rtt"`
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
IsRelay bool `json:"isRelay"`
|
||||
PeerIP string `json:"peerAddress,omitempty"`
|
||||
HolepunchConnected bool `json:"holepunchConnected"`
|
||||
}
|
||||
|
||||
// OlmError holds error information from registration failures
|
||||
type OlmError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// StatusResponse is returned by the status endpoint
|
||||
type StatusResponse struct {
|
||||
Connected bool `json:"connected"`
|
||||
Registered bool `json:"registered"`
|
||||
Terminated bool `json:"terminated"`
|
||||
OlmError *OlmError `json:"error,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Agent string `json:"agent,omitempty"`
|
||||
OrgID string `json:"orgId,omitempty"`
|
||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
|
||||
}
|
||||
|
||||
type MetadataChangeRequest struct {
|
||||
Fingerprint map[string]any `json:"fingerprint"`
|
||||
Postures map[string]any `json:"postures"`
|
||||
}
|
||||
|
||||
// API represents the HTTP server and its state
|
||||
type API struct {
|
||||
addr string
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
|
||||
onConnect func(ConnectionRequest) error
|
||||
onSwitchOrg func(SwitchOrgRequest) error
|
||||
onMetadataChange func(MetadataChangeRequest) error
|
||||
onDisconnect func() error
|
||||
onExit func() error
|
||||
onRebind func() error
|
||||
onPowerMode func(PowerModeRequest) error
|
||||
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
connectedAt time.Time
|
||||
isConnected bool
|
||||
isRegistered bool
|
||||
isTerminated bool
|
||||
olmError *OlmError
|
||||
|
||||
version string
|
||||
agent string
|
||||
orgID string
|
||||
}
|
||||
|
||||
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||
func NewAPI(addr string) *API {
|
||||
s := &API{
|
||||
addr: addr,
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
|
||||
func NewAPISocket(socketPath string) *API {
|
||||
s := &API{
|
||||
socketPath: socketPath,
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func NewAPIStub() *API {
|
||||
s := &API{
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SetHandlers sets the callback functions for handling API requests
|
||||
func (s *API) SetHandlers(
|
||||
onConnect func(ConnectionRequest) error,
|
||||
onSwitchOrg func(SwitchOrgRequest) error,
|
||||
onMetadataChange func(MetadataChangeRequest) error,
|
||||
onDisconnect func() error,
|
||||
onExit func() error,
|
||||
onRebind func() error,
|
||||
onPowerMode func(PowerModeRequest) error,
|
||||
) {
|
||||
s.onConnect = onConnect
|
||||
s.onSwitchOrg = onSwitchOrg
|
||||
s.onMetadataChange = onMetadataChange
|
||||
s.onDisconnect = onDisconnect
|
||||
s.onExit = onExit
|
||||
s.onRebind = onRebind
|
||||
s.onPowerMode = onPowerMode
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *API) Start() error {
|
||||
if s.socketPath == "" && s.addr == "" {
|
||||
return fmt.Errorf("either socketPath or addr must be provided to start the API server")
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/connect", s.handleConnect)
|
||||
mux.HandleFunc("/status", s.handleStatus)
|
||||
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||
mux.HandleFunc("/metadata", s.handleMetadataChange)
|
||||
mux.HandleFunc("/disconnect", s.handleDisconnect)
|
||||
mux.HandleFunc("/exit", s.handleExit)
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
mux.HandleFunc("/rebind", s.handleRebind)
|
||||
mux.HandleFunc("/power-mode", s.handlePowerMode)
|
||||
|
||||
s.server = &http.Server{
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
var err error
|
||||
if s.socketPath != "" {
|
||||
// Use platform-specific socket listener
|
||||
s.listener, err = createSocketListener(s.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create socket listener: %w", err)
|
||||
}
|
||||
logger.Info("Starting HTTP server on socket %s", s.socketPath)
|
||||
} else {
|
||||
// Use TCP listener
|
||||
s.listener, err = net.Listen("tcp", s.addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create TCP listener: %w", err)
|
||||
}
|
||||
logger.Info("Starting HTTP server on %s", s.addr)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("HTTP server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the HTTP server
|
||||
func (s *API) Stop() error {
|
||||
logger.Info("Stopping api server")
|
||||
|
||||
// Close the server first, which will also close the listener gracefully
|
||||
if s.server != nil {
|
||||
_ = s.server.Close()
|
||||
}
|
||||
|
||||
// Clean up socket file if using Unix socket
|
||||
if s.socketPath != "" {
|
||||
cleanupSocket(s.socketPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *API) AddPeerStatus(siteID int, siteName string, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.Name = siteName
|
||||
status.Connected = connected
|
||||
status.RTT = rtt
|
||||
status.LastSeen = time.Now()
|
||||
status.Endpoint = endpoint
|
||||
status.IsRelay = isRelay
|
||||
}
|
||||
|
||||
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.Connected = connected
|
||||
status.RTT = rtt
|
||||
status.LastSeen = time.Now()
|
||||
status.Endpoint = endpoint
|
||||
status.IsRelay = isRelay
|
||||
}
|
||||
|
||||
func (s *API) RemovePeerStatus(siteID int) { // remove the peer from the status map
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
delete(s.peerStatuses, siteID)
|
||||
}
|
||||
|
||||
// SetConnectionStatus sets the overall connection status
|
||||
func (s *API) SetConnectionStatus(isConnected bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
s.isConnected = isConnected
|
||||
|
||||
if isConnected {
|
||||
s.connectedAt = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *API) SetRegistered(registered bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.isRegistered = registered
|
||||
// Clear any registration error when successfully registered
|
||||
if registered {
|
||||
s.olmError = nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetOlmError sets the registration error
|
||||
func (s *API) SetOlmError(code string, message string) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.olmError = &OlmError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// ClearOlmError clears any registration error
|
||||
func (s *API) ClearOlmError() {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.olmError = nil
|
||||
}
|
||||
|
||||
func (s *API) SetTerminated(terminated bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.isTerminated = terminated
|
||||
}
|
||||
|
||||
// ClearPeerStatuses clears all peer statuses
|
||||
func (s *API) ClearPeerStatuses() {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.peerStatuses = make(map[int]*PeerStatus)
|
||||
}
|
||||
|
||||
// SetVersion sets the olm version
|
||||
func (s *API) SetVersion(version string) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.version = version
|
||||
}
|
||||
|
||||
// SetAgent sets the olm agent
|
||||
func (s *API) SetAgent(agent string) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.agent = agent
|
||||
}
|
||||
|
||||
// SetOrgID sets the organization ID
|
||||
func (s *API) SetOrgID(orgID string) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.orgID = orgID
|
||||
}
|
||||
|
||||
// UpdatePeerRelayStatus updates only the relay status of a peer
|
||||
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.Endpoint = endpoint
|
||||
status.IsRelay = isRelay
|
||||
}
|
||||
|
||||
// UpdatePeerHolepunchStatus updates the holepunch connection status of a peer
|
||||
func (s *API) UpdatePeerHolepunchStatus(siteID int, holepunchConnected bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.HolepunchConnected = holepunchConnected
|
||||
}
|
||||
|
||||
// handleConnect handles the /connect endpoint
|
||||
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// if we are already connected, reject new connection requests
|
||||
s.statusMu.RLock()
|
||||
alreadyConnected := s.isConnected
|
||||
s.statusMu.RUnlock()
|
||||
if alreadyConnected {
|
||||
http.Error(w, "Already connected to a server. Disconnect first before connecting again.", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
var req ConnectionRequest
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
|
||||
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Call the connect handler if set
|
||||
if s.onConnect != nil {
|
||||
if err := s.onConnect(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Connection failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "connection request accepted",
|
||||
})
|
||||
}
|
||||
|
||||
// handleStatus handles the /status endpoint
|
||||
func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
s.statusMu.RLock()
|
||||
|
||||
resp := StatusResponse{
|
||||
Connected: s.isConnected,
|
||||
Registered: s.isRegistered,
|
||||
Terminated: s.isTerminated,
|
||||
OlmError: s.olmError,
|
||||
Version: s.version,
|
||||
Agent: s.agent,
|
||||
OrgID: s.orgID,
|
||||
PeerStatuses: s.peerStatuses,
|
||||
NetworkSettings: network.GetSettings(),
|
||||
}
|
||||
|
||||
s.statusMu.RUnlock()
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
// handleHealth handles the /health endpoint
|
||||
func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "ok",
|
||||
})
|
||||
}
|
||||
|
||||
// handleExit handles the /exit endpoint
|
||||
func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received exit request via API")
|
||||
|
||||
// Return a success response first
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "shutdown initiated",
|
||||
})
|
||||
|
||||
// Call the exit handler after responding, in a goroutine with a small delay
|
||||
// to ensure the response is fully sent before shutdown begins
|
||||
if s.onExit != nil {
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if err := s.onExit(); err != nil {
|
||||
logger.Error("Exit handler failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// handleSwitchOrg handles the /switch-org endpoint
|
||||
func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req SwitchOrgRequest
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.OrgID == "" {
|
||||
http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||
|
||||
// Call the switch org handler if set
|
||||
if s.onSwitchOrg != nil {
|
||||
if err := s.onSwitchOrg(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Org switch failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "org switch request accepted",
|
||||
})
|
||||
}
|
||||
|
||||
// handleDisconnect handles the /disconnect endpoint
|
||||
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// if we are already disconnected, reject new disconnect requests
|
||||
s.statusMu.RLock()
|
||||
alreadyDisconnected := !s.isConnected
|
||||
s.statusMu.RUnlock()
|
||||
if alreadyDisconnected {
|
||||
http.Error(w, "Not currently connected to a server.", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received disconnect request via API")
|
||||
|
||||
// Call the disconnect handler if set
|
||||
if s.onDisconnect != nil {
|
||||
if err := s.onDisconnect(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Disconnect failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "disconnect initiated",
|
||||
})
|
||||
}
|
||||
|
||||
// handleMetadataChange handles the /metadata endpoint
|
||||
func (s *API) handleMetadataChange(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req MetadataChangeRequest
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received metadata change request via API: %v", req)
|
||||
|
||||
_ = s.onMetadataChange(req)
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "metadata updated",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *API) GetStatus() StatusResponse {
|
||||
return StatusResponse{
|
||||
Connected: s.isConnected,
|
||||
Registered: s.isRegistered,
|
||||
Terminated: s.isTerminated,
|
||||
OlmError: s.olmError,
|
||||
Version: s.version,
|
||||
Agent: s.agent,
|
||||
OrgID: s.orgID,
|
||||
PeerStatuses: s.peerStatuses,
|
||||
NetworkSettings: network.GetSettings(),
|
||||
}
|
||||
}
|
||||
|
||||
// handleRebind handles the /rebind endpoint
|
||||
// This triggers a socket rebind, which is necessary when network connectivity changes
|
||||
// (e.g., WiFi to cellular transition on macOS/iOS) and the old socket becomes stale.
|
||||
func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received rebind request via API")
|
||||
|
||||
// Call the rebind handler if set
|
||||
if s.onRebind != nil {
|
||||
if err := s.onRebind(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Rebind failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "Rebind handler not configured", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "socket rebound successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// handlePowerMode handles the /power-mode endpoint
|
||||
// This allows changing the power mode between "normal" and "low"
|
||||
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req PowerModeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate power mode
|
||||
if req.Mode != "normal" && req.Mode != "low" {
|
||||
http.Error(w, "Invalid power mode: must be 'normal' or 'low'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received power mode change request via API: mode=%s", req.Mode)
|
||||
|
||||
// Call the power mode handler if set
|
||||
if s.onPowerMode != nil {
|
||||
if err := s.onPowerMode(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Power mode change failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "Power mode handler not configured", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": fmt.Sprintf("power mode changed to %s successfully", req.Mode),
|
||||
})
|
||||
}
|
||||
50
api/api_unix.go
Normal file
50
api/api_unix.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// createSocketListener creates a Unix domain socket listener
|
||||
func createSocketListener(socketPath string) (net.Listener, error) {
|
||||
// Ensure the directory exists
|
||||
dir := filepath.Dir(socketPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create socket directory: %w", err)
|
||||
}
|
||||
|
||||
// Remove existing socket file if it exists
|
||||
if err := os.RemoveAll(socketPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to remove existing socket: %w", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen on Unix socket: %w", err)
|
||||
}
|
||||
|
||||
// Set socket permissions to allow access
|
||||
if err := os.Chmod(socketPath, 0666); err != nil {
|
||||
listener.Close()
|
||||
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Created Unix socket at %s", socketPath)
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// cleanupSocket removes the Unix socket file
|
||||
func cleanupSocket(socketPath string) {
|
||||
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
|
||||
logger.Error("Failed to remove socket file %s: %v", socketPath, err)
|
||||
} else {
|
||||
logger.Debug("Removed Unix socket at %s", socketPath)
|
||||
}
|
||||
}
|
||||
41
api/api_windows.go
Normal file
41
api/api_windows.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// createSocketListener creates a Windows named pipe listener
|
||||
func createSocketListener(pipePath string) (net.Listener, error) {
|
||||
// Ensure the pipe path has the correct format
|
||||
if pipePath[0] != '\\' {
|
||||
pipePath = `\\.\pipe\` + pipePath
|
||||
}
|
||||
|
||||
// Create a pipe configuration that allows everyone to write
|
||||
config := &winio.PipeConfig{
|
||||
// Set security descriptor to allow everyone full access
|
||||
// This SDDL string grants full access to Everyone (WD) and to the current owner (OW)
|
||||
SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)",
|
||||
}
|
||||
|
||||
// Create a named pipe listener using go-winio with the configuration
|
||||
listener, err := winio.ListenPipe(pipePath, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen on named pipe: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Created named pipe at %s with write access for everyone", pipePath)
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up
|
||||
func cleanupSocket(pipePath string) {
|
||||
logger.Debug("Named pipe %s will be automatically cleaned up", pipePath)
|
||||
}
|
||||
650
config.go
Normal file
650
config.go
Normal file
@@ -0,0 +1,650 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OlmConfig holds all configuration options for the Olm client
|
||||
type OlmConfig struct {
|
||||
// Connection settings
|
||||
Endpoint string `json:"endpoint"`
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
OrgID string `json:"org"`
|
||||
UserToken string `json:"userToken"`
|
||||
|
||||
// Network settings
|
||||
MTU int `json:"mtu"`
|
||||
DNS string `json:"dns"`
|
||||
UpstreamDNS []string `json:"upstreamDNS"`
|
||||
InterfaceName string `json:"interface"`
|
||||
|
||||
// Logging
|
||||
LogLevel string `json:"logLevel"`
|
||||
|
||||
// HTTP server
|
||||
EnableAPI bool `json:"enableApi"`
|
||||
HTTPAddr string `json:"httpAddr"`
|
||||
SocketPath string `json:"socketPath"`
|
||||
|
||||
// Ping settings
|
||||
PingInterval string `json:"pingInterval"`
|
||||
PingTimeout string `json:"pingTimeout"`
|
||||
|
||||
// Advanced
|
||||
DisableHolepunch bool `json:"disableHolepunch"`
|
||||
TlsClientCert string `json:"tlsClientCert"`
|
||||
OverrideDNS bool `json:"overrideDNS"`
|
||||
TunnelDNS bool `json:"tunnelDNS"`
|
||||
DisableRelay bool `json:"disableRelay"`
|
||||
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
|
||||
|
||||
// Parsed values (not in JSON)
|
||||
PingIntervalDuration time.Duration `json:"-"`
|
||||
PingTimeoutDuration time.Duration `json:"-"`
|
||||
|
||||
// Source tracking (not in JSON)
|
||||
sources map[string]string `json:"-"`
|
||||
|
||||
Version string
|
||||
}
|
||||
|
||||
// ConfigSource tracks where each config value came from
|
||||
type ConfigSource string
|
||||
|
||||
const (
|
||||
SourceDefault ConfigSource = "default"
|
||||
SourceFile ConfigSource = "file"
|
||||
SourceEnv ConfigSource = "environment"
|
||||
SourceCLI ConfigSource = "cli"
|
||||
)
|
||||
|
||||
// DefaultConfig returns a config with default values
|
||||
func DefaultConfig() *OlmConfig {
|
||||
// Set OS-specific socket path
|
||||
var socketPath string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
socketPath = "olm"
|
||||
default: // darwin, linux, and others
|
||||
socketPath = "/var/run/olm.sock"
|
||||
}
|
||||
|
||||
config := &OlmConfig{
|
||||
MTU: 1280,
|
||||
DNS: "8.8.8.8",
|
||||
UpstreamDNS: []string{"8.8.8.8:53"},
|
||||
LogLevel: "INFO",
|
||||
InterfaceName: "olm",
|
||||
EnableAPI: false,
|
||||
SocketPath: socketPath,
|
||||
PingInterval: "3s",
|
||||
PingTimeout: "5s",
|
||||
DisableHolepunch: false,
|
||||
OverrideDNS: true,
|
||||
TunnelDNS: false,
|
||||
// DoNotCreateNewClient: false,
|
||||
sources: make(map[string]string),
|
||||
}
|
||||
|
||||
// Track default sources
|
||||
config.sources["mtu"] = string(SourceDefault)
|
||||
config.sources["dns"] = string(SourceDefault)
|
||||
config.sources["upstreamDNS"] = string(SourceDefault)
|
||||
config.sources["logLevel"] = string(SourceDefault)
|
||||
config.sources["interface"] = string(SourceDefault)
|
||||
config.sources["enableApi"] = string(SourceDefault)
|
||||
config.sources["httpAddr"] = string(SourceDefault)
|
||||
config.sources["socketPath"] = string(SourceDefault)
|
||||
config.sources["pingInterval"] = string(SourceDefault)
|
||||
config.sources["pingTimeout"] = string(SourceDefault)
|
||||
config.sources["disableHolepunch"] = string(SourceDefault)
|
||||
config.sources["overrideDNS"] = string(SourceDefault)
|
||||
config.sources["tunnelDNS"] = string(SourceDefault)
|
||||
config.sources["disableRelay"] = string(SourceDefault)
|
||||
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// getOlmConfigPath returns the path to the olm config file
|
||||
func getOlmConfigPath() string {
|
||||
configFile := os.Getenv("CONFIG_FILE")
|
||||
if configFile != "" {
|
||||
return configFile
|
||||
}
|
||||
|
||||
var configDir string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client")
|
||||
case "windows":
|
||||
configDir = filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "olm-client")
|
||||
default: // linux and others
|
||||
configDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
fmt.Printf("Warning: Failed to create config directory: %v\n", err)
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, "config.json")
|
||||
}
|
||||
|
||||
// LoadConfig loads configuration from file, env vars, and CLI args
|
||||
// Priority: CLI args > Env vars > Config file > Defaults
|
||||
// Returns: (config, showVersion, showConfig, error)
|
||||
func LoadConfig(args []string) (*OlmConfig, bool, bool, error) {
|
||||
// Start with defaults
|
||||
config := DefaultConfig()
|
||||
|
||||
// Load from config file (if exists)
|
||||
fileConfig, err := loadConfigFromFile()
|
||||
if err != nil {
|
||||
return nil, false, false, fmt.Errorf("failed to load config file: %w", err)
|
||||
}
|
||||
if fileConfig != nil {
|
||||
mergeConfigs(config, fileConfig)
|
||||
}
|
||||
|
||||
// Override with environment variables
|
||||
loadConfigFromEnv(config)
|
||||
|
||||
// Override with CLI arguments
|
||||
showVersion, showConfig, err := loadConfigFromCLI(config, args)
|
||||
if err != nil {
|
||||
return nil, false, false, err
|
||||
}
|
||||
|
||||
// Parse duration strings
|
||||
if err := config.parseDurations(); err != nil {
|
||||
return nil, false, false, err
|
||||
}
|
||||
|
||||
return config, showVersion, showConfig, nil
|
||||
}
|
||||
|
||||
// loadConfigFromFile loads configuration from the JSON config file
|
||||
func loadConfigFromFile() (*OlmConfig, error) {
|
||||
configPath := getOlmConfigPath()
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil // File doesn't exist, not an error
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config OlmConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// loadConfigFromEnv loads configuration from environment variables
|
||||
func loadConfigFromEnv(config *OlmConfig) {
|
||||
if val := os.Getenv("PANGOLIN_ENDPOINT"); val != "" {
|
||||
config.Endpoint = val
|
||||
config.sources["endpoint"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("OLM_ID"); val != "" {
|
||||
config.ID = val
|
||||
config.sources["id"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("OLM_SECRET"); val != "" {
|
||||
config.Secret = val
|
||||
config.sources["secret"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("ORG"); val != "" {
|
||||
config.OrgID = val
|
||||
config.sources["org"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("USER_TOKEN"); val != "" {
|
||||
config.UserToken = val
|
||||
config.sources["userToken"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("MTU"); val != "" {
|
||||
if mtu, err := strconv.Atoi(val); err == nil {
|
||||
config.MTU = mtu
|
||||
config.sources["mtu"] = string(SourceEnv)
|
||||
} else {
|
||||
fmt.Printf("Invalid MTU value: %s, keeping current value\n", val)
|
||||
}
|
||||
}
|
||||
if val := os.Getenv("DNS"); val != "" {
|
||||
config.DNS = val
|
||||
config.sources["dns"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("UPSTREAM_DNS"); val != "" {
|
||||
config.UpstreamDNS = []string{val}
|
||||
config.sources["upstreamDNS"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("LOG_LEVEL"); val != "" {
|
||||
config.LogLevel = val
|
||||
config.sources["logLevel"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("INTERFACE"); val != "" {
|
||||
config.InterfaceName = val
|
||||
config.sources["interface"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("HTTP_ADDR"); val != "" {
|
||||
config.HTTPAddr = val
|
||||
config.sources["httpAddr"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("PING_INTERVAL"); val != "" {
|
||||
config.PingInterval = val
|
||||
config.sources["pingInterval"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("PING_TIMEOUT"); val != "" {
|
||||
config.PingTimeout = val
|
||||
config.sources["pingTimeout"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("ENABLE_API"); val == "true" {
|
||||
config.EnableAPI = true
|
||||
config.sources["enableApi"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("SOCKET_PATH"); val != "" {
|
||||
config.SocketPath = val
|
||||
config.sources["socketPath"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("DISABLE_HOLEPUNCH"); val == "true" {
|
||||
config.DisableHolepunch = true
|
||||
config.sources["disableHolepunch"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("OVERRIDE_DNS"); val == "true" {
|
||||
config.OverrideDNS = true
|
||||
config.sources["overrideDNS"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("DISABLE_RELAY"); val == "true" {
|
||||
config.DisableRelay = true
|
||||
config.sources["disableRelay"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("TUNNEL_DNS"); val == "true" {
|
||||
config.TunnelDNS = true
|
||||
config.sources["tunnelDNS"] = string(SourceEnv)
|
||||
}
|
||||
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
|
||||
// config.DoNotCreateNewClient = true
|
||||
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
|
||||
// }
|
||||
}
|
||||
|
||||
// loadConfigFromCLI loads configuration from command-line arguments
|
||||
func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
|
||||
|
||||
// Store original values to detect changes
|
||||
origValues := map[string]interface{}{
|
||||
"endpoint": config.Endpoint,
|
||||
"id": config.ID,
|
||||
"secret": config.Secret,
|
||||
"org": config.OrgID,
|
||||
"userToken": config.UserToken,
|
||||
"mtu": config.MTU,
|
||||
"dns": config.DNS,
|
||||
"upstreamDNS": fmt.Sprintf("%v", config.UpstreamDNS),
|
||||
"logLevel": config.LogLevel,
|
||||
"interface": config.InterfaceName,
|
||||
"httpAddr": config.HTTPAddr,
|
||||
"socketPath": config.SocketPath,
|
||||
"pingInterval": config.PingInterval,
|
||||
"pingTimeout": config.PingTimeout,
|
||||
"enableApi": config.EnableAPI,
|
||||
"disableHolepunch": config.DisableHolepunch,
|
||||
"overrideDNS": config.OverrideDNS,
|
||||
"disableRelay": config.DisableRelay,
|
||||
"tunnelDNS": config.TunnelDNS,
|
||||
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||
}
|
||||
|
||||
// Define flags
|
||||
serviceFlags.StringVar(&config.Endpoint, "endpoint", config.Endpoint, "Endpoint of your Pangolin server")
|
||||
serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID")
|
||||
serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret")
|
||||
serviceFlags.StringVar(&config.OrgID, "org", config.OrgID, "Organization ID")
|
||||
serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)")
|
||||
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
|
||||
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
|
||||
var upstreamDNSFlag string
|
||||
serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8:53)")
|
||||
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
|
||||
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
|
||||
serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)")
|
||||
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
|
||||
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
|
||||
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
|
||||
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
|
||||
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "When enabled, the client uses custom DNS servers to resolve internal resources and aliases. This overrides your system's default DNS settings. Queries that cannot be resolved as a Pangolin resource will be forwarded to your configured Upstream DNS Server. (default false)")
|
||||
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
|
||||
serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "When enabled, DNS queries are routed through the tunnel for remote resolution. To ensure queries are tunneled correctly, you must define the DNS server as a Pangolin resource and enter its address as an Upstream DNS Server. (default false)")
|
||||
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
|
||||
|
||||
version := serviceFlags.Bool("version", false, "Print the version")
|
||||
showConfig := serviceFlags.Bool("show-config", false, "Show configuration sources and exit")
|
||||
|
||||
// Parse the arguments
|
||||
if err := serviceFlags.Parse(args); err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
// Parse upstream DNS flag if provided
|
||||
if upstreamDNSFlag != "" {
|
||||
config.UpstreamDNS = []string{}
|
||||
for _, dns := range splitComma(upstreamDNSFlag) {
|
||||
if dns != "" {
|
||||
config.UpstreamDNS = append(config.UpstreamDNS, dns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track which values were changed by CLI args
|
||||
if config.Endpoint != origValues["endpoint"].(string) {
|
||||
config.sources["endpoint"] = string(SourceCLI)
|
||||
}
|
||||
if config.ID != origValues["id"].(string) {
|
||||
config.sources["id"] = string(SourceCLI)
|
||||
}
|
||||
if config.Secret != origValues["secret"].(string) {
|
||||
config.sources["secret"] = string(SourceCLI)
|
||||
}
|
||||
if config.OrgID != origValues["org"].(string) {
|
||||
config.sources["org"] = string(SourceCLI)
|
||||
}
|
||||
if config.UserToken != origValues["userToken"].(string) {
|
||||
config.sources["userToken"] = string(SourceCLI)
|
||||
}
|
||||
if config.MTU != origValues["mtu"].(int) {
|
||||
config.sources["mtu"] = string(SourceCLI)
|
||||
}
|
||||
if config.DNS != origValues["dns"].(string) {
|
||||
config.sources["dns"] = string(SourceCLI)
|
||||
}
|
||||
if fmt.Sprintf("%v", config.UpstreamDNS) != origValues["upstreamDNS"].(string) {
|
||||
config.sources["upstreamDNS"] = string(SourceCLI)
|
||||
}
|
||||
if config.LogLevel != origValues["logLevel"].(string) {
|
||||
config.sources["logLevel"] = string(SourceCLI)
|
||||
}
|
||||
if config.InterfaceName != origValues["interface"].(string) {
|
||||
config.sources["interface"] = string(SourceCLI)
|
||||
}
|
||||
if config.HTTPAddr != origValues["httpAddr"].(string) {
|
||||
config.sources["httpAddr"] = string(SourceCLI)
|
||||
}
|
||||
if config.SocketPath != origValues["socketPath"].(string) {
|
||||
config.sources["socketPath"] = string(SourceCLI)
|
||||
}
|
||||
if config.PingInterval != origValues["pingInterval"].(string) {
|
||||
config.sources["pingInterval"] = string(SourceCLI)
|
||||
}
|
||||
if config.PingTimeout != origValues["pingTimeout"].(string) {
|
||||
config.sources["pingTimeout"] = string(SourceCLI)
|
||||
}
|
||||
if config.EnableAPI != origValues["enableApi"].(bool) {
|
||||
config.sources["enableApi"] = string(SourceCLI)
|
||||
}
|
||||
if config.DisableHolepunch != origValues["disableHolepunch"].(bool) {
|
||||
config.sources["disableHolepunch"] = string(SourceCLI)
|
||||
}
|
||||
if config.OverrideDNS != origValues["overrideDNS"].(bool) {
|
||||
config.sources["overrideDNS"] = string(SourceCLI)
|
||||
}
|
||||
if config.DisableRelay != origValues["disableRelay"].(bool) {
|
||||
config.sources["disableRelay"] = string(SourceCLI)
|
||||
}
|
||||
if config.TunnelDNS != origValues["tunnelDNS"].(bool) {
|
||||
config.sources["tunnelDNS"] = string(SourceCLI)
|
||||
}
|
||||
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
|
||||
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
|
||||
// }
|
||||
|
||||
return *version, *showConfig, nil
|
||||
}
|
||||
|
||||
// parseDurations parses the duration strings into time.Duration
|
||||
func (c *OlmConfig) parseDurations() error {
|
||||
var err error
|
||||
|
||||
// Parse ping interval
|
||||
if c.PingInterval != "" {
|
||||
c.PingIntervalDuration, err = time.ParseDuration(c.PingInterval)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", c.PingInterval)
|
||||
c.PingIntervalDuration = 3 * time.Second
|
||||
c.PingInterval = "3s"
|
||||
}
|
||||
} else {
|
||||
c.PingIntervalDuration = 3 * time.Second
|
||||
c.PingInterval = "3s"
|
||||
}
|
||||
|
||||
// Parse ping timeout
|
||||
if c.PingTimeout != "" {
|
||||
c.PingTimeoutDuration, err = time.ParseDuration(c.PingTimeout)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", c.PingTimeout)
|
||||
c.PingTimeoutDuration = 5 * time.Second
|
||||
c.PingTimeout = "5s"
|
||||
}
|
||||
} else {
|
||||
c.PingTimeoutDuration = 5 * time.Second
|
||||
c.PingTimeout = "5s"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeConfigs merges source config into destination (only non-empty values)
|
||||
// Also tracks that these values came from a file
|
||||
func mergeConfigs(dest, src *OlmConfig) {
|
||||
if src.Endpoint != "" {
|
||||
dest.Endpoint = src.Endpoint
|
||||
dest.sources["endpoint"] = string(SourceFile)
|
||||
}
|
||||
if src.ID != "" {
|
||||
dest.ID = src.ID
|
||||
dest.sources["id"] = string(SourceFile)
|
||||
}
|
||||
if src.Secret != "" {
|
||||
dest.Secret = src.Secret
|
||||
dest.sources["secret"] = string(SourceFile)
|
||||
}
|
||||
if src.OrgID != "" {
|
||||
dest.OrgID = src.OrgID
|
||||
dest.sources["org"] = string(SourceFile)
|
||||
}
|
||||
if src.UserToken != "" {
|
||||
dest.UserToken = src.UserToken
|
||||
dest.sources["userToken"] = string(SourceFile)
|
||||
}
|
||||
if src.MTU != 0 && src.MTU != 1280 {
|
||||
dest.MTU = src.MTU
|
||||
dest.sources["mtu"] = string(SourceFile)
|
||||
}
|
||||
if src.DNS != "" && src.DNS != "8.8.8.8" {
|
||||
dest.DNS = src.DNS
|
||||
dest.sources["dns"] = string(SourceFile)
|
||||
}
|
||||
if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8:53]" {
|
||||
dest.UpstreamDNS = src.UpstreamDNS
|
||||
dest.sources["upstreamDNS"] = string(SourceFile)
|
||||
}
|
||||
if src.LogLevel != "" && src.LogLevel != "INFO" {
|
||||
dest.LogLevel = src.LogLevel
|
||||
dest.sources["logLevel"] = string(SourceFile)
|
||||
}
|
||||
if src.InterfaceName != "" && src.InterfaceName != "olm" {
|
||||
dest.InterfaceName = src.InterfaceName
|
||||
dest.sources["interface"] = string(SourceFile)
|
||||
}
|
||||
if src.HTTPAddr != "" && src.HTTPAddr != ":9452" {
|
||||
dest.HTTPAddr = src.HTTPAddr
|
||||
dest.sources["httpAddr"] = string(SourceFile)
|
||||
}
|
||||
if src.SocketPath != "" {
|
||||
// Check if it's not the default for any OS
|
||||
isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm"
|
||||
if !isDefault {
|
||||
dest.SocketPath = src.SocketPath
|
||||
dest.sources["socketPath"] = string(SourceFile)
|
||||
}
|
||||
}
|
||||
if src.PingInterval != "" && src.PingInterval != "3s" {
|
||||
dest.PingInterval = src.PingInterval
|
||||
dest.sources["pingInterval"] = string(SourceFile)
|
||||
}
|
||||
if src.PingTimeout != "" && src.PingTimeout != "5s" {
|
||||
dest.PingTimeout = src.PingTimeout
|
||||
dest.sources["pingTimeout"] = string(SourceFile)
|
||||
}
|
||||
if src.TlsClientCert != "" {
|
||||
dest.TlsClientCert = src.TlsClientCert
|
||||
dest.sources["tlsClientCert"] = string(SourceFile)
|
||||
}
|
||||
// For booleans, we always take the source value if explicitly set
|
||||
if src.EnableAPI {
|
||||
dest.EnableAPI = src.EnableAPI
|
||||
dest.sources["enableApi"] = string(SourceFile)
|
||||
}
|
||||
if src.DisableHolepunch {
|
||||
dest.DisableHolepunch = src.DisableHolepunch
|
||||
dest.sources["disableHolepunch"] = string(SourceFile)
|
||||
}
|
||||
if src.OverrideDNS {
|
||||
dest.OverrideDNS = src.OverrideDNS
|
||||
dest.sources["overrideDNS"] = string(SourceFile)
|
||||
}
|
||||
if src.DisableRelay {
|
||||
dest.DisableRelay = src.DisableRelay
|
||||
dest.sources["disableRelay"] = string(SourceFile)
|
||||
}
|
||||
// if src.DoNotCreateNewClient {
|
||||
// dest.DoNotCreateNewClient = src.DoNotCreateNewClient
|
||||
// dest.sources["doNotCreateNewClient"] = string(SourceFile)
|
||||
// }
|
||||
}
|
||||
|
||||
// SaveConfig saves the current configuration to the config file
|
||||
func SaveConfig(config *OlmConfig) error {
|
||||
configPath := getOlmConfigPath()
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
return os.WriteFile(configPath, data, 0644)
|
||||
}
|
||||
|
||||
// ShowConfig prints the configuration and the source of each value
|
||||
func (c *OlmConfig) ShowConfig() {
|
||||
configPath := getOlmConfigPath()
|
||||
|
||||
fmt.Print("\n=== Olm Configuration ===\n\n")
|
||||
fmt.Printf("Config File: %s\n", configPath)
|
||||
|
||||
// Check if config file exists
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
fmt.Printf("Config File Status: ✓ exists\n")
|
||||
} else {
|
||||
fmt.Printf("Config File Status: ✗ not found\n")
|
||||
}
|
||||
|
||||
fmt.Println("\n--- Configuration Values ---")
|
||||
fmt.Print("(Format: Setting = Value [source])\n\n")
|
||||
|
||||
// Helper to get source or default
|
||||
getSource := func(key string) string {
|
||||
if source, ok := c.sources[key]; ok {
|
||||
return source
|
||||
}
|
||||
return string(SourceDefault)
|
||||
}
|
||||
|
||||
// Helper to format value (mask secrets)
|
||||
formatValue := func(key, value string) string {
|
||||
if key == "secret" && value != "" {
|
||||
if len(value) > 8 {
|
||||
return value[:4] + "****" + value[len(value)-4:]
|
||||
}
|
||||
return "****"
|
||||
}
|
||||
if value == "" {
|
||||
return "(not set)"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// Connection settings
|
||||
fmt.Println("Connection:")
|
||||
fmt.Printf(" endpoint = %s [%s]\n", formatValue("endpoint", c.Endpoint), getSource("endpoint"))
|
||||
fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id"))
|
||||
fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret"))
|
||||
fmt.Printf(" org = %s [%s]\n", formatValue("org", c.OrgID), getSource("org"))
|
||||
fmt.Printf(" user-token = %s [%s]\n", formatValue("userToken", c.UserToken), getSource("userToken"))
|
||||
|
||||
// Network settings
|
||||
fmt.Println("\nNetwork:")
|
||||
fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu"))
|
||||
fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns"))
|
||||
fmt.Printf(" upstream-dns = %v [%s]\n", c.UpstreamDNS, getSource("upstreamDNS"))
|
||||
fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface"))
|
||||
|
||||
// Logging
|
||||
fmt.Println("\nLogging:")
|
||||
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
|
||||
|
||||
// API server
|
||||
fmt.Println("\nAPI Server:")
|
||||
fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi"))
|
||||
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
|
||||
fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath"))
|
||||
|
||||
// Timing
|
||||
fmt.Println("\nTiming:")
|
||||
fmt.Printf(" ping-interval = %s [%s]\n", c.PingInterval, getSource("pingInterval"))
|
||||
fmt.Printf(" ping-timeout = %s [%s]\n", c.PingTimeout, getSource("pingTimeout"))
|
||||
|
||||
// Advanced
|
||||
fmt.Println("\nAdvanced:")
|
||||
fmt.Printf(" disable-holepunch = %v [%s]\n", c.DisableHolepunch, getSource("disableHolepunch"))
|
||||
fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS"))
|
||||
fmt.Printf(" tunnel-dns = %v [%s]\n", c.TunnelDNS, getSource("tunnelDNS"))
|
||||
fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay"))
|
||||
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
|
||||
if c.TlsClientCert != "" {
|
||||
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
|
||||
}
|
||||
|
||||
// Source legend
|
||||
fmt.Println("\n--- Source Legend ---")
|
||||
fmt.Println(" default = Built-in default value")
|
||||
fmt.Println(" file = Loaded from config file")
|
||||
fmt.Println(" environment = Set via environment variable")
|
||||
fmt.Println(" cli = Provided as command-line argument")
|
||||
fmt.Println("\nPriority: cli > environment > file > default")
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// splitComma splits a comma-separated string into a slice of trimmed strings
|
||||
func splitComma(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
43
create_test_creds.py
Normal file
43
create_test_creds.py
Normal file
@@ -0,0 +1,43 @@
|
||||
|
||||
import requests
|
||||
|
||||
def create_olm(base_url, user_token, olm_name, user_id):
|
||||
url = f"{base_url}/api/v1/user/{user_id}/olm"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "pangolin-cli",
|
||||
"X-CSRF-Token": "x-csrf-protection",
|
||||
"Cookie": f"p_session_token={user_token}"
|
||||
}
|
||||
payload = {"name": olm_name}
|
||||
response = requests.put(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
print(f"Response Data: {data}")
|
||||
|
||||
def create_client(base_url, user_token, client_name):
|
||||
url = f"{base_url}/api/v1/api/clients"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "pangolin-cli",
|
||||
"X-CSRF-Token": "x-csrf-protection",
|
||||
"Cookie": f"p_session_token={user_token}"
|
||||
}
|
||||
payload = {"name": client_name}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
print(f"Response Data: {data}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
base_url = input("Enter base URL (e.g., http://localhost:3000): ")
|
||||
user_token = input("Enter user token: ")
|
||||
user_id = input("Enter user ID: ")
|
||||
olm_name = input("Enter OLM name: ")
|
||||
client_name = input("Enter client name: ")
|
||||
|
||||
create_olm(base_url, user_token, olm_name, user_id)
|
||||
# client_id = create_client(base_url, user_token, client_name)
|
||||
663
device/middle_device.go
Normal file
663
device/middle_device.go
Normal file
@@ -0,0 +1,663 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
// PacketHandler processes intercepted packets and returns true if packet should be dropped
|
||||
type PacketHandler func(packet []byte) bool
|
||||
|
||||
// FilterRule defines a rule for packet filtering
|
||||
type FilterRule struct {
|
||||
DestIP netip.Addr
|
||||
Handler PacketHandler
|
||||
}
|
||||
|
||||
// closeAwareDevice wraps a tun.Device along with a flag
|
||||
// indicating whether its Close method was called.
|
||||
type closeAwareDevice struct {
|
||||
isClosed atomic.Bool
|
||||
tun.Device
|
||||
closeEventCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice {
|
||||
return &closeAwareDevice{
|
||||
Device: tunDevice,
|
||||
isClosed: atomic.Bool{},
|
||||
closeEventCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// redirectEvents redirects the Events() method of the underlying tun.Device
|
||||
// to the given channel.
|
||||
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-c.Device.Events():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if ev == tun.EventDown {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case out <- ev:
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Close calls the underlying Device's Close method
|
||||
// after setting isClosed to true.
|
||||
func (c *closeAwareDevice) Close() (err error) {
|
||||
c.closeOnce.Do(func() {
|
||||
c.isClosed.Store(true)
|
||||
close(c.closeEventCh)
|
||||
err = c.Device.Close()
|
||||
c.wg.Wait()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *closeAwareDevice) IsClosed() bool {
|
||||
return c.isClosed.Load()
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
bufs [][]byte
|
||||
sizes []int
|
||||
offset int
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||
// and supports swapping the underlying device.
|
||||
type MiddleDevice struct {
|
||||
devices []*closeAwareDevice
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
rules []FilterRule
|
||||
rulesMutex sync.RWMutex
|
||||
readCh chan readResult
|
||||
injectCh chan []byte
|
||||
closed atomic.Bool
|
||||
events chan tun.Event
|
||||
}
|
||||
|
||||
// NewMiddleDevice creates a new filtered TUN device wrapper
|
||||
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||
d := &MiddleDevice{
|
||||
devices: make([]*closeAwareDevice, 0),
|
||||
rules: make([]FilterRule, 0),
|
||||
readCh: make(chan readResult, 16),
|
||||
injectCh: make(chan []byte, 100),
|
||||
events: make(chan tun.Event, 16),
|
||||
}
|
||||
d.cond = sync.NewCond(&d.mu)
|
||||
|
||||
if device != nil {
|
||||
d.AddDevice(device)
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
// AddDevice adds a new underlying TUN device, closing any previous one
|
||||
func (d *MiddleDevice) AddDevice(device tun.Device) {
|
||||
d.mu.Lock()
|
||||
if d.closed.Load() {
|
||||
d.mu.Unlock()
|
||||
_ = device.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var toClose *closeAwareDevice
|
||||
if len(d.devices) > 0 {
|
||||
toClose = d.devices[len(d.devices)-1]
|
||||
}
|
||||
|
||||
cad := newCloseAwareDevice(device)
|
||||
cad.redirectEvents(d.events)
|
||||
|
||||
d.devices = []*closeAwareDevice{cad}
|
||||
|
||||
// Start pump for the new device
|
||||
go d.pump(cad)
|
||||
|
||||
d.cond.Broadcast()
|
||||
d.mu.Unlock()
|
||||
|
||||
if toClose != nil {
|
||||
logger.Debug("MiddleDevice: Closing previous device")
|
||||
if err := toClose.Close(); err != nil {
|
||||
logger.Debug("MiddleDevice: Error closing previous device: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) pump(dev *closeAwareDevice) {
|
||||
const defaultOffset = 16
|
||||
batchSize := dev.BatchSize()
|
||||
logger.Debug("MiddleDevice: pump started for device")
|
||||
|
||||
// Recover from panic if readCh is closed while we're trying to send
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Debug("MiddleDevice: pump recovered from panic (channel closed)")
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
// Check if this device is closed
|
||||
if dev.IsClosed() {
|
||||
logger.Debug("MiddleDevice: pump exiting, device is closed")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if MiddleDevice itself is closed
|
||||
if d.closed.Load() {
|
||||
logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed")
|
||||
return
|
||||
}
|
||||
|
||||
// Allocate buffers for reading
|
||||
bufs := make([][]byte, batchSize)
|
||||
sizes := make([]int, batchSize)
|
||||
for i := range bufs {
|
||||
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
|
||||
}
|
||||
|
||||
n, err := dev.Read(bufs, sizes, defaultOffset)
|
||||
|
||||
// Check if device was closed during read
|
||||
if dev.IsClosed() {
|
||||
logger.Debug("MiddleDevice: pump exiting, device closed during read")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if MiddleDevice was closed during read
|
||||
if d.closed.Load() {
|
||||
logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read")
|
||||
return
|
||||
}
|
||||
|
||||
// Try to send the result - check closed state first to avoid sending on closed channel
|
||||
if d.closed.Load() {
|
||||
logger.Debug("MiddleDevice: pump exiting, device closed before send")
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||
default:
|
||||
// Channel full, check if we should exit
|
||||
if dev.IsClosed() || d.closed.Load() {
|
||||
return
|
||||
}
|
||||
// Try again with blocking
|
||||
select {
|
||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||
case <-dev.closeEventCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Debug("MiddleDevice: pump exiting due to read error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
|
||||
func (d *MiddleDevice) InjectOutbound(packet []byte) {
|
||||
if d.closed.Load() {
|
||||
return
|
||||
}
|
||||
// Use defer/recover to handle panic from sending on closed channel
|
||||
// This can happen during shutdown race conditions
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Debug("MiddleDevice: InjectOutbound recovered from panic (channel closed)")
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case d.injectCh <- packet:
|
||||
default:
|
||||
// Channel full, drop packet
|
||||
logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full")
|
||||
}
|
||||
}
|
||||
|
||||
// AddRule adds a packet filtering rule
|
||||
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||
d.rulesMutex.Lock()
|
||||
defer d.rulesMutex.Unlock()
|
||||
d.rules = append(d.rules, FilterRule{
|
||||
DestIP: destIP,
|
||||
Handler: handler,
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveRule removes all rules for a given destination IP
|
||||
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||
d.rulesMutex.Lock()
|
||||
defer d.rulesMutex.Unlock()
|
||||
newRules := make([]FilterRule, 0, len(d.rules))
|
||||
for _, rule := range d.rules {
|
||||
if rule.DestIP != destIP {
|
||||
newRules = append(newRules, rule)
|
||||
}
|
||||
}
|
||||
d.rules = newRules
|
||||
}
|
||||
|
||||
// Close stops the device
|
||||
func (d *MiddleDevice) Close() error {
|
||||
if !d.closed.CompareAndSwap(false, true) {
|
||||
return nil // already closed
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
devices := d.devices
|
||||
d.devices = nil
|
||||
d.cond.Broadcast()
|
||||
d.mu.Unlock()
|
||||
|
||||
// Close underlying devices first - this causes the pump goroutines to exit
|
||||
// when their read operations return errors
|
||||
var lastErr error
|
||||
logger.Debug("MiddleDevice: Closing %d devices", len(devices))
|
||||
for _, device := range devices {
|
||||
if err := device.Close(); err != nil {
|
||||
logger.Debug("MiddleDevice: Error closing device: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
// Now close channels to unblock any remaining readers
|
||||
// The pump should have exited by now, but close channels to be safe
|
||||
close(d.readCh)
|
||||
close(d.injectCh)
|
||||
close(d.events)
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// Events returns the events channel
|
||||
func (d *MiddleDevice) Events() <-chan tun.Event {
|
||||
return d.events
|
||||
}
|
||||
|
||||
// File returns the underlying file descriptor
|
||||
func (d *MiddleDevice) File() *os.File {
|
||||
for {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
file := dev.File()
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return file
|
||||
}
|
||||
}
|
||||
|
||||
// MTU returns the MTU of the underlying device
|
||||
func (d *MiddleDevice) MTU() (int, error) {
|
||||
for {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
mtu, err := dev.MTU()
|
||||
if err == nil {
|
||||
return mtu, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the name of the underlying device
|
||||
func (d *MiddleDevice) Name() (string, error) {
|
||||
for {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return "", io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
name, err := dev.Name()
|
||||
if err == nil {
|
||||
return name, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// BatchSize returns the batch size
|
||||
func (d *MiddleDevice) BatchSize() int {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
return 1
|
||||
}
|
||||
return dev.BatchSize()
|
||||
}
|
||||
|
||||
// extractDestIP extracts destination IP from packet (fast path)
|
||||
func extractDestIP(packet []byte) (netip.Addr, bool) {
|
||||
if len(packet) < 20 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
version := packet[0] >> 4
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
if len(packet) < 20 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
// Destination IP is at bytes 16-19 for IPv4
|
||||
ip := netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
|
||||
return ip, true
|
||||
case 6:
|
||||
if len(packet) < 40 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
// Destination IP is at bytes 24-39 for IPv6
|
||||
var ip16 [16]byte
|
||||
copy(ip16[:], packet[24:40])
|
||||
ip := netip.AddrFrom16(ip16)
|
||||
return ip, true
|
||||
}
|
||||
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
for {
|
||||
if d.closed.Load() {
|
||||
logger.Debug("MiddleDevice: Read returning io.EOF, device closed")
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Wait for a device to be available
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Now block waiting for data from readCh or injectCh
|
||||
select {
|
||||
case res, ok := <-d.readCh:
|
||||
if !ok {
|
||||
// Channel closed, device is shutting down
|
||||
return 0, io.EOF
|
||||
}
|
||||
if res.err != nil {
|
||||
// Check if device was swapped
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
||||
return 0, res.err
|
||||
}
|
||||
|
||||
// Copy packets from result to provided buffers
|
||||
count := 0
|
||||
for i := 0; i < res.n && i < len(bufs); i++ {
|
||||
src := res.bufs[i]
|
||||
srcOffset := res.offset
|
||||
srcSize := res.sizes[i]
|
||||
|
||||
pktData := src[srcOffset : srcOffset+srcSize]
|
||||
|
||||
if len(bufs[i]) < offset+len(pktData) {
|
||||
continue
|
||||
}
|
||||
|
||||
copy(bufs[i][offset:], pktData)
|
||||
sizes[i] = len(pktData)
|
||||
count++
|
||||
}
|
||||
n = count
|
||||
|
||||
case pkt, ok := <-d.injectCh:
|
||||
if !ok {
|
||||
// Channel closed, device is shutting down
|
||||
return 0, io.EOF
|
||||
}
|
||||
if len(bufs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if len(bufs[0]) < offset+len(pkt) {
|
||||
return 0, nil
|
||||
}
|
||||
copy(bufs[0][offset:], pkt)
|
||||
sizes[0] = len(pkt)
|
||||
n = 1
|
||||
}
|
||||
|
||||
// Apply filtering rules
|
||||
d.rulesMutex.RLock()
|
||||
rules := d.rules
|
||||
d.rulesMutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Process packets and filter out handled ones
|
||||
writeIdx := 0
|
||||
for readIdx := 0; readIdx < n; readIdx++ {
|
||||
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
||||
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return writeIdx, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
||||
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
for {
|
||||
if d.closed.Load() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
d.rulesMutex.RLock()
|
||||
rules := d.rules
|
||||
d.rulesMutex.RUnlock()
|
||||
|
||||
var filteredBufs [][]byte
|
||||
if len(rules) == 0 {
|
||||
filteredBufs = bufs
|
||||
} else {
|
||||
filteredBufs = make([][]byte, 0, len(bufs))
|
||||
for _, buf := range bufs {
|
||||
if len(buf) <= offset {
|
||||
continue
|
||||
}
|
||||
|
||||
packet := buf[offset:]
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
continue
|
||||
}
|
||||
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredBufs) == 0 {
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
n, err := dev.Write(filteredBufs, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) waitForDevice() bool {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
for len(d.devices) == 0 && !d.closed.Load() {
|
||||
d.cond.Wait()
|
||||
}
|
||||
return !d.closed.Load()
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) peekLast() *closeAwareDevice {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if len(d.devices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return d.devices[len(d.devices)-1]
|
||||
}
|
||||
|
||||
// WriteToTun writes packets directly to the underlying TUN device,
|
||||
// bypassing WireGuard. This is useful for sending packets that should
|
||||
// appear to come from the TUN interface (e.g., DNS responses from a proxy).
|
||||
// Unlike Write(), this does not go through packet filtering rules.
|
||||
func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) {
|
||||
for {
|
||||
if d.closed.Load() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := dev.Write(bufs, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
102
device/middle_device_test.go
Normal file
102
device/middle_device_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/fosrl/newt/util"
|
||||
)
|
||||
|
||||
func TestExtractDestIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
packet []byte
|
||||
wantIP string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 packet",
|
||||
packet: []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
|
||||
0x0a, 0x1e, 0x1e, 0x1e, // Dest IP: 10.30.30.30
|
||||
},
|
||||
wantIP: "10.30.30.30",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "Too short packet",
|
||||
packet: []byte{0x45, 0x00},
|
||||
wantIP: "",
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotIP, gotOk := extractDestIP(tt.packet)
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("extractDestIP() ok = %v, want %v", gotOk, tt.wantOk)
|
||||
return
|
||||
}
|
||||
if tt.wantOk {
|
||||
wantAddr := netip.MustParseAddr(tt.wantIP)
|
||||
if gotIP != wantAddr {
|
||||
t.Errorf("extractDestIP() ip = %v, want %v", gotIP, wantAddr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProtocol(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
packet []byte
|
||||
wantProto uint8
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "UDP packet",
|
||||
packet: []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01, // Protocol: UDP (17) at byte 9
|
||||
0x0a, 0x1e, 0x1e, 0x1e,
|
||||
},
|
||||
wantProto: 17,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "Too short",
|
||||
packet: []byte{0x45, 0x00},
|
||||
wantProto: 0,
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotProto, gotOk := util.GetProtocol(tt.packet)
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
|
||||
return
|
||||
}
|
||||
if gotProto != tt.wantProto {
|
||||
t.Errorf("GetProtocol() proto = %v, want %v", gotProto, tt.wantProto)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExtractDestIP(b *testing.B) {
|
||||
packet := []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
|
||||
0x0a, 0x1e, 0x1e, 0x1e,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
extractDestIP(packet)
|
||||
}
|
||||
}
|
||||
44
device/tun_darwin.go
Normal file
44
device/tun_darwin.go
Normal file
@@ -0,0 +1,44 @@
|
||||
//go:build darwin
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
dupTunFd, err := unix.Dup(int(tunFd))
|
||||
if err != nil {
|
||||
logger.Error("Unable to dup tun fd: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(dupTunFd, true)
|
||||
if err != nil {
|
||||
unix.Close(dupTunFd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
|
||||
device, err := tun.CreateTUNFromFile(file, 0)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(interfaceName)
|
||||
}
|
||||
|
||||
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||
}
|
||||
50
device/tun_linux.go
Normal file
50
device/tun_linux.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build linux
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
if runtime.GOOS == "android" { // otherwise we get a permission denied
|
||||
theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
|
||||
return theTun, err
|
||||
}
|
||||
|
||||
dupTunFd, err := unix.Dup(int(tunFd))
|
||||
if err != nil {
|
||||
logger.Error("Unable to dup tun fd: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(dupTunFd, true)
|
||||
if err != nil {
|
||||
unix.Close(dupTunFd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
|
||||
device, err := tun.CreateTUNFromFile(file, mtuInt)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(interfaceName)
|
||||
}
|
||||
|
||||
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -11,15 +11,15 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
|
||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
return nil, errors.New("CreateTUNFromFile not supported on Windows")
|
||||
}
|
||||
|
||||
func uapiOpen(interfaceName string) (*os.File, error) {
|
||||
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
// On Windows, UAPIListen only takes one parameter
|
||||
return ipc.UAPIListen(interfaceName)
|
||||
}
|
||||
757
dns/dns_proxy.go
Normal file
757
dns/dns_proxy.go
Normal file
@@ -0,0 +1,757 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/device"
|
||||
"github.com/miekg/dns"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
DNSPort = 53
|
||||
)
|
||||
|
||||
// DNSProxy implements a DNS proxy using gvisor netstack
|
||||
type DNSProxy struct {
|
||||
stack *stack.Stack
|
||||
ep *channel.Endpoint
|
||||
proxyIP netip.Addr
|
||||
upstreamDNS []string
|
||||
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
|
||||
mtu int
|
||||
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes
|
||||
recordStore *DNSRecordStore // Local DNS records
|
||||
|
||||
// Tunnel DNS fields - for sending queries over WireGuard
|
||||
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
|
||||
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
|
||||
tunnelEp *channel.Endpoint
|
||||
tunnelActivePorts map[uint16]bool
|
||||
tunnelPortsLock sync.Mutex
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewDNSProxy creates a new DNS proxy
|
||||
func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
|
||||
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
||||
}
|
||||
|
||||
if len(upstreamDns) == 0 {
|
||||
return nil, fmt.Errorf("at least one upstream DNS server must be specified")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
proxy := &DNSProxy{
|
||||
proxyIP: proxyIP,
|
||||
mtu: mtu,
|
||||
middleDevice: middleDevice,
|
||||
upstreamDNS: upstreamDns,
|
||||
tunnelDNS: tunnelDns,
|
||||
recordStore: NewDNSRecordStore(),
|
||||
tunnelActivePorts: make(map[uint16]bool),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Parse tunnel IP if provided (needed for tunneled DNS)
|
||||
if tunnelIP != "" {
|
||||
addr, err := netip.ParseAddr(tunnelIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tunnel IP: %v", err)
|
||||
}
|
||||
proxy.tunnelIP = addr
|
||||
}
|
||||
|
||||
// Create gvisor netstack for receiving DNS queries
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
HandleLocal: true,
|
||||
}
|
||||
|
||||
proxy.ep = channel.New(256, uint32(mtu), "")
|
||||
proxy.stack = stack.New(stackOpts)
|
||||
|
||||
// Create NIC
|
||||
if err := proxy.stack.CreateNIC(1, proxy.ep); err != nil {
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
// Add IP address
|
||||
// Parse the proxy IP to get the octets
|
||||
ipBytes := proxyIP.As4()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
||||
}
|
||||
|
||||
if err := proxy.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to add protocol address: %v", err)
|
||||
}
|
||||
|
||||
// Add default route
|
||||
proxy.stack.AddRoute(tcpip.Route{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
// Initialize tunnel netstack if tunnel DNS is enabled
|
||||
if tunnelDns {
|
||||
if !proxy.tunnelIP.IsValid() {
|
||||
return nil, fmt.Errorf("tunnel IP is required when tunnelDNS is enabled")
|
||||
}
|
||||
|
||||
// TODO: DO WE NEED TO ESTABLISH ANOTHER NETSTACK HERE OR CAN WE COMBINE WITH WGTESTER?
|
||||
if err := proxy.initTunnelNetstack(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize tunnel netstack: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// initTunnelNetstack creates a separate netstack for outbound DNS queries through the tunnel
|
||||
func (p *DNSProxy) initTunnelNetstack() error {
|
||||
// Create gvisor netstack for outbound tunnel queries
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
HandleLocal: true,
|
||||
}
|
||||
|
||||
p.tunnelEp = channel.New(256, uint32(p.mtu), "")
|
||||
p.tunnelStack = stack.New(stackOpts)
|
||||
|
||||
// Create NIC
|
||||
if err := p.tunnelStack.CreateNIC(1, p.tunnelEp); err != nil {
|
||||
return fmt.Errorf("failed to create tunnel NIC: %v", err)
|
||||
}
|
||||
|
||||
// Add tunnel IP address (WireGuard interface IP)
|
||||
ipBytes := p.tunnelIP.As4()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
||||
}
|
||||
|
||||
if err := p.tunnelStack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return fmt.Errorf("failed to add tunnel protocol address: %v", err)
|
||||
}
|
||||
|
||||
// Add default route
|
||||
p.tunnelStack.AddRoute(tcpip.Route{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
// Register filter rule on MiddleDevice to intercept responses
|
||||
p.middleDevice.AddRule(p.tunnelIP, p.handleTunnelResponse)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleTunnelResponse handles packets coming back from the tunnel destined for the tunnel IP
|
||||
func (p *DNSProxy) handleTunnelResponse(packet []byte) bool {
|
||||
// Check if it's UDP
|
||||
proto, ok := util.GetProtocol(packet)
|
||||
if !ok || proto != 17 { // UDP
|
||||
return false
|
||||
}
|
||||
|
||||
// Check destination port - should be one of our active outbound ports
|
||||
port, ok := util.GetDestPort(packet)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if we are expecting a response on this port
|
||||
p.tunnelPortsLock.Lock()
|
||||
active := p.tunnelActivePorts[uint16(port)]
|
||||
p.tunnelPortsLock.Unlock()
|
||||
|
||||
if !active {
|
||||
return false
|
||||
}
|
||||
|
||||
// Inject into tunnel netstack
|
||||
version := packet[0] >> 4
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
p.tunnelEp.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||
case 6:
|
||||
p.tunnelEp.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||
default:
|
||||
pkb.DecRef()
|
||||
return false
|
||||
}
|
||||
|
||||
pkb.DecRef()
|
||||
return true // Handled
|
||||
}
|
||||
|
||||
// Start starts the DNS proxy and registers with the filter
|
||||
func (p *DNSProxy) Start() error {
|
||||
// Install packet filter rule
|
||||
p.middleDevice.AddRule(p.proxyIP, p.handlePacket)
|
||||
|
||||
// Start DNS listener
|
||||
p.wg.Add(2)
|
||||
go p.runDNSListener()
|
||||
go p.runPacketSender()
|
||||
|
||||
// Start tunnel packet sender if tunnel DNS is enabled
|
||||
if p.tunnelDNS {
|
||||
p.wg.Add(1)
|
||||
go p.runTunnelPacketSender()
|
||||
}
|
||||
|
||||
logger.Info("DNS proxy started on %s:%d (tunnelDNS=%v)", p.proxyIP.String(), DNSPort, p.tunnelDNS)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the DNS proxy
|
||||
func (p *DNSProxy) Stop() {
|
||||
if p.middleDevice != nil {
|
||||
p.middleDevice.RemoveRule(p.proxyIP)
|
||||
if p.tunnelDNS && p.tunnelIP.IsValid() {
|
||||
p.middleDevice.RemoveRule(p.tunnelIP)
|
||||
}
|
||||
}
|
||||
p.cancel()
|
||||
|
||||
// Close the endpoint first to unblock any pending Read() calls in runPacketSender
|
||||
if p.ep != nil {
|
||||
p.ep.Close()
|
||||
}
|
||||
|
||||
// Close tunnel endpoint if it exists
|
||||
if p.tunnelEp != nil {
|
||||
p.tunnelEp.Close()
|
||||
}
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
if p.stack != nil {
|
||||
p.stack.Close()
|
||||
}
|
||||
|
||||
if p.tunnelStack != nil {
|
||||
p.tunnelStack.Close()
|
||||
}
|
||||
|
||||
logger.Info("DNS proxy stopped")
|
||||
}
|
||||
|
||||
func (p *DNSProxy) GetProxyIP() netip.Addr {
|
||||
return p.proxyIP
|
||||
}
|
||||
|
||||
// handlePacket is called by the filter for packets destined to DNS proxy IP
|
||||
func (p *DNSProxy) handlePacket(packet []byte) bool {
|
||||
if len(packet) < 20 {
|
||||
return false // Don't drop, malformed
|
||||
}
|
||||
|
||||
// Quick check for UDP port 53
|
||||
proto, ok := util.GetProtocol(packet)
|
||||
if !ok || proto != 17 { // 17 = UDP
|
||||
return false // Not UDP, don't handle
|
||||
}
|
||||
|
||||
port, ok := util.GetDestPort(packet)
|
||||
if !ok || port != DNSPort {
|
||||
return false // Not DNS port
|
||||
}
|
||||
|
||||
// Inject packet into our netstack
|
||||
version := packet[0] >> 4
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
p.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||
case 6:
|
||||
p.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||
default:
|
||||
pkb.DecRef()
|
||||
return false
|
||||
}
|
||||
|
||||
pkb.DecRef()
|
||||
return true // Drop packet from normal path
|
||||
}
|
||||
|
||||
// runDNSListener listens for DNS queries on the netstack
|
||||
func (p *DNSProxy) runDNSListener() {
|
||||
defer p.wg.Done()
|
||||
|
||||
// Create UDP listener using gonet
|
||||
// Parse the proxy IP to get the octets
|
||||
ipBytes := p.proxyIP.As4()
|
||||
laddr := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4(ipBytes),
|
||||
Port: DNSPort,
|
||||
}
|
||||
|
||||
udpConn, err := gonet.DialUDP(p.stack, laddr, nil, ipv4.ProtocolNumber)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create DNS listener: %v", err)
|
||||
return
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
logger.Debug("DNS proxy listening on netstack")
|
||||
|
||||
// Handle DNS queries
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
udpConn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
n, remoteAddr, err := udpConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
logger.Error("DNS read error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
query := make([]byte, n)
|
||||
copy(query, buf[:n])
|
||||
|
||||
// Handle query in background
|
||||
go p.handleDNSQuery(udpConn, query, remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// handleDNSQuery processes a DNS query, checking local records first, then forwarding upstream
|
||||
func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clientAddr net.Addr) {
|
||||
// Parse the DNS query
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(queryData); err != nil {
|
||||
logger.Error("Failed to parse DNS query: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(msg.Question) == 0 {
|
||||
logger.Debug("DNS query has no questions")
|
||||
return
|
||||
}
|
||||
|
||||
question := msg.Question[0]
|
||||
logger.Debug("DNS query for %s (type %s)", question.Name, dns.TypeToString[question.Qtype])
|
||||
|
||||
// Check if we have local records for this query
|
||||
var response *dns.Msg
|
||||
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA || question.Qtype == dns.TypePTR {
|
||||
response = p.checkLocalRecords(msg, question)
|
||||
}
|
||||
|
||||
// If no local records, forward to upstream
|
||||
if response == nil {
|
||||
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
|
||||
response = p.forwardToUpstream(msg)
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
logger.Error("Failed to get DNS response for %s", question.Name)
|
||||
return
|
||||
}
|
||||
|
||||
// Pack and send response
|
||||
responseData, err := response.Pack()
|
||||
if err != nil {
|
||||
logger.Error("Failed to pack DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = udpConn.WriteTo(responseData, clientAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to send DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// checkLocalRecords checks if we have local records for the query
|
||||
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
|
||||
// Handle PTR queries
|
||||
if question.Qtype == dns.TypePTR {
|
||||
if ptrDomain, ok := p.recordStore.GetPTRRecord(question.Name); ok {
|
||||
logger.Debug("Found local PTR record for %s -> %s", question.Name, ptrDomain)
|
||||
|
||||
// Create response message
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(query)
|
||||
response.Authoritative = true
|
||||
|
||||
// Add PTR answer record
|
||||
rr := &dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300, // 5 minutes
|
||||
},
|
||||
Ptr: ptrDomain,
|
||||
}
|
||||
response.Answer = append(response.Answer, rr)
|
||||
|
||||
return response
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle A and AAAA queries
|
||||
var recordType RecordType
|
||||
if question.Qtype == dns.TypeA {
|
||||
recordType = RecordTypeA
|
||||
} else if question.Qtype == dns.TypeAAAA {
|
||||
recordType = RecordTypeAAAA
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
ips := p.recordStore.GetRecords(question.Name, recordType)
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
|
||||
|
||||
// Create response message
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(query)
|
||||
response.Authoritative = true
|
||||
|
||||
// Add answer records
|
||||
for _, ip := range ips {
|
||||
var rr dns.RR
|
||||
if question.Qtype == dns.TypeA {
|
||||
rr = &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300, // 5 minutes
|
||||
},
|
||||
A: ip.To4(),
|
||||
}
|
||||
} else { // TypeAAAA
|
||||
rr = &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300, // 5 minutes
|
||||
},
|
||||
AAAA: ip.To16(),
|
||||
}
|
||||
}
|
||||
response.Answer = append(response.Answer, rr)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// forwardToUpstream forwards a DNS query to upstream DNS servers
|
||||
func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
|
||||
// Try primary DNS server
|
||||
response, err := p.queryUpstream(p.upstreamDNS[0], query, 2*time.Second)
|
||||
if err != nil && len(p.upstreamDNS) > 1 {
|
||||
// Try secondary DNS server
|
||||
logger.Debug("Primary DNS failed, trying secondary: %v", err)
|
||||
response, err = p.queryUpstream(p.upstreamDNS[1], query, 2*time.Second)
|
||||
if err != nil {
|
||||
logger.Error("Both DNS servers failed: %v", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
// queryUpstream sends a DNS query to upstream server
|
||||
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||
if p.tunnelDNS {
|
||||
return p.queryUpstreamTunnel(server, query, timeout)
|
||||
}
|
||||
return p.queryUpstreamDirect(server, query, timeout)
|
||||
}
|
||||
|
||||
// queryUpstreamDirect sends a DNS query to upstream server using miekg/dns directly (host networking)
|
||||
func (p *DNSProxy) queryUpstreamDirect(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||
client := &dns.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
response, _, err := client.Exchange(query, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// queryUpstreamTunnel sends a DNS query through the WireGuard tunnel
|
||||
func (p *DNSProxy) queryUpstreamTunnel(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||
// Dial through the tunnel netstack
|
||||
conn, port, err := p.dialTunnel("udp", server)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial tunnel: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
conn.Close()
|
||||
p.removeTunnelPort(port)
|
||||
}()
|
||||
|
||||
// Pack the query
|
||||
queryData, err := query.Pack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pack query: %v", err)
|
||||
}
|
||||
|
||||
// Set deadline
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
|
||||
// Send the query
|
||||
_, err = conn.Write(queryData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send query: %v", err)
|
||||
}
|
||||
|
||||
// Read the response
|
||||
buf := make([]byte, 4096)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
// Parse the response
|
||||
response := new(dns.Msg)
|
||||
if err := response.Unpack(buf[:n]); err != nil {
|
||||
return nil, fmt.Errorf("failed to unpack response: %v", err)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// dialTunnel creates a UDP connection through the tunnel netstack
|
||||
func (p *DNSProxy) dialTunnel(network, addr string) (net.Conn, uint16, error) {
|
||||
if p.tunnelStack == nil {
|
||||
return nil, 0, fmt.Errorf("tunnel netstack not initialized")
|
||||
}
|
||||
|
||||
// Parse remote address
|
||||
raddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Use tunnel IP as source
|
||||
ipBytes := p.tunnelIP.As4()
|
||||
|
||||
// Create UDP connection with ephemeral port
|
||||
laddr := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4(ipBytes),
|
||||
Port: 0,
|
||||
}
|
||||
|
||||
raddrTcpip := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
|
||||
Port: uint16(raddr.Port),
|
||||
}
|
||||
|
||||
conn, err := gonet.DialUDP(p.tunnelStack, laddr, raddrTcpip, ipv4.ProtocolNumber)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Get local port
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
port := uint16(localAddr.Port)
|
||||
|
||||
// Register port so we can receive responses
|
||||
p.tunnelPortsLock.Lock()
|
||||
p.tunnelActivePorts[port] = true
|
||||
p.tunnelPortsLock.Unlock()
|
||||
|
||||
return conn, port, nil
|
||||
}
|
||||
|
||||
// removeTunnelPort removes a port from the active ports map
|
||||
func (p *DNSProxy) removeTunnelPort(port uint16) {
|
||||
p.tunnelPortsLock.Lock()
|
||||
delete(p.tunnelActivePorts, port)
|
||||
p.tunnelPortsLock.Unlock()
|
||||
}
|
||||
|
||||
// runTunnelPacketSender reads packets from tunnel netstack and injects them into WireGuard
|
||||
func (p *DNSProxy) runTunnelPacketSender() {
|
||||
defer p.wg.Done()
|
||||
logger.Debug("DNS tunnel packet sender goroutine started")
|
||||
|
||||
for {
|
||||
// Use blocking ReadContext instead of polling - much more CPU efficient
|
||||
// This will block until a packet is available or context is cancelled
|
||||
pkt := p.tunnelEp.ReadContext(p.ctx)
|
||||
if pkt == nil {
|
||||
// Context was cancelled or endpoint closed
|
||||
logger.Debug("DNS tunnel packet sender exiting")
|
||||
// Drain any remaining packets
|
||||
for {
|
||||
pkt := p.tunnelEp.Read()
|
||||
if pkt == nil {
|
||||
break
|
||||
}
|
||||
pkt.DecRef()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Extract packet data
|
||||
slices := pkt.AsSlices()
|
||||
if len(slices) > 0 {
|
||||
var totalSize int
|
||||
for _, slice := range slices {
|
||||
totalSize += len(slice)
|
||||
}
|
||||
|
||||
buf := make([]byte, totalSize)
|
||||
pos := 0
|
||||
for _, slice := range slices {
|
||||
copy(buf[pos:], slice)
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Inject into MiddleDevice (outbound to WG)
|
||||
p.middleDevice.InjectOutbound(buf)
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
// runPacketSender sends packets from netstack back to TUN
|
||||
func (p *DNSProxy) runPacketSender() {
|
||||
defer p.wg.Done()
|
||||
|
||||
// MessageTransportHeaderSize is the offset used by WireGuard device
|
||||
// for reading/writing packets to the TUN interface
|
||||
const offset = 16
|
||||
|
||||
for {
|
||||
// Use blocking ReadContext instead of polling - much more CPU efficient
|
||||
// This will block until a packet is available or context is cancelled
|
||||
pkt := p.ep.ReadContext(p.ctx)
|
||||
if pkt == nil {
|
||||
// Context was cancelled or endpoint closed
|
||||
return
|
||||
}
|
||||
|
||||
// Extract packet data as slices
|
||||
slices := pkt.AsSlices()
|
||||
if len(slices) > 0 {
|
||||
// Flatten all slices into a single packet buffer
|
||||
var totalSize int
|
||||
for _, slice := range slices {
|
||||
totalSize += len(slice)
|
||||
}
|
||||
|
||||
// Allocate buffer with offset space for WireGuard transport header
|
||||
// The first 'offset' bytes are reserved for the transport header
|
||||
buf := make([]byte, offset+totalSize)
|
||||
|
||||
// Copy packet data after the offset
|
||||
pos := offset
|
||||
for _, slice := range slices {
|
||||
copy(buf[pos:], slice)
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Write packet to TUN device via MiddleDevice
|
||||
// offset=16 indicates packet data starts at position 16 in the buffer
|
||||
_, err := p.middleDevice.WriteToTun([][]byte{buf}, offset)
|
||||
if err != nil {
|
||||
logger.Error("Failed to write DNS response to TUN: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
// AddDNSRecord adds a DNS record to the local store
|
||||
// domain should be a domain name (e.g., "example.com" or "example.com.")
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
|
||||
return p.recordStore.AddRecord(domain, ip)
|
||||
}
|
||||
|
||||
// RemoveDNSRecord removes a DNS record from the local store
|
||||
// If ip is nil, removes all records for the domain
|
||||
func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
|
||||
p.recordStore.RemoveRecord(domain, ip)
|
||||
}
|
||||
|
||||
// GetDNSRecords returns all IP addresses for a domain and record type
|
||||
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
|
||||
return p.recordStore.GetRecords(domain, recordType)
|
||||
}
|
||||
|
||||
// ClearDNSRecords removes all DNS records from the local store
|
||||
func (p *DNSProxy) ClearDNSRecords() {
|
||||
p.recordStore.Clear()
|
||||
}
|
||||
|
||||
func PickIPFromSubnet(subnet string) (netip.Addr, error) {
|
||||
// given a subnet in CIDR notation, pick the first usable IP
|
||||
prefix, err := netip.ParsePrefix(subnet)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invalid subnet: %w", err)
|
||||
}
|
||||
|
||||
// Pick the first usable IP address from the subnet
|
||||
ip := prefix.Addr().Next()
|
||||
if !ip.IsValid() {
|
||||
return netip.Addr{}, fmt.Errorf("no valid IP address found in subnet: %s", subnet)
|
||||
}
|
||||
|
||||
return ip, nil
|
||||
}
|
||||
497
dns/dns_records.go
Normal file
497
dns/dns_records.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// RecordType represents the type of DNS record
|
||||
type RecordType uint16
|
||||
|
||||
const (
|
||||
RecordTypeA RecordType = RecordType(dns.TypeA)
|
||||
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
|
||||
RecordTypePTR RecordType = RecordType(dns.TypePTR)
|
||||
)
|
||||
|
||||
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries
|
||||
type DNSRecordStore struct {
|
||||
mu sync.RWMutex
|
||||
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
||||
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
||||
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
|
||||
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
|
||||
ptrRecords map[string]string // IP address string -> domain name
|
||||
}
|
||||
|
||||
// NewDNSRecordStore creates a new DNS record store
|
||||
func NewDNSRecordStore() *DNSRecordStore {
|
||||
return &DNSRecordStore{
|
||||
aRecords: make(map[string][]net.IP),
|
||||
aaaaRecords: make(map[string][]net.IP),
|
||||
aWildcards: make(map[string][]net.IP),
|
||||
aaaaWildcards: make(map[string][]net.IP),
|
||||
ptrRecords: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRecord adds a DNS record mapping (A or AAAA)
|
||||
// domain should be in FQDN format (e.g., "example.com.")
|
||||
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
// Automatically adds a corresponding PTR record for non-wildcard domains
|
||||
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check if domain contains wildcards
|
||||
isWildcard := strings.ContainsAny(domain, "*?")
|
||||
|
||||
if ip.To4() != nil {
|
||||
// IPv4 address
|
||||
if isWildcard {
|
||||
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
|
||||
} else {
|
||||
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
||||
// Automatically add PTR record for non-wildcard domains
|
||||
s.ptrRecords[ip.String()] = domain
|
||||
}
|
||||
} else if ip.To16() != nil {
|
||||
// IPv6 address
|
||||
if isWildcard {
|
||||
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
|
||||
} else {
|
||||
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
||||
// Automatically add PTR record for non-wildcard domains
|
||||
s.ptrRecords[ip.String()] = domain
|
||||
}
|
||||
} else {
|
||||
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
// domain should be in FQDN format (e.g., "example.com.")
|
||||
func (s *DNSRecordStore) AddPTRRecord(ip net.IP, domain string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Store PTR record using IP string as key
|
||||
s.ptrRecords[ip.String()] = domain
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRecord removes a specific DNS record mapping
|
||||
// If ip is nil, removes all records for the domain (including wildcards)
|
||||
// Automatically removes corresponding PTR records for non-wildcard domains
|
||||
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check if domain contains wildcards
|
||||
isWildcard := strings.ContainsAny(domain, "*?")
|
||||
|
||||
if ip == nil {
|
||||
// Remove all records for this domain
|
||||
if isWildcard {
|
||||
delete(s.aWildcards, domain)
|
||||
delete(s.aaaaWildcards, domain)
|
||||
} else {
|
||||
// For non-wildcard domains, remove PTR records for all IPs
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
for _, ipAddr := range ips {
|
||||
// Only remove PTR if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ipAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
for _, ipAddr := range ips {
|
||||
// Only remove PTR if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ipAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(s.aRecords, domain)
|
||||
delete(s.aaaaRecords, domain)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
// Remove specific IPv4 address
|
||||
if isWildcard {
|
||||
if ips, ok := s.aWildcards[domain]; ok {
|
||||
s.aWildcards[domain] = removeIP(ips, ip)
|
||||
if len(s.aWildcards[domain]) == 0 {
|
||||
delete(s.aWildcards, domain)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
s.aRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aRecords[domain]) == 0 {
|
||||
delete(s.aRecords, domain)
|
||||
}
|
||||
// Automatically remove PTR record if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if ip.To16() != nil {
|
||||
// Remove specific IPv6 address
|
||||
if isWildcard {
|
||||
if ips, ok := s.aaaaWildcards[domain]; ok {
|
||||
s.aaaaWildcards[domain] = removeIP(ips, ip)
|
||||
if len(s.aaaaWildcards[domain]) == 0 {
|
||||
delete(s.aaaaWildcards, domain)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
s.aaaaRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aaaaRecords[domain]) == 0 {
|
||||
delete(s.aaaaRecords, domain)
|
||||
}
|
||||
// Automatically remove PTR record if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemovePTRRecord removes a PTR record for an IP address
|
||||
func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
|
||||
// GetRecords returns all IP addresses for a domain and record type
|
||||
// First checks for exact matches, then checks wildcard patterns
|
||||
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
var records []net.IP
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
// Check exact match first
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
return records
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern, ips := range s.aWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
records = append(records, ips...)
|
||||
}
|
||||
}
|
||||
if len(records) > 0 {
|
||||
// Return a copy
|
||||
result := make([]net.IP, len(records))
|
||||
copy(result, records)
|
||||
return result
|
||||
}
|
||||
|
||||
case RecordTypeAAAA:
|
||||
// Check exact match first
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
return records
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern, ips := range s.aaaaWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
records = append(records, ips...)
|
||||
}
|
||||
}
|
||||
if len(records) > 0 {
|
||||
// Return a copy
|
||||
result := make([]net.IP, len(records))
|
||||
copy(result, records)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// GetPTRRecord returns the domain name for a PTR record query
|
||||
// domain should be in reverse DNS format (e.g., "1.0.0.127.in-addr.arpa.")
|
||||
func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Convert reverse DNS format to IP address
|
||||
ip := reverseDNSToIP(domain)
|
||||
if ip == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Look up the PTR record
|
||||
if ptrDomain, ok := s.ptrRecords[ip.String()]; ok {
|
||||
return ptrDomain, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// HasRecord checks if a domain has any records of the specified type
|
||||
// Checks both exact matches and wildcard patterns
|
||||
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
// Check exact match
|
||||
if _, ok := s.aRecords[domain]; ok {
|
||||
return true
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern := range s.aWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case RecordTypeAAAA:
|
||||
// Check exact match
|
||||
if _, ok := s.aaaaRecords[domain]; ok {
|
||||
return true
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern := range s.aaaaWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HasPTRRecord checks if a PTR record exists for the given reverse DNS domain
|
||||
func (s *DNSRecordStore) HasPTRRecord(domain string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Convert reverse DNS format to IP address
|
||||
ip := reverseDNSToIP(domain)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
_, ok := s.ptrRecords[ip.String()]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Clear removes all records from the store
|
||||
func (s *DNSRecordStore) Clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.aRecords = make(map[string][]net.IP)
|
||||
s.aaaaRecords = make(map[string][]net.IP)
|
||||
s.aWildcards = make(map[string][]net.IP)
|
||||
s.aaaaWildcards = make(map[string][]net.IP)
|
||||
s.ptrRecords = make(map[string]string)
|
||||
}
|
||||
|
||||
// removeIP is a helper function to remove a specific IP from a slice
|
||||
func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
|
||||
result := make([]net.IP, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
if !ip.Equal(toRemove) {
|
||||
result = append(result, ip)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// matchWildcard checks if a domain matches a wildcard pattern
|
||||
// Pattern supports * (0+ chars) and ? (exactly 1 char)
|
||||
// Special case: *.domain.com does not match domain.com itself
|
||||
func matchWildcard(pattern, domain string) bool {
|
||||
return matchWildcardInternal(pattern, domain, 0, 0)
|
||||
}
|
||||
|
||||
// matchWildcardInternal performs the actual wildcard matching recursively
|
||||
func matchWildcardInternal(pattern, domain string, pi, di int) bool {
|
||||
plen := len(pattern)
|
||||
dlen := len(domain)
|
||||
|
||||
// Base cases
|
||||
if pi == plen && di == dlen {
|
||||
return true
|
||||
}
|
||||
if pi == plen {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle wildcard characters
|
||||
if pattern[pi] == '*' {
|
||||
// Special case: if pattern starts with "*." and we're at the beginning,
|
||||
// ensure we don't match the domain without a prefix
|
||||
// e.g., *.autoco.internal should not match autoco.internal
|
||||
if pi == 0 && pi+1 < plen && pattern[pi+1] == '.' {
|
||||
// The * must match at least one character
|
||||
if di == dlen {
|
||||
return false
|
||||
}
|
||||
// Try matching 1 or more characters before the dot
|
||||
for i := di + 1; i <= dlen; i++ {
|
||||
if matchWildcardInternal(pattern, domain, pi+1, i) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Normal * matching (0 or more characters)
|
||||
// Try matching 0 characters (skip the *)
|
||||
if matchWildcardInternal(pattern, domain, pi+1, di) {
|
||||
return true
|
||||
}
|
||||
// Try matching 1+ characters
|
||||
if di < dlen {
|
||||
return matchWildcardInternal(pattern, domain, pi, di+1)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if pattern[pi] == '?' {
|
||||
// ? matches exactly one character
|
||||
if di >= dlen {
|
||||
return false
|
||||
}
|
||||
return matchWildcardInternal(pattern, domain, pi+1, di+1)
|
||||
}
|
||||
|
||||
// Regular character - must match exactly
|
||||
if di >= dlen || pattern[pi] != domain[di] {
|
||||
return false
|
||||
}
|
||||
|
||||
return matchWildcardInternal(pattern, domain, pi+1, di+1)
|
||||
}
|
||||
|
||||
// reverseDNSToIP converts a reverse DNS query name to an IP address
|
||||
// Supports both IPv4 (in-addr.arpa) and IPv6 (ip6.arpa) formats
|
||||
func reverseDNSToIP(domain string) net.IP {
|
||||
// Normalize to lowercase and ensure FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check for IPv4 reverse DNS (in-addr.arpa)
|
||||
if strings.HasSuffix(domain, ".in-addr.arpa.") {
|
||||
// Remove the suffix
|
||||
ipPart := strings.TrimSuffix(domain, ".in-addr.arpa.")
|
||||
// Split by dots and reverse
|
||||
parts := strings.Split(ipPart, ".")
|
||||
if len(parts) != 4 {
|
||||
return nil
|
||||
}
|
||||
// Reverse the octets
|
||||
reversed := make([]string, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
reversed[i] = parts[3-i]
|
||||
}
|
||||
// Parse as IP
|
||||
return net.ParseIP(strings.Join(reversed, "."))
|
||||
}
|
||||
|
||||
// Check for IPv6 reverse DNS (ip6.arpa)
|
||||
if strings.HasSuffix(domain, ".ip6.arpa.") {
|
||||
// Remove the suffix
|
||||
ipPart := strings.TrimSuffix(domain, ".ip6.arpa.")
|
||||
// Split by dots and reverse
|
||||
parts := strings.Split(ipPart, ".")
|
||||
if len(parts) != 32 {
|
||||
return nil
|
||||
}
|
||||
// Reverse the nibbles and group into 16-bit hex values
|
||||
reversed := make([]string, 32)
|
||||
for i := 0; i < 32; i++ {
|
||||
reversed[i] = parts[31-i]
|
||||
}
|
||||
// Join into IPv6 format (groups of 4 nibbles separated by colons)
|
||||
var ipv6Parts []string
|
||||
for i := 0; i < 32; i += 4 {
|
||||
ipv6Parts = append(ipv6Parts, reversed[i]+reversed[i+1]+reversed[i+2]+reversed[i+3])
|
||||
}
|
||||
// Parse as IP
|
||||
return net.ParseIP(strings.Join(ipv6Parts, ":"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPToReverseDNS converts an IP address to reverse DNS format
|
||||
// Returns the domain name for PTR queries (e.g., "1.0.0.127.in-addr.arpa.")
|
||||
func IPToReverseDNS(ip net.IP) string {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// IPv4: reverse octets and append .in-addr.arpa.
|
||||
return dns.Fqdn(fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa",
|
||||
ip4[3], ip4[2], ip4[1], ip4[0]))
|
||||
}
|
||||
|
||||
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
|
||||
// IPv6: expand to 32 nibbles, reverse, and append .ip6.arpa.
|
||||
var nibbles []string
|
||||
for i := 15; i >= 0; i-- {
|
||||
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]&0x0f))
|
||||
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]>>4))
|
||||
}
|
||||
return dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
864
dns/dns_records_test.go
Normal file
864
dns/dns_records_test.go
Normal file
@@ -0,0 +1,864 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWildcardMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
domain string
|
||||
expected bool
|
||||
}{
|
||||
// Basic wildcard tests
|
||||
{
|
||||
name: "*.autoco.internal matches host.autoco.internal",
|
||||
pattern: "*.autoco.internal.",
|
||||
domain: "host.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.autoco.internal matches longerhost.autoco.internal",
|
||||
pattern: "*.autoco.internal.",
|
||||
domain: "longerhost.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.autoco.internal matches sub.host.autoco.internal",
|
||||
pattern: "*.autoco.internal.",
|
||||
domain: "sub.host.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.autoco.internal does NOT match autoco.internal",
|
||||
pattern: "*.autoco.internal.",
|
||||
domain: "autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Question mark wildcard tests
|
||||
{
|
||||
name: "host-0?.autoco.internal matches host-01.autoco.internal",
|
||||
pattern: "host-0?.autoco.internal.",
|
||||
domain: "host-01.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "host-0?.autoco.internal matches host-0a.autoco.internal",
|
||||
pattern: "host-0?.autoco.internal.",
|
||||
domain: "host-0a.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "host-0?.autoco.internal does NOT match host-0.autoco.internal",
|
||||
pattern: "host-0?.autoco.internal.",
|
||||
domain: "host-0.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "host-0?.autoco.internal does NOT match host-012.autoco.internal",
|
||||
pattern: "host-0?.autoco.internal.",
|
||||
domain: "host-012.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Combined wildcard tests
|
||||
{
|
||||
name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal",
|
||||
pattern: "*.host-0?.autoco.internal.",
|
||||
domain: "sub.host-01.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.host-0?.autoco.internal matches prefix.host-0a.autoco.internal",
|
||||
pattern: "*.host-0?.autoco.internal.",
|
||||
domain: "prefix.host-0a.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.host-0?.autoco.internal does NOT match host-01.autoco.internal",
|
||||
pattern: "*.host-0?.autoco.internal.",
|
||||
domain: "host-01.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Multiple asterisks
|
||||
{
|
||||
name: "*.*. autoco.internal matches any.thing.autoco.internal",
|
||||
pattern: "*.*.autoco.internal.",
|
||||
domain: "any.thing.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.*.autoco.internal does NOT match single.autoco.internal",
|
||||
pattern: "*.*.autoco.internal.",
|
||||
domain: "single.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Asterisk in middle
|
||||
{
|
||||
name: "host-*.autoco.internal matches host-anything.autoco.internal",
|
||||
pattern: "host-*.autoco.internal.",
|
||||
domain: "host-anything.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "host-*.autoco.internal matches host-.autoco.internal (empty match)",
|
||||
pattern: "host-*.autoco.internal.",
|
||||
domain: "host-.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Multiple question marks
|
||||
{
|
||||
name: "host-??.autoco.internal matches host-01.autoco.internal",
|
||||
pattern: "host-??.autoco.internal.",
|
||||
domain: "host-01.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "host-??.autoco.internal does NOT match host-1.autoco.internal",
|
||||
pattern: "host-??.autoco.internal.",
|
||||
domain: "host-1.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Exact match (no wildcards)
|
||||
{
|
||||
name: "exact.autoco.internal matches exact.autoco.internal",
|
||||
pattern: "exact.autoco.internal.",
|
||||
domain: "exact.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "exact.autoco.internal does NOT match other.autoco.internal",
|
||||
pattern: "exact.autoco.internal.",
|
||||
domain: "other.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "* matches anything",
|
||||
pattern: "*",
|
||||
domain: "anything.at.all.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.* matches multi.level.",
|
||||
pattern: "*.*",
|
||||
domain: "multi.level.",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchWildcard(tt.pattern, tt.domain)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchWildcard(%q, %q) = %v, want %v", tt.pattern, tt.domain, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreWildcard(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add wildcard records
|
||||
wildcardIP := net.ParseIP("10.0.0.1")
|
||||
err := store.AddRecord("*.autoco.internal", wildcardIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Add exact record
|
||||
exactIP := net.ParseIP("10.0.0.2")
|
||||
err = store.AddRecord("exact.autoco.internal", exactIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add exact record: %v", err)
|
||||
}
|
||||
|
||||
// Test exact match takes precedence
|
||||
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
|
||||
}
|
||||
if !ips[0].Equal(exactIP) {
|
||||
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
|
||||
}
|
||||
|
||||
// Test wildcard match
|
||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
|
||||
}
|
||||
if !ips[0].Equal(wildcardIP) {
|
||||
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
|
||||
}
|
||||
|
||||
// Test non-match (base domain)
|
||||
ips = store.GetRecords("autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add complex wildcard pattern
|
||||
ip1 := net.ParseIP("10.0.0.1")
|
||||
err := store.AddRecord("*.host-0?.autoco.internal", ip1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Test matching domain
|
||||
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
|
||||
}
|
||||
if len(ips) > 0 && !ips[0].Equal(ip1) {
|
||||
t.Errorf("Expected IP %v, got %v", ip1, ips[0])
|
||||
}
|
||||
|
||||
// Test non-matching domain (missing prefix)
|
||||
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test non-matching domain (wrong ? position)
|
||||
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add wildcard record
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
err := store.AddRecord("*.autoco.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Verify it exists
|
||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
|
||||
}
|
||||
|
||||
// Remove wildcard record
|
||||
store.RemoveRecord("*.autoco.internal", nil)
|
||||
|
||||
// Verify it's gone
|
||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add multiple wildcard patterns that don't overlap
|
||||
ip1 := net.ParseIP("10.0.0.1")
|
||||
ip2 := net.ParseIP("10.0.0.2")
|
||||
ip3 := net.ParseIP("10.0.0.3")
|
||||
|
||||
err := store.AddRecord("*.prod.autoco.internal", ip1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add first wildcard: %v", err)
|
||||
}
|
||||
|
||||
err = store.AddRecord("*.dev.autoco.internal", ip2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add second wildcard: %v", err)
|
||||
}
|
||||
|
||||
// Add a broader wildcard that matches both
|
||||
err = store.AddRecord("*.autoco.internal", ip3)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add third wildcard: %v", err)
|
||||
}
|
||||
|
||||
// Test domain matching only the prod pattern and the broad pattern
|
||||
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test domain matching only the dev pattern and the broad pattern
|
||||
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test domain matching only the broad pattern
|
||||
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add IPv6 wildcard record
|
||||
ip := net.ParseIP("2001:db8::1")
|
||||
err := store.AddRecord("*.autoco.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Test wildcard match for IPv6
|
||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
|
||||
}
|
||||
if len(ips) > 0 && !ips[0].Equal(ip) {
|
||||
t.Errorf("Expected IPv6 %v, got %v", ip, ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasRecordWildcard(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add wildcard record
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
err := store.AddRecord("*.autoco.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Test HasRecord with wildcard match
|
||||
if !store.HasRecord("host.autoco.internal.", RecordTypeA) {
|
||||
t.Error("Expected HasRecord to return true for wildcard match")
|
||||
}
|
||||
|
||||
// Test HasRecord with non-match
|
||||
if store.HasRecord("autoco.internal.", RecordTypeA) {
|
||||
t.Error("Expected HasRecord to return false for base domain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add record with mixed case
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
err := store.AddRecord("MyHost.AutoCo.Internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add mixed case record: %v", err)
|
||||
}
|
||||
|
||||
// Test lookup with different cases
|
||||
testCases := []string{
|
||||
"myhost.autoco.internal.",
|
||||
"MYHOST.AUTOCO.INTERNAL.",
|
||||
"MyHost.AutoCo.Internal.",
|
||||
"mYhOsT.aUtOcO.iNtErNaL.",
|
||||
}
|
||||
|
||||
for _, domain := range testCases {
|
||||
ips := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
|
||||
}
|
||||
if len(ips) > 0 && !ips[0].Equal(ip) {
|
||||
t.Errorf("Expected IP %v for domain %q, got %v", ip, domain, ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Test wildcard with mixed case
|
||||
wildcardIP := net.ParseIP("10.0.0.2")
|
||||
err = store.AddRecord("*.Example.Com", wildcardIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add mixed case wildcard: %v", err)
|
||||
}
|
||||
|
||||
wildcardTestCases := []string{
|
||||
"host.example.com.",
|
||||
"HOST.EXAMPLE.COM.",
|
||||
"Host.Example.Com.",
|
||||
"HoSt.ExAmPlE.CoM.",
|
||||
}
|
||||
|
||||
for _, domain := range wildcardTestCases {
|
||||
ips := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
|
||||
}
|
||||
if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
|
||||
t.Errorf("Expected IP %v for wildcard domain %q, got %v", wildcardIP, domain, ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Test removal with different case
|
||||
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
|
||||
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test HasRecord with different case
|
||||
if !store.HasRecord("HOST.EXAMPLE.COM.", RecordTypeA) {
|
||||
t.Error("Expected HasRecord to return true for mixed case wildcard match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordIPv4(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record for IPv4
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
domain := "host.example.com."
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Test reverse DNS lookup
|
||||
reverseDomain := "1.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be found")
|
||||
}
|
||||
if result != domain {
|
||||
t.Errorf("Expected domain %q, got %q", domain, result)
|
||||
}
|
||||
|
||||
// Test HasPTRRecord
|
||||
if !store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected HasPTRRecord to return true")
|
||||
}
|
||||
|
||||
// Test non-existent PTR record
|
||||
_, ok = store.GetPTRRecord("2.1.168.192.in-addr.arpa.")
|
||||
if ok {
|
||||
t.Error("Expected PTR record not to be found for different IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordIPv6(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record for IPv6
|
||||
ip := net.ParseIP("2001:db8::1")
|
||||
domain := "ipv6host.example.com."
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Test reverse DNS lookup
|
||||
// 2001:db8::1 = 2001:0db8:0000:0000:0000:0000:0000:0001
|
||||
// Reverse: 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
|
||||
reverseDomain := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected IPv6 PTR record to be found")
|
||||
}
|
||||
if result != domain {
|
||||
t.Errorf("Expected domain %q, got %q", domain, result)
|
||||
}
|
||||
|
||||
// Test HasPTRRecord
|
||||
if !store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected HasPTRRecord to return true for IPv6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemovePTRRecord(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
domain := "test.example.com."
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Verify it exists
|
||||
reverseDomain := "1.0.0.10.in-addr.arpa."
|
||||
_, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to exist before removal")
|
||||
}
|
||||
|
||||
// Remove PTR record
|
||||
store.RemovePTRRecord(ip)
|
||||
|
||||
// Verify it's gone
|
||||
_, ok = store.GetPTRRecord(reverseDomain)
|
||||
if ok {
|
||||
t.Error("Expected PTR record to be removed")
|
||||
}
|
||||
|
||||
// Test HasPTRRecord after removal
|
||||
if store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected HasPTRRecord to return false after removal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPToReverseDNS(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 simple",
|
||||
ip: "192.168.1.1",
|
||||
expected: "1.1.168.192.in-addr.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv4 localhost",
|
||||
ip: "127.0.0.1",
|
||||
expected: "1.0.0.127.in-addr.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv4 with zeros",
|
||||
ip: "10.0.0.1",
|
||||
expected: "1.0.0.10.in-addr.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv6 simple",
|
||||
ip: "2001:db8::1",
|
||||
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost",
|
||||
ip: "::1",
|
||||
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("Failed to parse IP: %s", tt.ip)
|
||||
}
|
||||
result := IPToReverseDNS(ip)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IPToReverseDNS(%s) = %q, want %q", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseDNSToIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reverseDNS string
|
||||
expectedIP string
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 simple",
|
||||
reverseDNS: "1.1.168.192.in-addr.arpa.",
|
||||
expectedIP: "192.168.1.1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 localhost",
|
||||
reverseDNS: "1.0.0.127.in-addr.arpa.",
|
||||
expectedIP: "127.0.0.1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 simple",
|
||||
reverseDNS: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
expectedIP: "2001:db8::1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid IPv4 format",
|
||||
reverseDNS: "1.1.168.in-addr.arpa.",
|
||||
expectedIP: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid IPv6 format",
|
||||
reverseDNS: "1.0.0.0.ip6.arpa.",
|
||||
expectedIP: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "Not a reverse DNS domain",
|
||||
reverseDNS: "example.com.",
|
||||
expectedIP: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reverseDNSToIP(tt.reverseDNS)
|
||||
if tt.shouldMatch {
|
||||
if result == nil {
|
||||
t.Errorf("reverseDNSToIP(%q) returned nil, expected IP", tt.reverseDNS)
|
||||
return
|
||||
}
|
||||
expectedIP := net.ParseIP(tt.expectedIP)
|
||||
if !result.Equal(expectedIP) {
|
||||
t.Errorf("reverseDNSToIP(%q) = %v, want %v", tt.reverseDNS, result, expectedIP)
|
||||
}
|
||||
} else {
|
||||
if result != nil {
|
||||
t.Errorf("reverseDNSToIP(%q) = %v, expected nil", tt.reverseDNS, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordCaseInsensitive(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record with mixed case domain
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
domain := "MyHost.Example.Com"
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Test lookup with different cases in reverse DNS format
|
||||
reverseDomain := "1.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be found")
|
||||
}
|
||||
// Domain should be normalized to lowercase
|
||||
if result != "myhost.example.com." {
|
||||
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
|
||||
}
|
||||
|
||||
// Test with uppercase reverse DNS
|
||||
reverseDomainUpper := "1.1.168.192.IN-ADDR.ARPA."
|
||||
result, ok = store.GetPTRRecord(reverseDomainUpper)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be found with uppercase reverse DNS")
|
||||
}
|
||||
if result != "myhost.example.com." {
|
||||
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearPTRRecords(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add some PTR records
|
||||
ip1 := net.ParseIP("192.168.1.1")
|
||||
ip2 := net.ParseIP("192.168.1.2")
|
||||
store.AddPTRRecord(ip1, "host1.example.com.")
|
||||
store.AddPTRRecord(ip2, "host2.example.com.")
|
||||
|
||||
// Add some A records too
|
||||
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"))
|
||||
|
||||
// Verify PTR records exist
|
||||
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
|
||||
t.Error("Expected PTR record to exist before clear")
|
||||
}
|
||||
|
||||
// Clear all records
|
||||
store.Clear()
|
||||
|
||||
// Verify PTR records are gone
|
||||
if store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
|
||||
t.Error("Expected PTR record to be cleared")
|
||||
}
|
||||
if store.HasPTRRecord("2.1.168.192.in-addr.arpa.") {
|
||||
t.Error("Expected PTR record to be cleared")
|
||||
}
|
||||
|
||||
// Verify A records are also gone
|
||||
if store.HasRecord("test.example.com.", RecordTypeA) {
|
||||
t.Error("Expected A record to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutomaticPTRRecordOnAdd(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add an A record - should automatically add PTR record
|
||||
domain := "host.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
err := store.AddRecord(domain, ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add A record: %v", err)
|
||||
}
|
||||
|
||||
// Verify PTR record was automatically created
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be automatically created")
|
||||
}
|
||||
if result != domain {
|
||||
t.Errorf("Expected PTR to point to %q, got %q", domain, result)
|
||||
}
|
||||
|
||||
// Add AAAA record - should also automatically add PTR record
|
||||
domain6 := "ipv6host.example.com."
|
||||
ip6 := net.ParseIP("2001:db8::1")
|
||||
err = store.AddRecord(domain6, ip6)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add AAAA record: %v", err)
|
||||
}
|
||||
|
||||
// Verify IPv6 PTR record was automatically created
|
||||
reverseDomain6 := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
|
||||
result6, ok := store.GetPTRRecord(reverseDomain6)
|
||||
if !ok {
|
||||
t.Error("Expected IPv6 PTR record to be automatically created")
|
||||
}
|
||||
if result6 != domain6 {
|
||||
t.Errorf("Expected PTR to point to %q, got %q", domain6, result6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutomaticPTRRecordOnRemove(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add an A record (with automatic PTR)
|
||||
domain := "host.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
store.AddRecord(domain, ip)
|
||||
|
||||
// Verify PTR exists
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
if !store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected PTR record to exist after adding A record")
|
||||
}
|
||||
|
||||
// Remove the A record
|
||||
store.RemoveRecord(domain, ip)
|
||||
|
||||
// Verify PTR was automatically removed
|
||||
if store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected PTR record to be automatically removed")
|
||||
}
|
||||
|
||||
// Verify A record is also gone
|
||||
ips := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected A record to be removed, got %d records", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add multiple IPs for the same domain
|
||||
domain := "host.example.com."
|
||||
ip1 := net.ParseIP("192.168.1.100")
|
||||
ip2 := net.ParseIP("192.168.1.101")
|
||||
store.AddRecord(domain, ip1)
|
||||
store.AddRecord(domain, ip2)
|
||||
|
||||
// Verify both PTR records exist
|
||||
reverseDomain1 := "100.1.168.192.in-addr.arpa."
|
||||
reverseDomain2 := "101.1.168.192.in-addr.arpa."
|
||||
if !store.HasPTRRecord(reverseDomain1) {
|
||||
t.Error("Expected first PTR record to exist")
|
||||
}
|
||||
if !store.HasPTRRecord(reverseDomain2) {
|
||||
t.Error("Expected second PTR record to exist")
|
||||
}
|
||||
|
||||
// Remove all records for the domain
|
||||
store.RemoveRecord(domain, nil)
|
||||
|
||||
// Verify both PTR records were removed
|
||||
if store.HasPTRRecord(reverseDomain1) {
|
||||
t.Error("Expected first PTR record to be removed")
|
||||
}
|
||||
if store.HasPTRRecord(reverseDomain2) {
|
||||
t.Error("Expected second PTR record to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoPTRForWildcardRecords(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add wildcard record - should NOT create PTR record
|
||||
domain := "*.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
err := store.AddRecord(domain, ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Verify no PTR record was created
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
_, ok := store.GetPTRRecord(reverseDomain)
|
||||
if ok {
|
||||
t.Error("Expected no PTR record for wildcard domain")
|
||||
}
|
||||
|
||||
// Verify wildcard A record exists
|
||||
if !store.HasRecord("host.example.com.", RecordTypeA) {
|
||||
t.Error("Expected wildcard A record to exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordOverwrite(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add first domain with IP
|
||||
domain1 := "host1.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
store.AddRecord(domain1, ip)
|
||||
|
||||
// Verify PTR points to first domain
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Fatal("Expected PTR record to exist")
|
||||
}
|
||||
if result != domain1 {
|
||||
t.Errorf("Expected PTR to point to %q, got %q", domain1, result)
|
||||
}
|
||||
|
||||
// Add second domain with same IP - should overwrite PTR
|
||||
domain2 := "host2.example.com."
|
||||
store.AddRecord(domain2, ip)
|
||||
|
||||
// Verify PTR now points to second domain (last one added)
|
||||
result, ok = store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Fatal("Expected PTR record to still exist")
|
||||
}
|
||||
if result != domain2 {
|
||||
t.Errorf("Expected PTR to point to %q (overwritten), got %q", domain2, result)
|
||||
}
|
||||
|
||||
// Remove first domain - PTR should remain pointing to second domain
|
||||
store.RemoveRecord(domain1, ip)
|
||||
result, ok = store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to still exist after removing first domain")
|
||||
}
|
||||
if result != domain2 {
|
||||
t.Errorf("Expected PTR to still point to %q, got %q", domain2, result)
|
||||
}
|
||||
|
||||
// Remove second domain - PTR should now be gone
|
||||
store.RemoveRecord(domain2, ip)
|
||||
_, ok = store.GetPTRRecord(reverseDomain)
|
||||
if ok {
|
||||
t.Error("Expected PTR record to be removed after removing second domain")
|
||||
}
|
||||
}
|
||||
16
dns/override/dns_override_android.go
Normal file
16
dns/override/dns_override_android.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build android
|
||||
|
||||
package olm
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// SetupDNSOverride is a no-op on Android
|
||||
// Android handles DNS through the VpnService API at the Java/Kotlin layer
|
||||
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride is a no-op on Android
|
||||
func RestoreDNSOverride() error {
|
||||
return nil
|
||||
}
|
||||
63
dns/override/dns_override_darwin.go
Normal file
63
dns/override/dns_override_darwin.go
Normal file
@@ -0,0 +1,63 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
|
||||
// Uses scutil for DNS configuration
|
||||
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||
var err error
|
||||
configurator, err = platform.NewDarwinDNSConfigurator()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Darwin DNS configurator: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Using Darwin scutil DNS configurator")
|
||||
|
||||
// Get current DNS servers before changing
|
||||
currentDNS, err := configurator.GetCurrentDNS()
|
||||
if err != nil {
|
||||
logger.Warn("Could not get current DNS: %v", err)
|
||||
} else {
|
||||
logger.Info("Current DNS servers: %v", currentDNS)
|
||||
}
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
proxyIp,
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
originalDNS, err := configurator.SetDNS(newDNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Original DNS servers backed up: %v", originalDNS)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride restores the original DNS configuration
|
||||
func RestoreDNSOverride() error {
|
||||
if configurator == nil {
|
||||
logger.Debug("No DNS configurator to restore")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Restoring original DNS configuration")
|
||||
if err := configurator.RestoreDNS(); err != nil {
|
||||
return fmt.Errorf("failed to restore DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("DNS configuration restored successfully")
|
||||
return nil
|
||||
}
|
||||
15
dns/override/dns_override_ios.go
Normal file
15
dns/override/dns_override_ios.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build ios
|
||||
|
||||
package olm
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system
|
||||
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride is a no-op on iOS as DNS configuration is handled by the system
|
||||
func RestoreDNSOverride() error {
|
||||
return nil
|
||||
}
|
||||
100
dns/override/dns_override_unix.go
Normal file
100
dns/override/dns_override_unix.go
Normal file
@@ -0,0 +1,100 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
|
||||
// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability
|
||||
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||
var err error
|
||||
|
||||
// Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability
|
||||
managerType := platform.DetectDNSManager(interfaceName)
|
||||
logger.Info("Detected DNS manager: %s", managerType.String())
|
||||
|
||||
// Create configurator based on detected manager
|
||||
switch managerType {
|
||||
case platform.SystemdResolvedManager:
|
||||
configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using systemd-resolved DNS configurator")
|
||||
return setDNS(proxyIp, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err)
|
||||
|
||||
case platform.NetworkManagerManager:
|
||||
configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using NetworkManager DNS configurator")
|
||||
return setDNS(proxyIp, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err)
|
||||
|
||||
case platform.ResolvconfManager:
|
||||
configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using resolvconf DNS configurator")
|
||||
return setDNS(proxyIp, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create resolvconf configurator: %v, falling back", err)
|
||||
}
|
||||
|
||||
// Fall back to direct file manipulation
|
||||
configurator, err = platform.NewFileDNSConfigurator()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file DNS configurator: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Using file-based DNS configurator")
|
||||
return setDNS(proxyIp, configurator)
|
||||
}
|
||||
|
||||
// setDNS is a helper function to set DNS and log the results
|
||||
func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error {
|
||||
// Get current DNS servers before changing
|
||||
currentDNS, err := conf.GetCurrentDNS()
|
||||
if err != nil {
|
||||
logger.Warn("Could not get current DNS: %v", err)
|
||||
} else {
|
||||
logger.Info("Current DNS servers: %v", currentDNS)
|
||||
}
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
proxyIp,
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
originalDNS, err := conf.SetDNS(newDNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Original DNS servers backed up: %v", originalDNS)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride restores the original DNS configuration
|
||||
func RestoreDNSOverride() error {
|
||||
if configurator == nil {
|
||||
logger.Debug("No DNS configurator to restore")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Restoring original DNS configuration")
|
||||
if err := configurator.RestoreDNS(); err != nil {
|
||||
return fmt.Errorf("failed to restore DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("DNS configuration restored successfully")
|
||||
return nil
|
||||
}
|
||||
63
dns/override/dns_override_windows.go
Normal file
63
dns/override/dns_override_windows.go
Normal file
@@ -0,0 +1,63 @@
|
||||
//go:build windows
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
|
||||
// Uses registry-based configuration (automatically extracts interface GUID)
|
||||
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||
var err error
|
||||
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Windows DNS configurator: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Using Windows registry DNS configurator for interface: %s", interfaceName)
|
||||
|
||||
// Get current DNS servers before changing
|
||||
currentDNS, err := configurator.GetCurrentDNS()
|
||||
if err != nil {
|
||||
logger.Warn("Could not get current DNS: %v", err)
|
||||
} else {
|
||||
logger.Info("Current DNS servers: %v", currentDNS)
|
||||
}
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
proxyIp,
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
originalDNS, err := configurator.SetDNS(newDNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Original DNS servers backed up: %v", originalDNS)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride restores the original DNS configuration
|
||||
func RestoreDNSOverride() error {
|
||||
if configurator == nil {
|
||||
logger.Debug("No DNS configurator to restore")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Restoring original DNS configuration")
|
||||
if err := configurator.RestoreDNS(); err != nil {
|
||||
return fmt.Errorf("failed to restore DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("DNS configuration restored successfully")
|
||||
return nil
|
||||
}
|
||||
419
dns/platform/darwin.go
Normal file
419
dns/platform/darwin.go
Normal file
@@ -0,0 +1,419 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
scutilPath = "/usr/sbin/scutil"
|
||||
dscacheutilPath = "/usr/bin/dscacheutil"
|
||||
|
||||
dnsStateKeyFormat = "State:/Network/Service/Olm-%s/DNS"
|
||||
globalIPv4State = "State:/Network/Global/IPv4"
|
||||
primaryServiceFormat = "State:/Network/Service/%s/DNS"
|
||||
|
||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
|
||||
keyServerAddresses = "ServerAddresses"
|
||||
keyServerPort = "ServerPort"
|
||||
arraySymbol = "* "
|
||||
digitSymbol = "# "
|
||||
|
||||
// State file name for crash recovery
|
||||
dnsStateFileName = "dns_state.json"
|
||||
)
|
||||
|
||||
// DNSPersistentState represents the state saved to disk for crash recovery
|
||||
type DNSPersistentState struct {
|
||||
CreatedKeys []string `json:"created_keys"`
|
||||
}
|
||||
|
||||
// DarwinDNSConfigurator manages DNS settings on macOS using scutil
|
||||
type DarwinDNSConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
originalState *DNSState
|
||||
stateFilePath string
|
||||
}
|
||||
|
||||
// NewDarwinDNSConfigurator creates a new macOS DNS configurator
|
||||
func NewDarwinDNSConfigurator() (*DarwinDNSConfigurator, error) {
|
||||
stateFilePath := getDNSStateFilePath()
|
||||
|
||||
configurator := &DarwinDNSConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
stateFilePath: stateFilePath,
|
||||
}
|
||||
|
||||
// Clean up any leftover state from a previous crash
|
||||
if err := configurator.CleanupUncleanShutdown(); err != nil {
|
||||
logger.Warn("Failed to cleanup previous DNS state: %v", err)
|
||||
}
|
||||
|
||||
return configurator, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (d *DarwinDNSConfigurator) Name() string {
|
||||
return "darwin-scutil"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (d *DarwinDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := d.GetCurrentDNS()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current DNS: %w", err)
|
||||
}
|
||||
|
||||
// Store original state
|
||||
d.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: d.Name(),
|
||||
}
|
||||
|
||||
// Set new DNS servers
|
||||
if err := d.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Persist state to disk for crash recovery
|
||||
if err := d.saveState(); err != nil {
|
||||
logger.Warn("Failed to save DNS state for crash recovery: %v", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := d.flushDNSCache(); err != nil {
|
||||
// Non-fatal, just log
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (d *DarwinDNSConfigurator) RestoreDNS() error {
|
||||
// Remove all created keys
|
||||
for key := range d.createdKeys {
|
||||
if err := d.removeKey(key); err != nil {
|
||||
return fmt.Errorf("remove key %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear state file after successful restoration
|
||||
if err := d.clearState(); err != nil {
|
||||
logger.Warn("Failed to clear DNS state file: %v", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := d.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (d *DarwinDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
primaryServiceKey, err := d.getPrimaryServiceKey()
|
||||
if err != nil || primaryServiceKey == "" {
|
||||
return nil, fmt.Errorf("get primary service: %w", err)
|
||||
}
|
||||
|
||||
dnsKey := fmt.Sprintf(primaryServiceFormat, primaryServiceKey)
|
||||
cmd := fmt.Sprintf("show %s\n", dnsKey)
|
||||
|
||||
output, err := d.runScutil(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("run scutil: %w", err)
|
||||
}
|
||||
|
||||
servers := d.parseServerAddresses(output)
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS keys left over from a previous crash
|
||||
func (d *DarwinDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
state, err := d.loadState()
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// No state file, nothing to clean up
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("load state: %w", err)
|
||||
}
|
||||
|
||||
if len(state.CreatedKeys) == 0 {
|
||||
// No keys to clean up
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Found DNS state from previous session, cleaning up %d keys", len(state.CreatedKeys))
|
||||
|
||||
// Remove all keys from previous session
|
||||
var lastErr error
|
||||
for _, key := range state.CreatedKeys {
|
||||
logger.Debug("Removing leftover DNS key: %s", key)
|
||||
if err := d.removeKeyDirect(key); err != nil {
|
||||
logger.Warn("Failed to remove DNS key %s: %v", key, err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
// Clear state file
|
||||
if err := d.clearState(); err != nil {
|
||||
logger.Warn("Failed to clear DNS state file: %v", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache after cleanup
|
||||
if err := d.flushDNSCache(); err != nil {
|
||||
logger.Warn("Failed to flush DNS cache after cleanup: %v", err)
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// applyDNSServers applies the DNS server configuration
|
||||
func (d *DarwinDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
key := fmt.Sprintf(dnsStateKeyFormat, "Override")
|
||||
|
||||
// Use SupplementalMatchDomains with empty string to match ALL domains
|
||||
// This is the key to making DNS override work on macOS
|
||||
// Setting SupplementalMatchDomainsNoSearch to 0 enables search domain behavior
|
||||
err := d.addDNSState(key, "\"\"", servers[0], 53, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set DNS servers: %w", err)
|
||||
}
|
||||
|
||||
d.createdKeys[key] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addDNSState adds a DNS state entry with the specified configuration
|
||||
func (d *DarwinDNSConfigurator) addDNSState(state, domains string, dnsServer netip.Addr, port int, enableSearch bool) error {
|
||||
noSearch := "1"
|
||||
if enableSearch {
|
||||
noSearch = "0"
|
||||
}
|
||||
|
||||
// Build the scutil command following NetBird's approach
|
||||
var commands strings.Builder
|
||||
commands.WriteString("d.init\n")
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keySupplementalMatchDomains, arraySymbol, domains))
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keySupplementalMatchDomainsNoSearch, digitSymbol, noSearch))
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keyServerAddresses, arraySymbol, dnsServer.String()))
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keyServerPort, digitSymbol, strconv.Itoa(port)))
|
||||
commands.WriteString(fmt.Sprintf("set %s\n", state))
|
||||
|
||||
if _, err := d.runScutil(commands.String()); err != nil {
|
||||
return fmt.Errorf("applying state for domains %s, error: %w", domains, err)
|
||||
}
|
||||
|
||||
logger.Info("Added DNS override with server %s:%d for domains: %s", dnsServer.String(), port, domains)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeKey removes a DNS configuration key and updates internal state
|
||||
func (d *DarwinDNSConfigurator) removeKey(key string) error {
|
||||
if err := d.removeKeyDirect(key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(d.createdKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeKeyDirect removes a DNS configuration key without updating internal state
|
||||
// Used for cleanup operations
|
||||
func (d *DarwinDNSConfigurator) removeKeyDirect(key string) error {
|
||||
cmd := fmt.Sprintf("remove %s\n", key)
|
||||
|
||||
if _, err := d.runScutil(cmd); err != nil {
|
||||
return fmt.Errorf("remove key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPrimaryServiceKey gets the primary network service key
|
||||
func (d *DarwinDNSConfigurator) getPrimaryServiceKey() (string, error) {
|
||||
cmd := fmt.Sprintf("show %s\n", globalIPv4State)
|
||||
|
||||
output, err := d.runScutil(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("run scutil: %w", err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(output))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.Contains(line, "PrimaryService") {
|
||||
parts := strings.Split(line, ":")
|
||||
if len(parts) >= 2 {
|
||||
return strings.TrimSpace(parts[1]), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("scan output: %w", err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("primary service not found")
|
||||
}
|
||||
|
||||
// parseServerAddresses parses DNS server addresses from scutil output
|
||||
func (d *DarwinDNSConfigurator) parseServerAddresses(output []byte) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
inServerArray := false
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(output))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if strings.HasPrefix(line, "ServerAddresses : <array> {") {
|
||||
inServerArray = true
|
||||
continue
|
||||
}
|
||||
|
||||
if line == "}" {
|
||||
inServerArray = false
|
||||
continue
|
||||
}
|
||||
|
||||
if inServerArray {
|
||||
// Line format: "0 : 8.8.8.8"
|
||||
parts := strings.Split(line, " : ")
|
||||
if len(parts) >= 2 {
|
||||
if addr, err := netip.ParseAddr(parts[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// flushDNSCache flushes the system DNS cache
|
||||
func (d *DarwinDNSConfigurator) flushDNSCache() error {
|
||||
logger.Debug("Flushing dscacheutil cache")
|
||||
cmd := exec.Command(dscacheutilPath, "-flushcache")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("flush cache: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Flushing mDNSResponder cache")
|
||||
|
||||
cmd = exec.Command("killall", "-HUP", "mDNSResponder")
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Non-fatal, mDNSResponder might not be running
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runScutil executes an scutil command
|
||||
func (d *DarwinDNSConfigurator) runScutil(commands string) ([]byte, error) {
|
||||
// Wrap commands with open/quit
|
||||
wrapped := fmt.Sprintf("open\n%squit\n", commands)
|
||||
|
||||
logger.Debug("Running scutil with commands:\n%s\n", wrapped)
|
||||
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader(wrapped)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scutil command failed: %w, output: %s", err, output)
|
||||
}
|
||||
|
||||
logger.Debug("scutil output:\n%s\n", output)
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// getDNSStateFilePath returns the path to the DNS state file
|
||||
func getDNSStateFilePath() string {
|
||||
var stateDir string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
stateDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client")
|
||||
default:
|
||||
stateDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(stateDir, 0755); err != nil {
|
||||
logger.Warn("Failed to create state directory: %v", err)
|
||||
}
|
||||
|
||||
return filepath.Join(stateDir, dnsStateFileName)
|
||||
}
|
||||
|
||||
// saveState persists the current DNS state to disk
|
||||
func (d *DarwinDNSConfigurator) saveState() error {
|
||||
keys := make([]string, 0, len(d.createdKeys))
|
||||
for key := range d.createdKeys {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
state := DNSPersistentState{
|
||||
CreatedKeys: keys,
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(state, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal state: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(d.stateFilePath, data, 0644); err != nil {
|
||||
return fmt.Errorf("write state file: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Saved DNS state to %s", d.stateFilePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadState loads the DNS state from disk
|
||||
func (d *DarwinDNSConfigurator) loadState() (*DNSPersistentState, error) {
|
||||
data, err := os.ReadFile(d.stateFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var state DNSPersistentState
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal state: %w", err)
|
||||
}
|
||||
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
// clearState removes the DNS state file
|
||||
func (d *DarwinDNSConfigurator) clearState() error {
|
||||
err := os.Remove(d.stateFilePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove state file: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Cleared DNS state file")
|
||||
return nil
|
||||
}
|
||||
158
dns/platform/detect_unix.go
Normal file
158
dns/platform/detect_unix.go
Normal file
@@ -0,0 +1,158 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
const defaultResolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// DNSManagerType represents the type of DNS manager detected
|
||||
type DNSManagerType int
|
||||
|
||||
const (
|
||||
// UnknownManager indicates we couldn't determine the DNS manager
|
||||
UnknownManager DNSManagerType = iota
|
||||
// SystemdResolvedManager indicates systemd-resolved is managing DNS
|
||||
SystemdResolvedManager
|
||||
// NetworkManagerManager indicates NetworkManager is managing DNS
|
||||
NetworkManagerManager
|
||||
// ResolvconfManager indicates resolvconf is managing DNS
|
||||
ResolvconfManager
|
||||
// FileManager indicates direct file management (no DNS manager)
|
||||
FileManager
|
||||
)
|
||||
|
||||
// DetectDNSManagerFromFile reads /etc/resolv.conf to determine which DNS manager is in use
|
||||
// This provides a hint based on comments in the file, similar to Netbird's approach
|
||||
func DetectDNSManagerFromFile() DNSManagerType {
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return UnknownManager
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
if len(text) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// If we hit a non-comment line, default to file-based
|
||||
if text[0] != '#' {
|
||||
return FileManager
|
||||
}
|
||||
|
||||
// Check for DNS manager signatures in comments
|
||||
if strings.Contains(text, "NetworkManager") {
|
||||
return NetworkManagerManager
|
||||
}
|
||||
|
||||
if strings.Contains(text, "systemd-resolved") {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
return ResolvconfManager
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return UnknownManager
|
||||
}
|
||||
|
||||
// No indicators found, assume file-based management
|
||||
return FileManager
|
||||
}
|
||||
|
||||
// String returns a human-readable name for the DNS manager type
|
||||
func (d DNSManagerType) String() string {
|
||||
switch d {
|
||||
case SystemdResolvedManager:
|
||||
return "systemd-resolved"
|
||||
case NetworkManagerManager:
|
||||
return "NetworkManager"
|
||||
case ResolvconfManager:
|
||||
return "resolvconf"
|
||||
case FileManager:
|
||||
return "file"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// DetectDNSManager combines file detection with runtime availability checks
|
||||
// to determine the best DNS configurator to use
|
||||
func DetectDNSManager(interfaceName string) DNSManagerType {
|
||||
// First check what the file suggests
|
||||
fileHint := DetectDNSManagerFromFile()
|
||||
|
||||
// Verify the hint with runtime checks
|
||||
switch fileHint {
|
||||
case SystemdResolvedManager:
|
||||
// Verify systemd-resolved is actually running
|
||||
if IsSystemdResolvedAvailable() {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
logger.Warn("dns platform: Found systemd-resolved but it is not running. Falling back to file...")
|
||||
os.Exit(0)
|
||||
return FileManager
|
||||
|
||||
case NetworkManagerManager:
|
||||
// Verify NetworkManager is actually running
|
||||
if IsNetworkManagerAvailable() {
|
||||
// Check if NetworkManager is delegating to systemd-resolved
|
||||
if !IsNetworkManagerDNSModeSupported() {
|
||||
logger.Info("NetworkManager is delegating DNS to systemd-resolved, using systemd-resolved configurator")
|
||||
if IsSystemdResolvedAvailable() {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
}
|
||||
return NetworkManagerManager
|
||||
}
|
||||
logger.Warn("dns platform: Found network manager but it is not running. Falling back to file...")
|
||||
return FileManager
|
||||
|
||||
case ResolvconfManager:
|
||||
// Verify resolvconf is available
|
||||
if IsResolvconfAvailable() {
|
||||
return ResolvconfManager
|
||||
}
|
||||
// If resolvconf is mentioned but not available, fall back to file
|
||||
return FileManager
|
||||
|
||||
case FileManager:
|
||||
// File suggests direct file management
|
||||
// But we should still check if a manager is available that wasn't mentioned
|
||||
if IsSystemdResolvedAvailable() && interfaceName != "" {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
if IsNetworkManagerAvailable() && interfaceName != "" {
|
||||
return NetworkManagerManager
|
||||
}
|
||||
if IsResolvconfAvailable() && interfaceName != "" {
|
||||
return ResolvconfManager
|
||||
}
|
||||
return FileManager
|
||||
|
||||
default:
|
||||
// Unknown - do runtime detection
|
||||
if IsSystemdResolvedAvailable() && interfaceName != "" {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
if IsNetworkManagerAvailable() && interfaceName != "" {
|
||||
return NetworkManagerManager
|
||||
}
|
||||
if IsResolvconfAvailable() && interfaceName != "" {
|
||||
return ResolvconfManager
|
||||
}
|
||||
return FileManager
|
||||
}
|
||||
}
|
||||
220
dns/platform/file.go
Normal file
220
dns/platform/file.go
Normal file
@@ -0,0 +1,220 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
resolvConfPath = "/etc/resolv.conf"
|
||||
resolvConfBackupPath = "/etc/resolv.conf.olm.backup"
|
||||
resolvConfHeader = "# Generated by Olm DNS Manager\n# Original file backed up to " + resolvConfBackupPath + "\n\n"
|
||||
)
|
||||
|
||||
// FileDNSConfigurator manages DNS settings by directly modifying /etc/resolv.conf
|
||||
type FileDNSConfigurator struct {
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewFileDNSConfigurator creates a new file-based DNS configurator
|
||||
func NewFileDNSConfigurator() (*FileDNSConfigurator, error) {
|
||||
f := &FileDNSConfigurator{}
|
||||
if err := f.CleanupUncleanShutdown(); err != nil {
|
||||
return nil, fmt.Errorf("cleanup unclean shutdown: %w", err)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (f *FileDNSConfigurator) Name() string {
|
||||
return "file-resolv.conf"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (f *FileDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := f.GetCurrentDNS()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current DNS: %w", err)
|
||||
}
|
||||
|
||||
// Backup original resolv.conf if not already backed up
|
||||
if !f.isBackupExists() {
|
||||
if err := f.backupResolvConf(); err != nil {
|
||||
return nil, fmt.Errorf("backup resolv.conf: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
f.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: f.Name(),
|
||||
}
|
||||
|
||||
// Write new resolv.conf
|
||||
if err := f.writeResolvConf(servers); err != nil {
|
||||
return nil, fmt.Errorf("write resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (f *FileDNSConfigurator) RestoreDNS() error {
|
||||
if !f.isBackupExists() {
|
||||
return fmt.Errorf("no backup file exists")
|
||||
}
|
||||
|
||||
// Copy backup back to original location
|
||||
if err := copyFile(resolvConfBackupPath, resolvConfPath); err != nil {
|
||||
return fmt.Errorf("restore from backup: %w", err)
|
||||
}
|
||||
|
||||
// Remove backup file
|
||||
if err := os.Remove(resolvConfBackupPath); err != nil {
|
||||
return fmt.Errorf("remove backup file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
|
||||
// For the file-based configurator, we check if a backup file exists (indicating a crash
|
||||
// happened while DNS was configured) and restore from it if so.
|
||||
func (f *FileDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
// Check if backup file exists from a previous session
|
||||
if !f.isBackupExists() {
|
||||
// No backup file, nothing to clean up
|
||||
return nil
|
||||
}
|
||||
|
||||
// A backup exists, which means we crashed while DNS was configured
|
||||
// Restore the original resolv.conf
|
||||
if err := copyFile(resolvConfBackupPath, resolvConfPath); err != nil {
|
||||
return fmt.Errorf("restore from backup during cleanup: %w", err)
|
||||
}
|
||||
|
||||
// Remove backup file
|
||||
if err := os.Remove(resolvConfBackupPath); err != nil {
|
||||
return fmt.Errorf("remove backup file during cleanup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (f *FileDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
content, err := os.ReadFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return f.parseNameservers(string(content)), nil
|
||||
}
|
||||
|
||||
// backupResolvConf creates a backup of the current resolv.conf
|
||||
func (f *FileDNSConfigurator) backupResolvConf() error {
|
||||
// Get file info for permissions
|
||||
info, err := os.Stat(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
if err := copyFile(resolvConfPath, resolvConfBackupPath); err != nil {
|
||||
return fmt.Errorf("copy file: %w", err)
|
||||
}
|
||||
|
||||
// Preserve permissions
|
||||
if err := os.Chmod(resolvConfBackupPath, info.Mode()); err != nil {
|
||||
return fmt.Errorf("chmod backup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeResolvConf writes a new resolv.conf with the specified DNS servers
|
||||
func (f *FileDNSConfigurator) writeResolvConf(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Get file info for permissions
|
||||
info, err := os.Stat(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
var content strings.Builder
|
||||
content.WriteString(resolvConfHeader)
|
||||
|
||||
// Write nameservers
|
||||
for _, server := range servers {
|
||||
content.WriteString("nameserver ")
|
||||
content.WriteString(server.String())
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
// Write the file
|
||||
if err := os.WriteFile(resolvConfPath, []byte(content.String()), info.Mode()); err != nil {
|
||||
return fmt.Errorf("write resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isBackupExists checks if a backup file exists
|
||||
func (f *FileDNSConfigurator) isBackupExists() bool {
|
||||
_, err := os.Stat(resolvConfBackupPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// parseNameservers extracts nameserver entries from resolv.conf content
|
||||
func (f *FileDNSConfigurator) parseNameservers(content string) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// Skip comments and empty lines
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Look for nameserver lines
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
if addr, err := netip.ParseAddr(fields[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// copyFile copies a file from src to dst
|
||||
func copyFile(src, dst string) error {
|
||||
content, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read source: %w", err)
|
||||
}
|
||||
|
||||
// Get source file permissions
|
||||
info, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat source: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(dst, content, info.Mode()); err != nil {
|
||||
return fmt.Errorf("write destination: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
325
dns/platform/network_manager.go
Normal file
325
dns/platform/network_manager.go
Normal file
@@ -0,0 +1,325 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbus "github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
// NetworkManager D-Bus constants
|
||||
networkManagerDest = "org.freedesktop.NetworkManager"
|
||||
networkManagerDbusObjectNode = "/org/freedesktop/NetworkManager"
|
||||
networkManagerDbusDNSManagerInterface = "org.freedesktop.NetworkManager.DnsManager"
|
||||
networkManagerDbusDNSManagerObjectNode = networkManagerDbusObjectNode + "/DnsManager"
|
||||
networkManagerDbusDNSManagerModeProperty = networkManagerDbusDNSManagerInterface + ".Mode"
|
||||
networkManagerDbusVersionProperty = "org.freedesktop.NetworkManager.Version"
|
||||
|
||||
// NetworkManager dispatcher script path
|
||||
networkManagerDispatcherDir = "/etc/NetworkManager/dispatcher.d"
|
||||
networkManagerConfDir = "/etc/NetworkManager/conf.d"
|
||||
networkManagerDNSConfFile = "olm-dns.conf"
|
||||
networkManagerDispatcherFile = "01-olm-dns"
|
||||
)
|
||||
|
||||
// NetworkManagerDNSConfigurator manages DNS settings using NetworkManager configuration files
|
||||
// This approach works with unmanaged interfaces by modifying NetworkManager's global DNS settings
|
||||
type NetworkManagerDNSConfigurator struct {
|
||||
ifaceName string
|
||||
originalState *DNSState
|
||||
confPath string
|
||||
dispatchPath string
|
||||
}
|
||||
|
||||
// NewNetworkManagerDNSConfigurator creates a new NetworkManager DNS configurator
|
||||
func NewNetworkManagerDNSConfigurator(ifaceName string) (*NetworkManagerDNSConfigurator, error) {
|
||||
if ifaceName == "" {
|
||||
return nil, fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
// Check that NetworkManager conf.d directory exists
|
||||
if _, err := os.Stat(networkManagerConfDir); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("NetworkManager conf.d directory not found: %s", networkManagerConfDir)
|
||||
}
|
||||
|
||||
configurator := &NetworkManagerDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
confPath: networkManagerConfDir + "/" + networkManagerDNSConfFile,
|
||||
dispatchPath: networkManagerDispatcherDir + "/" + networkManagerDispatcherFile,
|
||||
}
|
||||
|
||||
// Clean up any stale configuration from a previous unclean shutdown
|
||||
if err := configurator.CleanupUncleanShutdown(); err != nil {
|
||||
return nil, fmt.Errorf("cleanup unclean shutdown: %w", err)
|
||||
}
|
||||
|
||||
return configurator, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (n *NetworkManagerDNSConfigurator) Name() string {
|
||||
return "network-manager"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (n *NetworkManagerDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := n.GetCurrentDNS()
|
||||
if err != nil {
|
||||
// If we can't get current DNS, proceed anyway
|
||||
originalServers = []netip.Addr{}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
n.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: n.Name(),
|
||||
}
|
||||
|
||||
// Apply new DNS servers
|
||||
if err := n.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (n *NetworkManagerDNSConfigurator) RestoreDNS() error {
|
||||
// Remove our configuration file
|
||||
if err := os.Remove(n.confPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove DNS config file: %w", err)
|
||||
}
|
||||
|
||||
// Reload NetworkManager to apply the change
|
||||
if err := n.reloadNetworkManager(); err != nil {
|
||||
return fmt.Errorf("reload NetworkManager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
|
||||
// For NetworkManager, we check if our config file exists and remove it if so.
|
||||
// This ensures that if the process crashed while DNS was configured, the stale
|
||||
// configuration is removed on the next startup.
|
||||
func (n *NetworkManagerDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
// Check if our config file exists from a previous session
|
||||
if _, err := os.Stat(n.confPath); os.IsNotExist(err) {
|
||||
// No config file, nothing to clean up
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove the stale configuration file
|
||||
if err := os.Remove(n.confPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove stale DNS config file: %w", err)
|
||||
}
|
||||
|
||||
// Reload NetworkManager to apply the change
|
||||
if err := n.reloadNetworkManager(); err != nil {
|
||||
return fmt.Errorf("reload NetworkManager after cleanup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers by reading /etc/resolv.conf
|
||||
func (n *NetworkManagerDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
content, err := os.ReadFile("/etc/resolv.conf")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
var servers []netip.Addr
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
if addr, err := netip.ParseAddr(fields[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// applyDNSServers applies DNS server configuration via NetworkManager config file
|
||||
func (n *NetworkManagerDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Build DNS server list
|
||||
var dnsServers []string
|
||||
for _, server := range servers {
|
||||
dnsServers = append(dnsServers, server.String())
|
||||
}
|
||||
|
||||
// Create NetworkManager configuration file that sets global DNS
|
||||
// This overrides DNS for all connections
|
||||
configContent := fmt.Sprintf(`# Generated by Olm DNS Manager - DO NOT EDIT
|
||||
# This file configures NetworkManager to use Olm's DNS proxy
|
||||
|
||||
[global-dns-domain-*]
|
||||
servers=%s
|
||||
`, strings.Join(dnsServers, ","))
|
||||
|
||||
// Write the configuration file
|
||||
if err := os.WriteFile(n.confPath, []byte(configContent), 0644); err != nil {
|
||||
return fmt.Errorf("write DNS config file: %w", err)
|
||||
}
|
||||
|
||||
// Reload NetworkManager to apply the new configuration
|
||||
if err := n.reloadNetworkManager(); err != nil {
|
||||
// Try to clean up
|
||||
os.Remove(n.confPath)
|
||||
return fmt.Errorf("reload NetworkManager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reloadNetworkManager tells NetworkManager to reload its configuration
|
||||
func (n *NetworkManagerDNSConfigurator) reloadNetworkManager() error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Call Reload method with flags=0 (reload everything)
|
||||
// See: https://networkmanager.dev/docs/api/latest/gdbus-org.freedesktop.NetworkManager.html#gdbus-method-org-freedesktop-NetworkManager.Reload
|
||||
err = obj.CallWithContext(ctx, networkManagerDest+".Reload", 0, uint32(0)).Store()
|
||||
if err != nil {
|
||||
return fmt.Errorf("call Reload: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsNetworkManagerAvailable checks if NetworkManager is available and responsive
|
||||
func IsNetworkManagerAvailable() bool {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Try to ping NetworkManager
|
||||
if err := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsNetworkManagerDNSModeSupported checks if NetworkManager's DNS mode is one we can work with
|
||||
// Some DNS modes delegate to other systems (like systemd-resolved) which we should use directly
|
||||
func IsNetworkManagerDNSModeSupported() bool {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
|
||||
|
||||
modeVariant, err := obj.GetProperty(networkManagerDbusDNSManagerModeProperty)
|
||||
if err != nil {
|
||||
// If we can't get the mode, assume it's not supported
|
||||
return false
|
||||
}
|
||||
|
||||
mode, ok := modeVariant.Value().(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// If NetworkManager is delegating DNS to systemd-resolved, we should use
|
||||
// systemd-resolved directly for better control
|
||||
switch mode {
|
||||
case "systemd-resolved":
|
||||
// NetworkManager is delegating to systemd-resolved
|
||||
// We should use systemd-resolved configurator instead
|
||||
return false
|
||||
case "dnsmasq", "unbound":
|
||||
// NetworkManager is using a local resolver that it controls
|
||||
// We can configure DNS through NetworkManager
|
||||
return true
|
||||
case "default", "none", "":
|
||||
// NetworkManager is managing DNS directly or not at all
|
||||
// We can configure DNS through NetworkManager
|
||||
return true
|
||||
default:
|
||||
// Unknown mode, try to use it
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// GetNetworkManagerDNSMode returns the current DNS mode of NetworkManager
|
||||
func GetNetworkManagerDNSMode() (string, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
|
||||
|
||||
modeVariant, err := obj.GetProperty(networkManagerDbusDNSManagerModeProperty)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get DNS mode property: %w", err)
|
||||
}
|
||||
|
||||
mode, ok := modeVariant.Value().(string)
|
||||
if !ok {
|
||||
return "", errors.New("DNS mode is not a string")
|
||||
}
|
||||
|
||||
return mode, nil
|
||||
}
|
||||
|
||||
// GetNetworkManagerVersion returns the version of NetworkManager
|
||||
func GetNetworkManagerVersion() (string, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
|
||||
|
||||
versionVariant, err := obj.GetProperty(networkManagerDbusVersionProperty)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get version property: %w", err)
|
||||
}
|
||||
|
||||
version, ok := versionVariant.Value().(string)
|
||||
if !ok {
|
||||
return "", errors.New("version is not a string")
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
221
dns/platform/resolvconf.go
Normal file
221
dns/platform/resolvconf.go
Normal file
@@ -0,0 +1,221 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const resolvconfCommand = "resolvconf"
|
||||
|
||||
// ResolvconfDNSConfigurator manages DNS settings using the resolvconf utility
|
||||
type ResolvconfDNSConfigurator struct {
|
||||
ifaceName string
|
||||
implType string
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewResolvconfDNSConfigurator creates a new resolvconf DNS configurator
|
||||
func NewResolvconfDNSConfigurator(ifaceName string) (*ResolvconfDNSConfigurator, error) {
|
||||
if ifaceName == "" {
|
||||
return nil, fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
// Detect resolvconf implementation type
|
||||
implType, err := detectResolvconfType()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("detect resolvconf type: %w", err)
|
||||
}
|
||||
|
||||
configurator := &ResolvconfDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
implType: implType,
|
||||
}
|
||||
|
||||
// Call cleanup function to remove any stale DNS config for this interface
|
||||
if err := configurator.CleanupUncleanShutdown(); err != nil {
|
||||
return nil, fmt.Errorf("cleanup unclean shutdown: %w", err)
|
||||
}
|
||||
|
||||
return configurator, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (r *ResolvconfDNSConfigurator) Name() string {
|
||||
return fmt.Sprintf("resolvconf-%s", r.implType)
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (r *ResolvconfDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := r.GetCurrentDNS()
|
||||
if err != nil {
|
||||
// If we can't get current DNS, proceed anyway
|
||||
originalServers = []netip.Addr{}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
r.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: r.Name(),
|
||||
}
|
||||
|
||||
// Apply new DNS servers
|
||||
if err := r.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (r *ResolvconfDNSConfigurator) RestoreDNS() error {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch r.implType {
|
||||
case "openresolv":
|
||||
// Force delete with -f
|
||||
cmd = exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||
default:
|
||||
cmd = exec.Command(resolvconfCommand, "-d", r.ifaceName)
|
||||
}
|
||||
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("delete resolvconf config: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
|
||||
// For resolvconf, we attempt to delete any entry for the interface name.
|
||||
// This ensures that if the process crashed while DNS was configured, the stale
|
||||
// entry is removed on the next startup.
|
||||
func (r *ResolvconfDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
// Try to delete any existing entry for this interface
|
||||
// This is idempotent - if no entry exists, resolvconf will just return success
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch r.implType {
|
||||
case "openresolv":
|
||||
cmd = exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||
default:
|
||||
cmd = exec.Command(resolvconfCommand, "-d", r.ifaceName)
|
||||
}
|
||||
|
||||
// Ignore errors - the entry may not exist, which is fine
|
||||
_ = cmd.Run()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (r *ResolvconfDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
// resolvconf doesn't provide a direct way to query per-interface DNS
|
||||
// We can try to read /etc/resolv.conf but it's merged from all sources
|
||||
content, err := exec.Command(resolvconfCommand, "-l").CombinedOutput()
|
||||
if err != nil {
|
||||
// Fall back to reading resolv.conf
|
||||
return readResolvConfServers()
|
||||
}
|
||||
|
||||
// Parse the output (format varies by implementation)
|
||||
return parseResolvconfOutput(string(content)), nil
|
||||
}
|
||||
|
||||
// applyDNSServers applies DNS server configuration via resolvconf
|
||||
func (r *ResolvconfDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Build resolv.conf content
|
||||
var content bytes.Buffer
|
||||
content.WriteString("# Generated by Olm DNS Manager\n\n")
|
||||
|
||||
for _, server := range servers {
|
||||
content.WriteString("nameserver ")
|
||||
content.WriteString(server.String())
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
// Apply via resolvconf
|
||||
var cmd *exec.Cmd
|
||||
switch r.implType {
|
||||
case "openresolv":
|
||||
// OpenResolv supports exclusive mode with -x
|
||||
cmd = exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||
default:
|
||||
cmd = exec.Command(resolvconfCommand, "-a", r.ifaceName)
|
||||
}
|
||||
|
||||
cmd.Stdin = &content
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("apply resolvconf config: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectResolvconfType detects which resolvconf implementation is being used
|
||||
func detectResolvconfType() (string, error) {
|
||||
cmd := exec.Command(resolvconfCommand, "--version")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("detect resolvconf type: %w", err)
|
||||
}
|
||||
|
||||
if strings.Contains(string(out), "openresolv") {
|
||||
return "openresolv", nil
|
||||
}
|
||||
|
||||
return "resolvconf", nil
|
||||
}
|
||||
|
||||
// parseResolvconfOutput parses resolvconf -l output for DNS servers
|
||||
func parseResolvconfOutput(output string) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
|
||||
lines := strings.Split(output, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// Skip comments and empty lines
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Look for nameserver lines
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
if addr, err := netip.ParseAddr(fields[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// readResolvConfServers reads DNS servers from /etc/resolv.conf
|
||||
func readResolvConfServers() ([]netip.Addr, error) {
|
||||
cmd := exec.Command("cat", "/etc/resolv.conf")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return parseResolvconfOutput(string(out)), nil
|
||||
}
|
||||
|
||||
// IsResolvconfAvailable checks if resolvconf is available
|
||||
func IsResolvconfAvailable() bool {
|
||||
cmd := exec.Command(resolvconfCommand, "--version")
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
304
dns/platform/systemd.go
Normal file
304
dns/platform/systemd.go
Normal file
@@ -0,0 +1,304 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
dbus "github.com/godbus/dbus/v5"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
systemdResolvedDest = "org.freedesktop.resolve1"
|
||||
systemdDbusObjectNode = "/org/freedesktop/resolve1"
|
||||
systemdDbusManagerIface = "org.freedesktop.resolve1.Manager"
|
||||
systemdDbusGetLinkMethod = systemdDbusManagerIface + ".GetLink"
|
||||
systemdDbusFlushCachesMethod = systemdDbusManagerIface + ".FlushCaches"
|
||||
systemdDbusLinkInterface = "org.freedesktop.resolve1.Link"
|
||||
systemdDbusSetDNSMethod = systemdDbusLinkInterface + ".SetDNS"
|
||||
systemdDbusSetDefaultRouteMethod = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||
systemdDbusSetDomainsMethod = systemdDbusLinkInterface + ".SetDomains"
|
||||
systemdDbusSetDNSSECMethod = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||
systemdDbusSetDNSOverTLSMethod = systemdDbusLinkInterface + ".SetDNSOverTLS"
|
||||
systemdDbusRevertMethod = systemdDbusLinkInterface + ".Revert"
|
||||
|
||||
// RootZone is the root DNS zone that matches all queries
|
||||
RootZone = "."
|
||||
)
|
||||
|
||||
// systemdDbusDNSInput maps to (iay) dbus input for SetDNS method
|
||||
type systemdDbusDNSInput struct {
|
||||
Family int32
|
||||
Address []byte
|
||||
}
|
||||
|
||||
// systemdDbusDomainsInput maps to (sb) dbus input for SetDomains method
|
||||
type systemdDbusDomainsInput struct {
|
||||
Domain string
|
||||
MatchOnly bool
|
||||
}
|
||||
|
||||
// SystemdResolvedDNSConfigurator manages DNS settings using systemd-resolved D-Bus API
|
||||
type SystemdResolvedDNSConfigurator struct {
|
||||
ifaceName string
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewSystemdResolvedDNSConfigurator creates a new systemd-resolved DNS configurator
|
||||
func NewSystemdResolvedDNSConfigurator(ifaceName string) (*SystemdResolvedDNSConfigurator, error) {
|
||||
// Get network interface
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get interface: %w", err)
|
||||
}
|
||||
|
||||
// Connect to D-Bus
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
|
||||
|
||||
// Get the link object for this interface
|
||||
var linkPath string
|
||||
if err := obj.Call(systemdDbusGetLinkMethod, 0, iface.Index).Store(&linkPath); err != nil {
|
||||
return nil, fmt.Errorf("get link: %w", err)
|
||||
}
|
||||
|
||||
config := &SystemdResolvedDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
dbusLinkObject: dbus.ObjectPath(linkPath),
|
||||
}
|
||||
|
||||
// Call cleanup function here
|
||||
if err := config.CleanupUncleanShutdown(); err != nil {
|
||||
fmt.Printf("warning: cleanup unclean shutdown failed: %v\n", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (s *SystemdResolvedDNSConfigurator) Name() string {
|
||||
return "systemd-resolved"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (s *SystemdResolvedDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := s.GetCurrentDNS()
|
||||
if err != nil {
|
||||
// If we can't get current DNS, proceed anyway
|
||||
originalServers = []netip.Addr{}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
s.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: s.Name(),
|
||||
}
|
||||
|
||||
// Apply new DNS servers
|
||||
if err := s.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (s *SystemdResolvedDNSConfigurator) RestoreDNS() error {
|
||||
// Call Revert method to restore systemd-resolved defaults
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := obj.CallWithContext(ctx, systemdDbusRevertMethod, 0).Store(); err != nil {
|
||||
return fmt.Errorf("revert DNS settings: %w", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache after reverting
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
|
||||
// For systemd-resolved, the DNS configuration is tied to the network interface.
|
||||
// When the interface is destroyed and recreated, systemd-resolved automatically
|
||||
// clears the per-link DNS settings, so there's nothing to clean up.
|
||||
func (s *SystemdResolvedDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
// systemd-resolved DNS configuration is per-link and automatically cleared
|
||||
// when the link (interface) is destroyed. Since the WireGuard interface is
|
||||
// recreated on restart, there's no leftover state to clean up.
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
// Note: systemd-resolved doesn't easily expose current per-link DNS servers via D-Bus
|
||||
// This is a placeholder that returns an empty list
|
||||
func (s *SystemdResolvedDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
// systemd-resolved's D-Bus API doesn't have a simple way to query current DNS servers
|
||||
// We would need to parse resolvectl status output or read from /run/systemd/resolve/
|
||||
// For now, return empty list
|
||||
return []netip.Addr{}, nil
|
||||
}
|
||||
|
||||
// applyDNSServers applies DNS server configuration via systemd-resolved
|
||||
func (s *SystemdResolvedDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Convert servers to systemd-resolved format
|
||||
var dnsInputs []systemdDbusDNSInput
|
||||
for _, server := range servers {
|
||||
family := unix.AF_INET
|
||||
if server.Is6() {
|
||||
family = unix.AF_INET6
|
||||
}
|
||||
|
||||
dnsInputs = append(dnsInputs, systemdDbusDNSInput{
|
||||
Family: int32(family),
|
||||
Address: server.AsSlice(),
|
||||
})
|
||||
}
|
||||
|
||||
// Connect to D-Bus
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Call SetDNS method to set the DNS servers
|
||||
if err := obj.CallWithContext(ctx, systemdDbusSetDNSMethod, 0, dnsInputs).Store(); err != nil {
|
||||
return fmt.Errorf("set DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Set this interface as the default route for DNS
|
||||
// This ensures all DNS queries prefer this interface
|
||||
if err := s.callLinkMethod(systemdDbusSetDefaultRouteMethod, true); err != nil {
|
||||
return fmt.Errorf("set default route: %w", err)
|
||||
}
|
||||
|
||||
// Set the root zone "." as a match-only domain
|
||||
// This captures ALL DNS queries and routes them through this interface
|
||||
domainsInput := []systemdDbusDomainsInput{
|
||||
{
|
||||
Domain: RootZone,
|
||||
MatchOnly: true,
|
||||
},
|
||||
}
|
||||
if err := s.callLinkMethod(systemdDbusSetDomainsMethod, domainsInput); err != nil {
|
||||
return fmt.Errorf("set domains: %w", err)
|
||||
}
|
||||
|
||||
// Disable DNSSEC - we don't support it and it may be enabled by default
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSSECMethod, "no"); err != nil {
|
||||
// Log warning but don't fail - this is optional
|
||||
fmt.Printf("warning: failed to disable DNSSEC: %v\n", err)
|
||||
}
|
||||
|
||||
// Disable DNSOverTLS - we don't support it and it may be enabled by default
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethod, "no"); err != nil {
|
||||
// Log warning but don't fail - this is optional
|
||||
fmt.Printf("warning: failed to disable DNSOverTLS: %v\n", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache to ensure new settings take effect immediately
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callLinkMethod is a helper to call methods on the link object
|
||||
func (s *SystemdResolvedDNSConfigurator) callLinkMethod(method string, value any) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if value != nil {
|
||||
if err := obj.CallWithContext(ctx, method, 0, value).Store(); err != nil {
|
||||
return fmt.Errorf("call %s: %w", method, err)
|
||||
}
|
||||
} else {
|
||||
if err := obj.CallWithContext(ctx, method, 0).Store(); err != nil {
|
||||
return fmt.Errorf("call %s: %w", method, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// flushDNSCache flushes the systemd-resolved DNS cache
|
||||
func (s *SystemdResolvedDNSConfigurator) flushDNSCache() error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, 0).Store(); err != nil {
|
||||
return fmt.Errorf("flush caches: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsSystemdResolvedAvailable checks if systemd-resolved is available and responsive
|
||||
func IsSystemdResolvedAvailable() bool {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Try to ping systemd-resolved
|
||||
if err := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
45
dns/platform/types.go
Normal file
45
dns/platform/types.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package dns
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// DNSConfigurator provides an interface for managing system DNS settings
|
||||
// across different platforms and implementations
|
||||
type DNSConfigurator interface {
|
||||
// SetDNS overrides the system DNS servers with the specified ones
|
||||
// Returns the original DNS servers that were replaced
|
||||
SetDNS(servers []netip.Addr) ([]netip.Addr, error)
|
||||
|
||||
// RestoreDNS restores the original DNS servers
|
||||
RestoreDNS() error
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
GetCurrentDNS() ([]netip.Addr, error)
|
||||
|
||||
// Name returns the name of this configurator implementation
|
||||
Name() string
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from
|
||||
// a previous crash or unclean shutdown. This should be called on startup.
|
||||
CleanupUncleanShutdown() error
|
||||
}
|
||||
|
||||
// DNSConfig contains the configuration for DNS override
|
||||
type DNSConfig struct {
|
||||
// Servers is the list of DNS servers to use
|
||||
Servers []netip.Addr
|
||||
|
||||
// SearchDomains is an optional list of search domains
|
||||
SearchDomains []string
|
||||
}
|
||||
|
||||
// DNSState represents the saved state of DNS configuration
|
||||
type DNSState struct {
|
||||
// OriginalServers are the DNS servers before override
|
||||
OriginalServers []netip.Addr
|
||||
|
||||
// OriginalSearchDomains are the search domains before override
|
||||
OriginalSearchDomains []string
|
||||
|
||||
// ConfiguratorName is the name of the configurator that saved this state
|
||||
ConfiguratorName string
|
||||
}
|
||||
355
dns/platform/windows.go
Normal file
355
dns/platform/windows.go
Normal file
@@ -0,0 +1,355 @@
|
||||
//go:build windows
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
|
||||
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
|
||||
)
|
||||
|
||||
const (
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServer = "NameServer"
|
||||
interfaceConfigDhcpNameServer = "DhcpNameServer"
|
||||
)
|
||||
|
||||
// WindowsDNSConfigurator manages DNS settings on Windows using the registry
|
||||
type WindowsDNSConfigurator struct {
|
||||
guid string
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
|
||||
// Accepts an interface name and extracts the GUID internally
|
||||
func NewWindowsDNSConfigurator(interfaceName string) (*WindowsDNSConfigurator, error) {
|
||||
if interfaceName == "" {
|
||||
return nil, fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
guid, err := getInterfaceGUIDString(interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get interface GUID: %w", err)
|
||||
}
|
||||
|
||||
return &WindowsDNSConfigurator{
|
||||
guid: guid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newWindowsDNSConfiguratorFromGUID creates a configurator from a GUID string
|
||||
// This is an internal function for use by DetectBestConfigurator
|
||||
func newWindowsDNSConfiguratorFromGUID(guid string) (*WindowsDNSConfigurator, error) {
|
||||
if guid == "" {
|
||||
return nil, fmt.Errorf("interface GUID is required")
|
||||
}
|
||||
|
||||
return &WindowsDNSConfigurator{
|
||||
guid: guid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (w *WindowsDNSConfigurator) Name() string {
|
||||
return "windows-registry"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (w *WindowsDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := w.GetCurrentDNS()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current DNS: %w", err)
|
||||
}
|
||||
|
||||
// Store original state
|
||||
w.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: w.Name(),
|
||||
}
|
||||
|
||||
// Set new DNS servers
|
||||
if err := w.setDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("set DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := w.flushDNSCache(); err != nil {
|
||||
// Non-fatal, just log
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (w *WindowsDNSConfigurator) RestoreDNS() error {
|
||||
if w.originalState == nil {
|
||||
return fmt.Errorf("no original state to restore")
|
||||
}
|
||||
|
||||
// Clear the static DNS setting
|
||||
if err := w.clearDNSServers(); err != nil {
|
||||
return fmt.Errorf("clear DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := w.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
|
||||
// On Windows, we rely on the registry-based approach which doesn't leave orphaned state
|
||||
// in the same way as macOS scutil. The DNS settings are tied to the interface which
|
||||
// gets recreated on restart.
|
||||
func (w *WindowsDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
// Windows DNS configuration via registry is interface-specific.
|
||||
// When the WireGuard interface is recreated, it gets a new GUID,
|
||||
// so there's no leftover state to clean up from previous sessions.
|
||||
// The old interface's registry keys are effectively orphaned but harmless.
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (w *WindowsDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
regKey, err := w.getInterfaceRegistryKey(registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer closeKey(regKey)
|
||||
|
||||
// Try to get static DNS first
|
||||
nameServer, _, err := regKey.GetStringValue(interfaceConfigNameServer)
|
||||
if err == nil && nameServer != "" {
|
||||
return w.parseServerList(nameServer), nil
|
||||
}
|
||||
|
||||
// Fall back to DHCP DNS
|
||||
dhcpNameServer, _, err := regKey.GetStringValue(interfaceConfigDhcpNameServer)
|
||||
if err == nil && dhcpNameServer != "" {
|
||||
return w.parseServerList(dhcpNameServer), nil
|
||||
}
|
||||
|
||||
return []netip.Addr{}, nil
|
||||
}
|
||||
|
||||
// setDNSServers sets the DNS servers in the registry
|
||||
func (w *WindowsDNSConfigurator) setDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer closeKey(regKey)
|
||||
|
||||
// Build comma-separated or space-separated list of servers
|
||||
var serverList string
|
||||
for i, server := range servers {
|
||||
if i > 0 {
|
||||
serverList += ","
|
||||
}
|
||||
serverList += server.String()
|
||||
}
|
||||
|
||||
if err := regKey.SetStringValue(interfaceConfigNameServer, serverList); err != nil {
|
||||
return fmt.Errorf("set NameServer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearDNSServers clears the static DNS server setting
|
||||
func (w *WindowsDNSConfigurator) clearDNSServers() error {
|
||||
regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer closeKey(regKey)
|
||||
|
||||
// Set empty string to revert to DHCP
|
||||
if err := regKey.SetStringValue(interfaceConfigNameServer, ""); err != nil {
|
||||
return fmt.Errorf("clear NameServer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInterfaceRegistryKey opens the registry key for the network interface
|
||||
func (w *WindowsDNSConfigurator) getInterfaceRegistryKey(access uint32) (registry.Key, error) {
|
||||
regKeyPath := interfaceConfigPath + `\` + w.guid
|
||||
|
||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, access)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("open HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
|
||||
}
|
||||
|
||||
return regKey, nil
|
||||
}
|
||||
|
||||
// parseServerList parses a comma or space-separated list of DNS servers
|
||||
func (w *WindowsDNSConfigurator) parseServerList(serverList string) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
|
||||
// Split by comma or space
|
||||
parts := splitByDelimiters(serverList, []rune{',', ' '})
|
||||
|
||||
for _, part := range parts {
|
||||
if addr, err := netip.ParseAddr(part); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// flushDNSCache flushes the Windows DNS resolver cache
|
||||
func (w *WindowsDNSConfigurator) flushDNSCache() error {
|
||||
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
fmt.Printf("warning: DnsFlushResolverCache panicked: %v\n", rec)
|
||||
}
|
||||
}()
|
||||
|
||||
ret, _, err := dnsFlushResolverCacheFn.Call()
|
||||
if ret == 0 {
|
||||
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
|
||||
}
|
||||
return fmt.Errorf("DnsFlushResolverCache failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitByDelimiters splits a string by multiple delimiters
|
||||
func splitByDelimiters(s string, delimiters []rune) []string {
|
||||
var result []string
|
||||
var current []rune
|
||||
|
||||
for _, char := range s {
|
||||
isDelimiter := false
|
||||
for _, delim := range delimiters {
|
||||
if char == delim {
|
||||
isDelimiter = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isDelimiter {
|
||||
if len(current) > 0 {
|
||||
result = append(result, string(current))
|
||||
current = []rune{}
|
||||
}
|
||||
} else {
|
||||
current = append(current, char)
|
||||
}
|
||||
}
|
||||
|
||||
if len(current) > 0 {
|
||||
result = append(result, string(current))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// closeKey closes a registry key and logs errors
|
||||
func closeKey(closer io.Closer) {
|
||||
if err := closer.Close(); err != nil {
|
||||
fmt.Printf("warning: failed to close registry key: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
|
||||
// This is required for registry-based DNS configuration on Windows
|
||||
func getInterfaceGUIDString(interfaceName string) (string, error) {
|
||||
if interfaceName == "" {
|
||||
return "", fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get interface %s: %w", interfaceName, err)
|
||||
}
|
||||
|
||||
luid, err := indexToLUID(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to convert index to LUID: %w", err)
|
||||
}
|
||||
|
||||
// Convert LUID to GUID using Windows API
|
||||
guid, err := luidToGUID(luid)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to convert LUID to GUID: %w", err)
|
||||
}
|
||||
|
||||
return guid, nil
|
||||
}
|
||||
|
||||
// indexToLUID converts a Windows interface index to a LUID
|
||||
func indexToLUID(index uint32) (uint64, error) {
|
||||
var luid uint64
|
||||
|
||||
// Load the iphlpapi.dll and get the ConvertInterfaceIndexToLuid function
|
||||
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
convertInterfaceIndexToLuid := iphlpapi.NewProc("ConvertInterfaceIndexToLuid")
|
||||
|
||||
// Call the Windows API
|
||||
ret, _, err := convertInterfaceIndexToLuid.Call(
|
||||
uintptr(index),
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
return 0, fmt.Errorf("ConvertInterfaceIndexToLuid failed with code %d: %w", ret, err)
|
||||
}
|
||||
|
||||
return luid, nil
|
||||
}
|
||||
|
||||
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
|
||||
// using the Windows ConvertInterface* APIs
|
||||
func luidToGUID(luid uint64) (string, error) {
|
||||
var guid windows.GUID
|
||||
|
||||
// Load the iphlpapi.dll and get the ConvertInterfaceLuidToGuid function
|
||||
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
convertLuidToGuid := iphlpapi.NewProc("ConvertInterfaceLuidToGuid")
|
||||
|
||||
// Call the Windows API
|
||||
// NET_LUID is a 64-bit value on Windows
|
||||
ret, _, err := convertLuidToGuid.Call(
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
uintptr(unsafe.Pointer(&guid)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
return "", fmt.Errorf("ConvertInterfaceLuidToGuid failed with code %d: %w", ret, err)
|
||||
}
|
||||
|
||||
// Format the GUID as a string with curly braces
|
||||
guidStr := fmt.Sprintf("{%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}",
|
||||
guid.Data1, guid.Data2, guid.Data3,
|
||||
guid.Data4[0], guid.Data4[1], guid.Data4[2], guid.Data4[3],
|
||||
guid.Data4[4], guid.Data4[5], guid.Data4[6], guid.Data4[7])
|
||||
|
||||
return guidStr, nil
|
||||
}
|
||||
@@ -5,6 +5,11 @@ services:
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- PANGOLIN_ENDPOINT=https://example.com
|
||||
- OLM_ID=2ix2t8xk22ubpfy
|
||||
- OLM_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||
- LOG_LEVEL=DEBUG
|
||||
- OLM_ID=vdqnz8rwgb95cnp
|
||||
- OLM_SECRET=1sw05qv1tkfdb1k81zpw05nahnnjvmhxjvf746umwagddmdg
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
- SYS_MODULE
|
||||
devices:
|
||||
- /dev/net/tun:/dev/net/tun
|
||||
network_mode: host
|
||||
279
get-olm.sh
Normal file
279
get-olm.sh
Normal file
@@ -0,0 +1,279 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Get Olm - Cross-platform installation script
|
||||
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/olm/refs/heads/main/get-olm.sh | bash
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# GitHub repository info
|
||||
REPO="fosrl/olm"
|
||||
GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
|
||||
|
||||
# Function to print colored output
|
||||
print_status() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
print_warning() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
print_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Function to get latest version from GitHub API
|
||||
get_latest_version() {
|
||||
local latest_info
|
||||
|
||||
if command -v curl >/dev/null 2>&1; then
|
||||
latest_info=$(curl -fsSL "$GITHUB_API_URL" 2>/dev/null)
|
||||
elif command -v wget >/dev/null 2>&1; then
|
||||
latest_info=$(wget -qO- "$GITHUB_API_URL" 2>/dev/null)
|
||||
else
|
||||
print_error "Neither curl nor wget is available. Please install one of them." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$latest_info" ]; then
|
||||
print_error "Failed to fetch latest version information" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract version from JSON response (works without jq)
|
||||
local version=$(echo "$latest_info" | grep '"tag_name"' | head -1 | sed 's/.*"tag_name": *"\([^"]*\)".*/\1/')
|
||||
|
||||
if [ -z "$version" ]; then
|
||||
print_error "Could not parse version from GitHub API response" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Remove 'v' prefix if present
|
||||
version=$(echo "$version" | sed 's/^v//')
|
||||
|
||||
echo "$version"
|
||||
}
|
||||
|
||||
# Detect OS and architecture
|
||||
detect_platform() {
|
||||
local os arch
|
||||
|
||||
# Detect OS
|
||||
case "$(uname -s)" in
|
||||
Linux*) os="linux" ;;
|
||||
Darwin*) os="darwin" ;;
|
||||
MINGW*|MSYS*|CYGWIN*) os="windows" ;;
|
||||
FreeBSD*) os="freebsd" ;;
|
||||
*)
|
||||
print_error "Unsupported operating system: $(uname -s)"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
# Detect architecture
|
||||
case "$(uname -m)" in
|
||||
x86_64|amd64) arch="amd64" ;;
|
||||
arm64|aarch64) arch="arm64" ;;
|
||||
armv7l|armv6l)
|
||||
if [ "$os" = "linux" ]; then
|
||||
if [ "$(uname -m)" = "armv6l" ]; then
|
||||
arch="arm32v6"
|
||||
else
|
||||
arch="arm32"
|
||||
fi
|
||||
else
|
||||
arch="arm64" # Default for non-Linux ARM
|
||||
fi
|
||||
;;
|
||||
riscv64)
|
||||
if [ "$os" = "linux" ]; then
|
||||
arch="riscv64"
|
||||
else
|
||||
print_error "RISC-V architecture only supported on Linux"
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
print_error "Unsupported architecture: $(uname -m)"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "${os}_${arch}"
|
||||
}
|
||||
|
||||
# Get installation directory
|
||||
get_install_dir() {
|
||||
local platform="$1"
|
||||
|
||||
if [[ "$platform" == *"windows"* ]]; then
|
||||
echo "$HOME/bin"
|
||||
else
|
||||
# For Unix-like systems, prioritize system-wide directories for sudo access
|
||||
# Check in order of preference: /usr/local/bin, /usr/bin, ~/.local/bin
|
||||
if [ -d "/usr/local/bin" ]; then
|
||||
echo "/usr/local/bin"
|
||||
elif [ -d "/usr/bin" ]; then
|
||||
echo "/usr/bin"
|
||||
else
|
||||
# Fallback to user directory if system directories don't exist
|
||||
echo "$HOME/.local/bin"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Check if we need sudo for installation
|
||||
need_sudo() {
|
||||
local install_dir="$1"
|
||||
|
||||
# If installing to system directory and we don't have write permission, need sudo
|
||||
if [[ "$install_dir" == "/usr/local/bin" || "$install_dir" == "/usr/bin" ]]; then
|
||||
if [ ! -w "$install_dir" ] 2>/dev/null; then
|
||||
return 0 # Need sudo
|
||||
fi
|
||||
fi
|
||||
return 1 # Don't need sudo
|
||||
}
|
||||
|
||||
# Download and install olm
|
||||
install_olm() {
|
||||
local platform="$1"
|
||||
local install_dir="$2"
|
||||
local binary_name="olm_${platform}"
|
||||
local exe_suffix=""
|
||||
|
||||
# Add .exe suffix for Windows
|
||||
if [[ "$platform" == *"windows"* ]]; then
|
||||
binary_name="${binary_name}.exe"
|
||||
exe_suffix=".exe"
|
||||
fi
|
||||
|
||||
local download_url="${BASE_URL}/${binary_name}"
|
||||
local temp_file="/tmp/olm${exe_suffix}"
|
||||
local final_path="${install_dir}/olm${exe_suffix}"
|
||||
|
||||
print_status "Downloading olm from ${download_url}"
|
||||
|
||||
# Download the binary
|
||||
if command -v curl >/dev/null 2>&1; then
|
||||
curl -fsSL "$download_url" -o "$temp_file"
|
||||
elif command -v wget >/dev/null 2>&1; then
|
||||
wget -q "$download_url" -O "$temp_file"
|
||||
else
|
||||
print_error "Neither curl nor wget is available. Please install one of them."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if we need sudo for installation
|
||||
local use_sudo=""
|
||||
if need_sudo "$install_dir"; then
|
||||
print_status "Administrator privileges required for system-wide installation"
|
||||
if command -v sudo >/dev/null 2>&1; then
|
||||
use_sudo="sudo"
|
||||
else
|
||||
print_error "sudo is required for system-wide installation but not available"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Create install directory if it doesn't exist
|
||||
if [ -n "$use_sudo" ]; then
|
||||
$use_sudo mkdir -p "$install_dir"
|
||||
else
|
||||
mkdir -p "$install_dir"
|
||||
fi
|
||||
|
||||
# Move binary to install directory
|
||||
if [ -n "$use_sudo" ]; then
|
||||
$use_sudo mv "$temp_file" "$final_path"
|
||||
$use_sudo chmod +x "$final_path"
|
||||
else
|
||||
mv "$temp_file" "$final_path"
|
||||
chmod +x "$final_path"
|
||||
fi
|
||||
|
||||
print_status "olm installed to ${final_path}"
|
||||
|
||||
# Check if install directory is in PATH (only warn for non-system directories)
|
||||
if [[ "$install_dir" != "/usr/local/bin" && "$install_dir" != "/usr/bin" ]]; then
|
||||
if ! echo "$PATH" | grep -q "$install_dir"; then
|
||||
print_warning "Install directory ${install_dir} is not in your PATH."
|
||||
print_warning "Add it to your PATH by adding this line to your shell profile:"
|
||||
print_warning " export PATH=\"${install_dir}:\$PATH\""
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Verify installation
|
||||
verify_installation() {
|
||||
local install_dir="$1"
|
||||
local exe_suffix=""
|
||||
|
||||
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||
exe_suffix=".exe"
|
||||
fi
|
||||
|
||||
local olm_path="${install_dir}/olm${exe_suffix}"
|
||||
|
||||
if [ -f "$olm_path" ] && [ -x "$olm_path" ]; then
|
||||
print_status "Installation successful!"
|
||||
print_status "olm version: $("$olm_path" --version 2>/dev/null || echo "unknown")"
|
||||
return 0
|
||||
else
|
||||
print_error "Installation failed. Binary not found or not executable."
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Main installation process
|
||||
main() {
|
||||
print_status "Installing latest version of olm..."
|
||||
|
||||
# Get latest version
|
||||
print_status "Fetching latest version from GitHub..."
|
||||
VERSION=$(get_latest_version)
|
||||
print_status "Latest version: v${VERSION}"
|
||||
|
||||
# Set base URL with the fetched version
|
||||
BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}"
|
||||
|
||||
# Detect platform
|
||||
PLATFORM=$(detect_platform)
|
||||
print_status "Detected platform: ${PLATFORM}"
|
||||
|
||||
# Get install directory
|
||||
INSTALL_DIR=$(get_install_dir "$PLATFORM")
|
||||
print_status "Install directory: ${INSTALL_DIR}"
|
||||
|
||||
# Inform user about system-wide installation
|
||||
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
|
||||
print_status "Installing system-wide for sudo access"
|
||||
fi
|
||||
|
||||
# Install olm
|
||||
install_olm "$PLATFORM" "$INSTALL_DIR"
|
||||
|
||||
# Verify installation
|
||||
if verify_installation "$INSTALL_DIR"; then
|
||||
print_status "olm is ready to use!"
|
||||
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
|
||||
print_status "olm is installed system-wide and accessible via sudo"
|
||||
fi
|
||||
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||
print_status "Run 'olm --help' to get started"
|
||||
else
|
||||
print_status "Run 'olm --help' or 'sudo olm --help' to get started"
|
||||
fi
|
||||
else
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Run main function
|
||||
main "$@"
|
||||
60
go.mod
60
go.mod
@@ -1,55 +1,35 @@
|
||||
module github.com/fosrl/olm
|
||||
|
||||
go 1.23.1
|
||||
|
||||
toolchain go1.23.2
|
||||
go 1.25
|
||||
|
||||
require (
|
||||
github.com/fosrl/newt v0.0.0-20250724194014-008be54c554a
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.40.0
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
|
||||
golang.org/x/net v0.42.0
|
||||
golang.org/x/sys v0.34.0
|
||||
github.com/Microsoft/go-winio v0.6.2
|
||||
github.com/fosrl/newt v1.9.0
|
||||
github.com/godbus/dbus/v5 v5.2.2
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/miekg/dns v1.1.70
|
||||
golang.org/x/sys v0.40.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||
software.sslmate.com/src/go-pkcs12 v0.7.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/docker v28.3.2+incompatible // indirect
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/gopacket v1.1.19 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/vishvananda/netlink v1.3.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect
|
||||
go.opentelemetry.io/otel v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||
golang.org/x/mod v0.26.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/crypto v0.46.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
golang.org/x/tools v0.35.0 // indirect
|
||||
golang.org/x/tools v0.40.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
|
||||
software.sslmate.com/src/go-pkcs12 v0.5.0 // indirect
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||
)
|
||||
|
||||
// To be used ONLY for local development
|
||||
// replace github.com/fosrl/newt => ../newt
|
||||
|
||||
131
go.sum
131
go.sum
@@ -1,127 +1,48 @@
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.3.2+incompatible h1:wn66NJ6pWB1vBZIilP8G3qQPqHy5XymfYn5vsqeA5oA=
|
||||
github.com/docker/docker v28.3.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/fosrl/newt v0.0.0-20250717220102-cd86e6b6de83 h1:jI6tP2sJNNb70Y+Ixq+oI06fDPnGUbarz/r67g7KvB8=
|
||||
github.com/fosrl/newt v0.0.0-20250717220102-cd86e6b6de83/go.mod h1:oqHsCx1rsEc8hAVGVXemfeolIwlr19biJSQiLYi7Mvo=
|
||||
github.com/fosrl/newt v0.0.0-20250718235538-510e78437ca4 h1:bK/MQyTOLGthrXZ7ExvOCdW0EH0o9b5vwk/+UKnNdg0=
|
||||
github.com/fosrl/newt v0.0.0-20250718235538-510e78437ca4/go.mod h1:oqHsCx1rsEc8hAVGVXemfeolIwlr19biJSQiLYi7Mvo=
|
||||
github.com/fosrl/newt v0.0.0-20250724190153-64c22a94a47a h1:Jgd60yfFJxb5z6L3LcoraaosHjiRgKLnMz6T3mv3D4Q=
|
||||
github.com/fosrl/newt v0.0.0-20250724190153-64c22a94a47a/go.mod h1:oqHsCx1rsEc8hAVGVXemfeolIwlr19biJSQiLYi7Mvo=
|
||||
github.com/fosrl/newt v0.0.0-20250724194014-008be54c554a h1:17r/Uhef6aIxpO0xYGI3771LJx7cTyc1WziDOgghc54=
|
||||
github.com/fosrl/newt v0.0.0-20250724194014-008be54c554a/go.mod h1:oqHsCx1rsEc8hAVGVXemfeolIwlr19biJSQiLYi7Mvo=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE=
|
||||
github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck=
|
||||
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
|
||||
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA=
|
||||
github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 h1:Hf9xI/XLML9ElpiHVDNwvqI0hIFlzV8dgIr35kV1kRU=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0/go.mod h1:NfchwuyNoMcZ5MLHwPrODwUF1HWCXWrL31s8gSAdIKY=
|
||||
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
|
||||
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
|
||||
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
|
||||
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c=
|
||||
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs=
|
||||
software.sslmate.com/src/go-pkcs12 v0.5.0 h1:EC6R394xgENTpZ4RltKydeDUjtlM5drOYIG9c6TVj2M=
|
||||
software.sslmate.com/src/go-pkcs12 v0.5.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
software.sslmate.com/src/go-pkcs12 v0.7.0 h1:Db8W44cB54TWD7stUFFSWxdfpdn6fZVcDl0w3R4RVM0=
|
||||
software.sslmate.com/src/go-pkcs12 v0.7.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// ConnectionRequest defines the structure for an incoming connection request
|
||||
type ConnectionRequest struct {
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
// PeerStatus represents the status of a peer connection
|
||||
type PeerStatus struct {
|
||||
SiteID int `json:"siteId"`
|
||||
Connected bool `json:"connected"`
|
||||
RTT time.Duration `json:"rtt"`
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
}
|
||||
|
||||
// StatusResponse is returned by the status endpoint
|
||||
type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Connected bool `json:"connected"`
|
||||
TunnelIP string `json:"tunnelIP,omitempty"`
|
||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPServer represents the HTTP server and its state
|
||||
type HTTPServer struct {
|
||||
addr string
|
||||
server *http.Server
|
||||
connectionChan chan ConnectionRequest
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
connectedAt time.Time
|
||||
isConnected bool
|
||||
}
|
||||
|
||||
// NewHTTPServer creates a new HTTP server
|
||||
func NewHTTPServer(addr string) *HTTPServer {
|
||||
s := &HTTPServer{
|
||||
addr: addr,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *HTTPServer) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/connect", s.handleConnect)
|
||||
mux.HandleFunc("/status", s.handleStatus)
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: s.addr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
logger.Info("Starting HTTP server on %s", s.addr)
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("HTTP server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the HTTP server
|
||||
func (s *HTTPServer) Stop() error {
|
||||
logger.Info("Stopping HTTP server")
|
||||
return s.server.Close()
|
||||
}
|
||||
|
||||
// GetConnectionChannel returns the channel for receiving connection requests
|
||||
func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest {
|
||||
return s.connectionChan
|
||||
}
|
||||
|
||||
// UpdatePeerStatus updates the status of a peer
|
||||
func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.Connected = connected
|
||||
status.RTT = rtt
|
||||
status.LastSeen = time.Now()
|
||||
}
|
||||
|
||||
// SetConnectionStatus sets the overall connection status
|
||||
func (s *HTTPServer) SetConnectionStatus(isConnected bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
s.isConnected = isConnected
|
||||
|
||||
if isConnected {
|
||||
s.connectedAt = time.Now()
|
||||
} else {
|
||||
// Clear peer statuses when disconnected
|
||||
s.peerStatuses = make(map[int]*PeerStatus)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnect handles the /connect endpoint
|
||||
func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req ConnectionRequest
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
|
||||
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Send the request to the main goroutine
|
||||
s.connectionChan <- req
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "connection request accepted",
|
||||
})
|
||||
}
|
||||
|
||||
// handleStatus handles the /status endpoint
|
||||
func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
s.statusMu.RLock()
|
||||
defer s.statusMu.RUnlock()
|
||||
|
||||
resp := StatusResponse{
|
||||
Connected: s.isConnected,
|
||||
PeerStatuses: s.peerStatuses,
|
||||
}
|
||||
|
||||
if s.isConnected {
|
||||
resp.Status = "connected"
|
||||
} else {
|
||||
resp.Status = "disconnected"
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
883
main.go
883
main.go
@@ -2,27 +2,15 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/fosrl/olm/httpserver"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/wgtester"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"github.com/fosrl/newt/updates"
|
||||
olmpkg "github.com/fosrl/olm/olm"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -34,8 +22,15 @@ func main() {
|
||||
}
|
||||
|
||||
// Handle service management commands on Windows
|
||||
if runtime.GOOS == "windows" && len(os.Args) > 1 {
|
||||
switch os.Args[1] {
|
||||
if runtime.GOOS == "windows" {
|
||||
var command string
|
||||
if len(os.Args) > 1 {
|
||||
command = os.Args[1]
|
||||
} else {
|
||||
command = "default"
|
||||
}
|
||||
|
||||
switch command {
|
||||
case "install":
|
||||
err := installService()
|
||||
if err != nil {
|
||||
@@ -109,15 +104,27 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
case "config":
|
||||
if runtime.GOOS == "windows" {
|
||||
showServiceConfig()
|
||||
} else {
|
||||
fmt.Println("Service configuration is only available on Windows")
|
||||
}
|
||||
return
|
||||
case "help", "--help", "-h":
|
||||
fmt.Println("Olm WireGuard VPN Client")
|
||||
fmt.Println("\nWindows Service Management:")
|
||||
fmt.Println(" install Install the service")
|
||||
fmt.Println(" remove Remove the service")
|
||||
fmt.Println(" start Start the service")
|
||||
fmt.Println(" start [args] Start the service with optional arguments")
|
||||
fmt.Println(" stop Stop the service")
|
||||
fmt.Println(" status Show service status")
|
||||
fmt.Println(" debug Run service in debug mode")
|
||||
fmt.Println(" debug [args] Run service in debug mode with optional arguments")
|
||||
fmt.Println(" logs Tail the service log file")
|
||||
fmt.Println(" config Show current service configuration")
|
||||
fmt.Println("\nExamples:")
|
||||
fmt.Println(" olm start --enable-http --http-addr :9452")
|
||||
fmt.Println(" olm debug --endpoint https://example.com --id myid --secret mysecret")
|
||||
fmt.Println("\nFor console mode, run without arguments or with standard flags.")
|
||||
return
|
||||
default:
|
||||
@@ -147,782 +154,116 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// Create a context that will be cancelled on interrupt signals
|
||||
signalCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
// Create a separate context for programmatic shutdown (e.g., via API exit)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Run in console mode
|
||||
runOlmMain(context.Background())
|
||||
runOlmMainWithArgs(ctx, cancel, signalCtx, os.Args[1:])
|
||||
}
|
||||
|
||||
func runOlmMain(ctx context.Context) {
|
||||
runOlmMainWithArgs(ctx, os.Args[1:])
|
||||
}
|
||||
|
||||
func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
// Log that we've entered the main function
|
||||
// fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args)
|
||||
|
||||
// Create a new FlagSet for parsing service arguments
|
||||
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
|
||||
|
||||
var (
|
||||
endpoint string
|
||||
id string
|
||||
secret string
|
||||
mtu string
|
||||
mtuInt int
|
||||
dns string
|
||||
privateKey wgtypes.Key
|
||||
err error
|
||||
logLevel string
|
||||
interfaceName string
|
||||
enableHTTP bool
|
||||
httpAddr string
|
||||
testMode bool // Add this var for the test flag
|
||||
testTarget string // Add this var for test target
|
||||
pingInterval time.Duration
|
||||
pingTimeout time.Duration
|
||||
doHolepunch bool
|
||||
connected bool
|
||||
)
|
||||
|
||||
stopHolepunch = make(chan struct{})
|
||||
stopPing = make(chan struct{})
|
||||
|
||||
// if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values
|
||||
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
||||
id = os.Getenv("OLM_ID")
|
||||
secret = os.Getenv("OLM_SECRET")
|
||||
mtu = os.Getenv("MTU")
|
||||
dns = os.Getenv("DNS")
|
||||
logLevel = os.Getenv("LOG_LEVEL")
|
||||
interfaceName = os.Getenv("INTERFACE")
|
||||
httpAddr = os.Getenv("HTTP_ADDR")
|
||||
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
||||
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
|
||||
doHolepunch = os.Getenv("HOLEPUNCH") == "true" // Default to true, can be overridden by flag
|
||||
|
||||
if endpoint == "" {
|
||||
serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
|
||||
}
|
||||
if id == "" {
|
||||
serviceFlags.StringVar(&id, "id", "", "Olm ID")
|
||||
}
|
||||
if secret == "" {
|
||||
serviceFlags.StringVar(&secret, "secret", "", "Olm secret")
|
||||
}
|
||||
if mtu == "" {
|
||||
serviceFlags.StringVar(&mtu, "mtu", "1280", "MTU to use")
|
||||
}
|
||||
if dns == "" {
|
||||
serviceFlags.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
|
||||
}
|
||||
if logLevel == "" {
|
||||
serviceFlags.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
}
|
||||
if interfaceName == "" {
|
||||
serviceFlags.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface")
|
||||
}
|
||||
if httpAddr == "" {
|
||||
serviceFlags.StringVar(&httpAddr, "http-addr", ":9452", "HTTP server address (e.g., ':9452')")
|
||||
}
|
||||
if pingIntervalStr == "" {
|
||||
serviceFlags.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)")
|
||||
}
|
||||
if pingTimeoutStr == "" {
|
||||
serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)")
|
||||
}
|
||||
serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests")
|
||||
serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)")
|
||||
|
||||
// Parse the service arguments
|
||||
if err := serviceFlags.Parse(args); err != nil {
|
||||
fmt.Printf("Error parsing service arguments: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Debug: Print final values after flag parsing
|
||||
// fmt.Printf("After flag parsing: endpoint='%s', id='%s', secret='%s'\n", endpoint, id, secret)
|
||||
|
||||
// Parse ping intervals
|
||||
if pingIntervalStr != "" {
|
||||
pingInterval, err = time.ParseDuration(pingIntervalStr)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr)
|
||||
pingInterval = 3 * time.Second
|
||||
}
|
||||
} else {
|
||||
pingInterval = 3 * time.Second
|
||||
}
|
||||
|
||||
if pingTimeoutStr != "" {
|
||||
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr)
|
||||
pingTimeout = 5 * time.Second
|
||||
}
|
||||
} else {
|
||||
pingTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCtx context.Context, args []string) {
|
||||
// Setup Windows event logging if on Windows
|
||||
if runtime.GOOS == "windows" {
|
||||
setupWindowsEventLog()
|
||||
} else {
|
||||
// Initialize logger for non-Windows platforms
|
||||
logger.Init()
|
||||
}
|
||||
loggerLevel := parseLogLevel(logLevel)
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
// Log startup information
|
||||
logger.Debug("Olm service starting...")
|
||||
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
|
||||
logger.Debug("HTTP enabled: %v, HTTP addr: %s", enableHTTP, httpAddr)
|
||||
|
||||
if doHolepunch {
|
||||
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||
logger.Init(nil)
|
||||
}
|
||||
|
||||
// Handle test mode
|
||||
if testMode {
|
||||
if testTarget == "" {
|
||||
logger.Fatal("Test mode requires -test-target to be set to a server:port")
|
||||
}
|
||||
|
||||
logger.Info("Running in test mode, connecting to %s", testTarget)
|
||||
|
||||
// Create a new tester client
|
||||
tester, err := wgtester.NewClient(testTarget)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create tester client: %v", err)
|
||||
}
|
||||
defer tester.Close()
|
||||
|
||||
// Test connection with a 2-second timeout
|
||||
connected, rtt := tester.TestConnectionWithTimeout(2 * time.Second)
|
||||
|
||||
if connected {
|
||||
logger.Info("Connection test successful! RTT: %v", rtt)
|
||||
fmt.Printf("Connection test successful! RTT: %v\n", rtt)
|
||||
os.Exit(0)
|
||||
} else {
|
||||
logger.Error("Connection test failed - no response received")
|
||||
fmt.Println("Connection test failed - no response received")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var httpServer *httpserver.HTTPServer
|
||||
if enableHTTP {
|
||||
httpServer = httpserver.NewHTTPServer(httpAddr)
|
||||
if err := httpServer.Start(); err != nil {
|
||||
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||
}
|
||||
|
||||
// Use a goroutine to handle connection requests
|
||||
go func() {
|
||||
for req := range httpServer.GetConnectionChannel() {
|
||||
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
||||
|
||||
// Set the connection parameters
|
||||
id = req.ID
|
||||
secret = req.Secret
|
||||
endpoint = req.Endpoint
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// // Check if required parameters are missing and provide helpful guidance
|
||||
// missingParams := []string{}
|
||||
// if id == "" {
|
||||
// missingParams = append(missingParams, "id (use -id flag or OLM_ID env var)")
|
||||
// }
|
||||
// if secret == "" {
|
||||
// missingParams = append(missingParams, "secret (use -secret flag or OLM_SECRET env var)")
|
||||
// }
|
||||
// if endpoint == "" {
|
||||
// missingParams = append(missingParams, "endpoint (use -endpoint flag or PANGOLIN_ENDPOINT env var)")
|
||||
// }
|
||||
|
||||
// if len(missingParams) > 0 {
|
||||
// logger.Error("Missing required parameters: %v", missingParams)
|
||||
// logger.Error("Either provide them as command line flags or set as environment variables")
|
||||
// fmt.Printf("ERROR: Missing required parameters: %v\n", missingParams)
|
||||
// fmt.Printf("Please provide them as command line flags or set as environment variables\n")
|
||||
// if !enableHTTP {
|
||||
// logger.Error("HTTP server is disabled, cannot receive parameters via API")
|
||||
// fmt.Printf("HTTP server is disabled, cannot receive parameters via API\n")
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
|
||||
// // wait until we have a client id and secret and endpoint
|
||||
// waitCount := 0
|
||||
// for id == "" || secret == "" || endpoint == "" {
|
||||
// select {
|
||||
// case <-ctx.Done():
|
||||
// logger.Info("Context cancelled while waiting for credentials")
|
||||
// return
|
||||
// default:
|
||||
// missing := []string{}
|
||||
// if id == "" {
|
||||
// missing = append(missing, "id")
|
||||
// }
|
||||
// if secret == "" {
|
||||
// missing = append(missing, "secret")
|
||||
// }
|
||||
// if endpoint == "" {
|
||||
// missing = append(missing, "endpoint")
|
||||
// }
|
||||
// waitCount++
|
||||
// if waitCount%10 == 1 { // Log every 10 seconds instead of every second
|
||||
// logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount)
|
||||
// }
|
||||
// time.Sleep(1 * time.Second)
|
||||
// }
|
||||
// }
|
||||
|
||||
// parse the mtu string into an int
|
||||
mtuInt, err = strconv.Atoi(mtu)
|
||||
// Load configuration from file, env vars, and CLI args
|
||||
// Priority: CLI args > Env vars > Config file > Defaults
|
||||
// Use the passed args parameter instead of os.Args[1:] to support Windows service mode
|
||||
config, showVersion, showConfig, err := LoadConfig(args)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse MTU: %v", err)
|
||||
fmt.Printf("Failed to load configuration: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||
// Handle --show-config flag
|
||||
if showConfig {
|
||||
config.ShowConfig()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
olmVersion := "version_replaceme"
|
||||
if showVersion {
|
||||
fmt.Println("Olm version " + olmVersion)
|
||||
os.Exit(0)
|
||||
}
|
||||
logger.Info("Olm version %s", olmVersion)
|
||||
|
||||
config.Version = olmVersion
|
||||
|
||||
if err := SaveConfig(config); err != nil {
|
||||
logger.Error("Failed to save full olm config: %v", err)
|
||||
} else {
|
||||
logger.Debug("Saved full olm config with all options")
|
||||
}
|
||||
|
||||
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
|
||||
logger.Debug("Failed to check for updates: %v", err)
|
||||
}
|
||||
|
||||
// Create a new olm.Config struct and copy values from the main config
|
||||
olmConfig := olmpkg.OlmConfig{
|
||||
LogLevel: config.LogLevel,
|
||||
EnableAPI: config.EnableAPI,
|
||||
HTTPAddr: config.HTTPAddr,
|
||||
SocketPath: config.SocketPath,
|
||||
Version: config.Version,
|
||||
Agent: "Olm CLI",
|
||||
OnExit: cancel, // Pass cancel function directly to trigger shutdown
|
||||
OnTerminated: cancel,
|
||||
PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE
|
||||
}
|
||||
|
||||
olm, err := olmpkg.Init(ctx, olmConfig)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to generate private key: %v", err)
|
||||
logger.Fatal("Failed to initialize olm: %v", err)
|
||||
}
|
||||
|
||||
// Create a new olm
|
||||
olm, err := websocket.NewClient(
|
||||
"olm",
|
||||
id, // CLI arg takes precedence
|
||||
secret, // CLI arg takes precedence
|
||||
endpoint,
|
||||
pingInterval,
|
||||
pingTimeout,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create olm: %v", err)
|
||||
}
|
||||
endpoint = olm.GetConfig().Endpoint // Update endpoint from config
|
||||
id = olm.GetConfig().ID // Update ID from config
|
||||
|
||||
// Create TUN device and network stack
|
||||
var dev *device.Device
|
||||
var wgData WgData
|
||||
var holePunchData HolePunchData
|
||||
var uapiListener net.Listener
|
||||
var tdev tun.Device
|
||||
|
||||
sourcePort, err := FindAvailableUDPPort(49152, 65535)
|
||||
if err != nil {
|
||||
fmt.Printf("Error finding available port: %v\n", err)
|
||||
os.Exit(1)
|
||||
if err := olm.StartApi(); err != nil {
|
||||
logger.Fatal("Failed to start API server: %v", err)
|
||||
}
|
||||
|
||||
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
|
||||
tunnelConfig := olmpkg.TunnelConfig{
|
||||
Endpoint: config.Endpoint,
|
||||
ID: config.ID,
|
||||
Secret: config.Secret,
|
||||
UserToken: config.UserToken,
|
||||
MTU: config.MTU,
|
||||
DNS: config.DNS,
|
||||
UpstreamDNS: config.UpstreamDNS,
|
||||
InterfaceName: config.InterfaceName,
|
||||
Holepunch: !config.DisableHolepunch,
|
||||
TlsClientCert: config.TlsClientCert,
|
||||
PingIntervalDuration: config.PingIntervalDuration,
|
||||
PingTimeoutDuration: config.PingTimeoutDuration,
|
||||
OrgID: config.OrgID,
|
||||
OverrideDNS: config.OverrideDNS,
|
||||
DisableRelay: config.DisableRelay,
|
||||
EnableUAPI: true,
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
gerbilServerPubKey = holePunchData.ServerPubKey
|
||||
|
||||
go keepSendingUDPHolePunch(holePunchData.Endpoint, id, sourcePort)
|
||||
})
|
||||
|
||||
// Register handlers for different message types
|
||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
|
||||
if connected {
|
||||
logger.Info("Already connected. Ignoring new connection request.")
|
||||
return
|
||||
}
|
||||
|
||||
if stopRegister != nil {
|
||||
stopRegister()
|
||||
stopRegister = nil
|
||||
}
|
||||
|
||||
close(stopHolepunch)
|
||||
|
||||
// wait 10 milliseconds to ensure the previous connection is closed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// if there is an existing tunnel then close it
|
||||
if dev != nil {
|
||||
logger.Info("Got new message. Closing existing tunnel!")
|
||||
dev.Close()
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
tdev, err = func() (tun.Device, error) {
|
||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||
|
||||
// if on macOS, call findUnusedUTUN to get a new utun device
|
||||
if runtime.GOOS == "darwin" {
|
||||
interfaceName, err := findUnusedUTUN()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tun.CreateTUN(interfaceName, mtuInt)
|
||||
}
|
||||
|
||||
if tunFdStr == "" {
|
||||
return tun.CreateTUN(interfaceName, mtuInt)
|
||||
}
|
||||
|
||||
return createTUNFromFD(tunFdStr, mtuInt)
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Failed to create TUN device: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
realInterfaceName, err2 := tdev.Name()
|
||||
if err2 == nil {
|
||||
interfaceName = realInterfaceName
|
||||
}
|
||||
|
||||
// open UAPI file (or use supplied fd)
|
||||
fileUAPI, err := func() (*os.File, error) {
|
||||
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
|
||||
if uapiFdStr == "" {
|
||||
return uapiOpen(interfaceName)
|
||||
}
|
||||
|
||||
// use supplied fd
|
||||
|
||||
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return os.NewFile(uintptr(fd), ""), nil
|
||||
}()
|
||||
if err != nil {
|
||||
logger.Error("UAPI listen error: %v", err)
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(
|
||||
mapToWireGuardLogLevel(loggerLevel),
|
||||
"wireguard: ",
|
||||
))
|
||||
|
||||
errs := make(chan error)
|
||||
|
||||
uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||
if err != nil {
|
||||
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := uapiListener.Accept()
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
go dev.IpcHandle(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("UAPI listener started")
|
||||
|
||||
// Bring up the device
|
||||
err = dev.Up()
|
||||
if err != nil {
|
||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// configure the interface
|
||||
err = ConfigureInterface(realInterfaceName, wgData)
|
||||
if err != nil {
|
||||
logger.Error("Failed to configure interface: %v", err)
|
||||
}
|
||||
|
||||
peerMonitor = peermonitor.NewPeerMonitor(
|
||||
func(siteID int, connected bool, rtt time.Duration) {
|
||||
if httpServer != nil {
|
||||
httpServer.UpdatePeerStatus(siteID, connected, rtt)
|
||||
}
|
||||
if connected {
|
||||
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
|
||||
} else {
|
||||
logger.Warn("Peer %d is disconnected", siteID)
|
||||
}
|
||||
},
|
||||
fixKey(privateKey.String()),
|
||||
olm,
|
||||
dev,
|
||||
doHolepunch,
|
||||
)
|
||||
|
||||
// loop over the sites and call ConfigurePeer for each one
|
||||
for _, site := range wgData.Sites {
|
||||
if httpServer != nil {
|
||||
httpServer.UpdatePeerStatus(site.SiteId, false, 0)
|
||||
}
|
||||
err = ConfigurePeer(dev, site, privateKey, endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to configure peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = addRouteForServerIP(site.ServerIP, interfaceName)
|
||||
if err != nil {
|
||||
logger.Error("Failed to add route for peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Add routes for remote subnets
|
||||
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
|
||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Configured peer %s", site.PublicKey)
|
||||
}
|
||||
|
||||
peerMonitor.Start()
|
||||
|
||||
connected = true
|
||||
|
||||
logger.Info("WireGuard device created.")
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received update-peer message: %v", msg.Data)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var updateData UpdatePeerData
|
||||
if err := json.Unmarshal(jsonData, &updateData); err != nil {
|
||||
logger.Error("Error unmarshaling update data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to SiteConfig
|
||||
siteConfig := SiteConfig{
|
||||
SiteId: updateData.SiteId,
|
||||
Endpoint: updateData.Endpoint,
|
||||
PublicKey: updateData.PublicKey,
|
||||
ServerIP: updateData.ServerIP,
|
||||
ServerPort: updateData.ServerPort,
|
||||
RemoteSubnets: updateData.RemoteSubnets,
|
||||
}
|
||||
|
||||
// Update the peer in WireGuard
|
||||
if dev != nil {
|
||||
// Find the existing peer to get old RemoteSubnets
|
||||
var oldRemoteSubnets string
|
||||
for _, site := range wgData.Sites {
|
||||
if site.SiteId == updateData.SiteId {
|
||||
oldRemoteSubnets = site.RemoteSubnets
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||
logger.Error("Failed to update peer: %v", err)
|
||||
// Send error response if needed
|
||||
return
|
||||
}
|
||||
|
||||
// Remove old remote subnet routes if they changed
|
||||
if oldRemoteSubnets != siteConfig.RemoteSubnets {
|
||||
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
|
||||
logger.Error("Failed to remove old remote subnet routes: %v", err)
|
||||
// Continue anyway to add new routes
|
||||
}
|
||||
|
||||
// Add new remote subnet routes
|
||||
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||
logger.Error("Failed to add new remote subnet routes: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update successful
|
||||
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||
// If this is part of a WgData structure, update it
|
||||
for i, site := range wgData.Sites {
|
||||
if site.SiteId == updateData.SiteId {
|
||||
wgData.Sites[i] = siteConfig
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.Error("WireGuard device not initialized")
|
||||
}
|
||||
})
|
||||
|
||||
// Handler for adding a new peer
|
||||
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received add-peer message: %v", msg.Data)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var addData AddPeerData
|
||||
if err := json.Unmarshal(jsonData, &addData); err != nil {
|
||||
logger.Error("Error unmarshaling add data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to SiteConfig
|
||||
siteConfig := SiteConfig{
|
||||
SiteId: addData.SiteId,
|
||||
Endpoint: addData.Endpoint,
|
||||
PublicKey: addData.PublicKey,
|
||||
ServerIP: addData.ServerIP,
|
||||
ServerPort: addData.ServerPort,
|
||||
RemoteSubnets: addData.RemoteSubnets,
|
||||
}
|
||||
|
||||
// Add the peer to WireGuard
|
||||
if dev != nil {
|
||||
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||
logger.Error("Failed to add peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Add route for the new peer
|
||||
err = addRouteForServerIP(siteConfig.ServerIP, interfaceName)
|
||||
if err != nil {
|
||||
logger.Error("Failed to add route for new peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Add routes for remote subnets
|
||||
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Add successful
|
||||
logger.Info("Successfully added peer for site %d", addData.SiteId)
|
||||
|
||||
// Update WgData with the new peer
|
||||
wgData.Sites = append(wgData.Sites, siteConfig)
|
||||
} else {
|
||||
logger.Error("WireGuard device not initialized")
|
||||
}
|
||||
})
|
||||
|
||||
// Handler for removing a peer
|
||||
olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received remove-peer message: %v", msg.Data)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var removeData RemovePeerData
|
||||
if err := json.Unmarshal(jsonData, &removeData); err != nil {
|
||||
logger.Error("Error unmarshaling remove data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Find the peer to remove
|
||||
var peerToRemove *SiteConfig
|
||||
var newSites []SiteConfig
|
||||
|
||||
for _, site := range wgData.Sites {
|
||||
if site.SiteId == removeData.SiteId {
|
||||
peerToRemove = &site
|
||||
} else {
|
||||
newSites = append(newSites, site)
|
||||
}
|
||||
}
|
||||
|
||||
if peerToRemove == nil {
|
||||
logger.Error("Peer with site ID %d not found", removeData.SiteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove the peer from WireGuard
|
||||
if dev != nil {
|
||||
if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
||||
logger.Error("Failed to remove peer: %v", err)
|
||||
// Send error response if needed
|
||||
return
|
||||
}
|
||||
|
||||
// Remove route for the peer
|
||||
err = removeRouteForServerIP(peerToRemove.ServerIP)
|
||||
if err != nil {
|
||||
logger.Error("Failed to remove route for peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove routes for remote subnets
|
||||
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
|
||||
logger.Error("Failed to remove routes for remote subnets: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove successful
|
||||
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
|
||||
|
||||
// Update WgData to remove the peer
|
||||
wgData.Sites = newSites
|
||||
} else {
|
||||
logger.Error("WireGuard device not initialized")
|
||||
}
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received relay-peer message: %v", msg.Data)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var removeData RelayPeerData
|
||||
if err := json.Unmarshal(jsonData, &removeData); err != nil {
|
||||
logger.Error("Error unmarshaling remove data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
primaryRelay, err := resolveDomain(removeData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||
}
|
||||
|
||||
peerMonitor.HandleFailover(removeData.SiteId, primaryRelay)
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received no-sites message - no sites available for connection")
|
||||
|
||||
// if stopRegister != nil {
|
||||
// stopRegister()
|
||||
// stopRegister = nil
|
||||
// }
|
||||
|
||||
// select {
|
||||
// case <-stopHolepunch:
|
||||
// // Channel already closed, do nothing
|
||||
// default:
|
||||
// close(stopHolepunch)
|
||||
// }
|
||||
|
||||
logger.Info("No sites available - stopped registration and holepunch processes")
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received terminate message")
|
||||
olm.Close()
|
||||
})
|
||||
|
||||
olm.OnConnect(func() error {
|
||||
logger.Info("Websocket Connected")
|
||||
|
||||
if httpServer != nil {
|
||||
httpServer.SetConnectionStatus(true)
|
||||
}
|
||||
|
||||
if connected {
|
||||
logger.Debug("Already connected, skipping registration")
|
||||
return nil
|
||||
}
|
||||
|
||||
publicKey := privateKey.PublicKey()
|
||||
|
||||
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
|
||||
|
||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": !doHolepunch,
|
||||
}, 1*time.Second)
|
||||
|
||||
go keepSendingPing(olm)
|
||||
|
||||
logger.Info("Sent registration message")
|
||||
return nil
|
||||
})
|
||||
|
||||
olm.OnTokenUpdate(func(token string) {
|
||||
olmToken = token
|
||||
})
|
||||
|
||||
// Connect to the WebSocket server
|
||||
if err := olm.Connect(); err != nil {
|
||||
logger.Fatal("Failed to connect to server: %v", err)
|
||||
go olm.StartTunnel(tunnelConfig)
|
||||
} else {
|
||||
logger.Info("Incomplete tunnel configuration, not starting tunnel")
|
||||
}
|
||||
defer olm.Close()
|
||||
|
||||
// Wait for interrupt signal or context cancellation
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Wait for either signal or programmatic shutdown
|
||||
select {
|
||||
case <-sigCh:
|
||||
logger.Info("Received interrupt signal")
|
||||
case <-signalCtx.Done():
|
||||
logger.Info("Shutdown signal received, cleaning up...")
|
||||
case <-ctx.Done():
|
||||
logger.Info("Context cancelled")
|
||||
logger.Info("Shutdown requested via API, cleaning up...")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stopHolepunch:
|
||||
// Channel already closed, do nothing
|
||||
default:
|
||||
close(stopHolepunch)
|
||||
}
|
||||
|
||||
if stopRegister != nil {
|
||||
stopRegister()
|
||||
stopRegister = nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stopPing:
|
||||
// Channel already closed
|
||||
default:
|
||||
close(stopPing)
|
||||
}
|
||||
|
||||
if uapiListener != nil {
|
||||
uapiListener.Close()
|
||||
}
|
||||
if dev != nil {
|
||||
dev.Close()
|
||||
}
|
||||
|
||||
logger.Info("runOlmMain() exiting")
|
||||
fmt.Printf("runOlmMain() exiting\n")
|
||||
// Clean up resources
|
||||
olm.Close()
|
||||
logger.Info("Shutdown complete")
|
||||
}
|
||||
|
||||
126
namespace.sh
Normal file
126
namespace.sh
Normal file
@@ -0,0 +1,126 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Configuration
|
||||
NS_NAME="isolated_ns" # Name of the namespace
|
||||
VETH_HOST="veth_host" # Interface name on host side
|
||||
VETH_NS="veth_ns" # Interface name inside namespace
|
||||
HOST_IP="192.168.15.1" # Gateway IP for the namespace (host side)
|
||||
NS_IP="192.168.15.2" # IP address for the namespace
|
||||
SUBNET_CIDR="24" # Subnet mask
|
||||
DNS_SERVER="8.8.8.8" # DNS to use inside namespace
|
||||
|
||||
# Detect the main physical interface (gateway to internet)
|
||||
PHY_IFACE=$(ip route get 8.8.8.8 | awk -- '{printf $5}')
|
||||
|
||||
# Helper function to check for root
|
||||
check_root() {
|
||||
if [ "$EUID" -ne 0 ]; then
|
||||
echo "Error: This script must be run as root."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
setup_ns() {
|
||||
echo "Bringing up namespace '$NS_NAME'..."
|
||||
|
||||
# 1. Create the network namespace
|
||||
if ip netns list | grep -q "$NS_NAME"; then
|
||||
echo "Namespace $NS_NAME already exists. Run 'down' first."
|
||||
exit 1
|
||||
fi
|
||||
ip netns add "$NS_NAME"
|
||||
|
||||
# 2. Create veth pair
|
||||
ip link add "$VETH_HOST" type veth peer name "$VETH_NS"
|
||||
|
||||
# 3. Move peer interface to namespace
|
||||
ip link set "$VETH_NS" netns "$NS_NAME"
|
||||
|
||||
# 4. Configure Host Side Interface
|
||||
ip addr add "${HOST_IP}/${SUBNET_CIDR}" dev "$VETH_HOST"
|
||||
ip link set "$VETH_HOST" up
|
||||
|
||||
# 5. Configure Namespace Side Interface
|
||||
ip netns exec "$NS_NAME" ip addr add "${NS_IP}/${SUBNET_CIDR}" dev "$VETH_NS"
|
||||
ip netns exec "$NS_NAME" ip link set "$VETH_NS" up
|
||||
|
||||
# 6. Bring up loopback inside namespace (crucial for many apps)
|
||||
ip netns exec "$NS_NAME" ip link set lo up
|
||||
|
||||
# 7. Routing: Add default gateway inside namespace pointing to host
|
||||
ip netns exec "$NS_NAME" ip route add default via "$HOST_IP"
|
||||
|
||||
# 8. Enable IP forwarding on host
|
||||
echo 1 > /proc/sys/net/ipv4/ip_forward
|
||||
|
||||
# 9. NAT/Masquerade: Allow traffic from namespace to go out physical interface
|
||||
# We verify rule doesn't exist first to avoid duplicates
|
||||
iptables -t nat -C POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE 2>/dev/null || \
|
||||
iptables -t nat -A POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE
|
||||
|
||||
# Allow forwarding from host veth to WAN and back
|
||||
iptables -C FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT 2>/dev/null || \
|
||||
iptables -A FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT
|
||||
|
||||
iptables -C FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT 2>/dev/null || \
|
||||
iptables -A FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT
|
||||
|
||||
# 10. DNS Setup
|
||||
# Netns uses /etc/netns/<name>/resolv.conf if it exists
|
||||
mkdir -p "/etc/netns/$NS_NAME"
|
||||
echo "nameserver $DNS_SERVER" > "/etc/netns/$NS_NAME/resolv.conf"
|
||||
|
||||
echo "Namespace $NS_NAME is UP."
|
||||
echo "To enter shell: sudo ip netns exec $NS_NAME bash"
|
||||
}
|
||||
|
||||
teardown_ns() {
|
||||
echo "Tearing down namespace '$NS_NAME'..."
|
||||
|
||||
# 1. Remove Namespace (this automatically deletes the veth pair inside it)
|
||||
# The host side veth usually disappears when the peer is destroyed.
|
||||
if ip netns list | grep -q "$NS_NAME"; then
|
||||
ip netns del "$NS_NAME"
|
||||
else
|
||||
echo "Namespace $NS_NAME does not exist."
|
||||
fi
|
||||
|
||||
# 2. Clean up veth host side if it still lingers
|
||||
if ip link show "$VETH_HOST" > /dev/null 2>&1; then
|
||||
ip link delete "$VETH_HOST"
|
||||
fi
|
||||
|
||||
# 3. Remove iptables rules
|
||||
# We use -D to delete the specific rules we added
|
||||
iptables -t nat -D POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE 2>/dev/null
|
||||
iptables -D FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT 2>/dev/null
|
||||
iptables -D FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT 2>/dev/null
|
||||
|
||||
# 4. Remove DNS config
|
||||
rm -rf "/etc/netns/$NS_NAME"
|
||||
|
||||
echo "Namespace $NS_NAME is DOWN."
|
||||
}
|
||||
|
||||
test_connectivity() {
|
||||
echo "Testing connectivity inside $NS_NAME..."
|
||||
ip netns exec "$NS_NAME" ping -c 3 8.8.8.8
|
||||
}
|
||||
|
||||
# Main execution logic
|
||||
check_root
|
||||
|
||||
case "$1" in
|
||||
up)
|
||||
setup_ns
|
||||
;;
|
||||
down)
|
||||
teardown_ns
|
||||
;;
|
||||
test)
|
||||
test_connectivity
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {up|down|test}"
|
||||
exit 1
|
||||
esac
|
||||
152
olm.iss
Normal file
152
olm.iss
Normal file
@@ -0,0 +1,152 @@
|
||||
; Script generated by the Inno Setup Script Wizard.
|
||||
; SEE THE DOCUMENTATION FOR DETAILS ON CREATING INNO SETUP SCRIPT FILES!
|
||||
|
||||
#define MyAppName "olm"
|
||||
#define MyAppVersion "1.0.0"
|
||||
#define MyAppPublisher "Fossorial Inc."
|
||||
#define MyAppURL "https://pangolin.net"
|
||||
#define MyAppExeName "olm.exe"
|
||||
|
||||
[Setup]
|
||||
; NOTE: The value of AppId uniquely identifies this application. Do not use the same AppId value in installers for other applications.
|
||||
; (To generate a new GUID, click Tools | Generate GUID inside the IDE.)
|
||||
AppId={{44A24E4C-B616-476F-ADE7-8D56B930959E}
|
||||
AppName={#MyAppName}
|
||||
AppVersion={#MyAppVersion}
|
||||
;AppVerName={#MyAppName} {#MyAppVersion}
|
||||
AppPublisher={#MyAppPublisher}
|
||||
AppPublisherURL={#MyAppURL}
|
||||
AppSupportURL={#MyAppURL}
|
||||
AppUpdatesURL={#MyAppURL}
|
||||
DefaultDirName={autopf}\{#MyAppName}
|
||||
UninstallDisplayIcon={app}\{#MyAppExeName}
|
||||
; "ArchitecturesAllowed=x64compatible" specifies that Setup cannot run
|
||||
; on anything but x64 and Windows 11 on Arm.
|
||||
ArchitecturesAllowed=x64compatible
|
||||
; "ArchitecturesInstallIn64BitMode=x64compatible" requests that the
|
||||
; install be done in "64-bit mode" on x64 or Windows 11 on Arm,
|
||||
; meaning it should use the native 64-bit Program Files directory and
|
||||
; the 64-bit view of the registry.
|
||||
ArchitecturesInstallIn64BitMode=x64compatible
|
||||
DefaultGroupName={#MyAppName}
|
||||
DisableProgramGroupPage=yes
|
||||
; Uncomment the following line to run in non administrative install mode (install for current user only).
|
||||
;PrivilegesRequired=lowest
|
||||
OutputBaseFilename=mysetup
|
||||
SolidCompression=yes
|
||||
WizardStyle=modern
|
||||
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed
|
||||
RestartIfNeededByRun=no
|
||||
ChangesEnvironment=true
|
||||
|
||||
[Languages]
|
||||
Name: "english"; MessagesFile: "compiler:Default.isl"
|
||||
|
||||
[Files]
|
||||
; The 'DestName' flag ensures that 'olm_windows_amd64.exe' is installed as 'olm.exe'
|
||||
Source: "Z:\olm_windows_amd64.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}"; Flags: ignoreversion
|
||||
Source: "Z:\wintun.dll"; DestDir: "{app}"; Flags: ignoreversion
|
||||
; NOTE: Don't use "Flags: ignoreversion" on any shared system files
|
||||
|
||||
[Icons]
|
||||
Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"
|
||||
|
||||
[Registry]
|
||||
; Add the application's installation directory to the system PATH environment variable.
|
||||
; HKLM (HKEY_LOCAL_MACHINE) is used for system-wide changes.
|
||||
; The 'Path' variable is located under 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'.
|
||||
; ValueType: expandsz allows for environment variables (like %ProgramFiles%) in the path.
|
||||
; ValueData: "{olddata};{app}" appends the current application directory to the existing PATH.
|
||||
; Note: Removal during uninstallation is handled by CurUninstallStepChanged procedure in [Code] section.
|
||||
; Check: NeedsAddPath ensures this is applied only if the path is not already present.
|
||||
[Registry]
|
||||
; Add the application's installation directory to the system PATH.
|
||||
Root: HKLM; Subkey: "SYSTEM\CurrentControlSet\Control\Session Manager\Environment"; \
|
||||
ValueType: expandsz; ValueName: "Path"; ValueData: "{olddata};{app}"; \
|
||||
Check: NeedsAddPath(ExpandConstant('{app}'))
|
||||
|
||||
[Code]
|
||||
function NeedsAddPath(Path: string): boolean;
|
||||
var
|
||||
OrigPath: string;
|
||||
begin
|
||||
if not RegQueryStringValue(HKEY_LOCAL_MACHINE,
|
||||
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||
'Path', OrigPath)
|
||||
then begin
|
||||
// Path variable doesn't exist at all, so we definitely need to add it.
|
||||
Result := True;
|
||||
exit;
|
||||
end;
|
||||
|
||||
// Perform a case-insensitive check to see if the path is already present.
|
||||
// We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2).
|
||||
if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then
|
||||
Result := False
|
||||
else
|
||||
Result := True;
|
||||
end;
|
||||
|
||||
procedure RemovePathEntry(PathToRemove: string);
|
||||
var
|
||||
OrigPath: string;
|
||||
NewPath: string;
|
||||
PathList: TStringList;
|
||||
I: Integer;
|
||||
begin
|
||||
if not RegQueryStringValue(HKEY_LOCAL_MACHINE,
|
||||
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||
'Path', OrigPath)
|
||||
then begin
|
||||
// Path variable doesn't exist, nothing to remove
|
||||
exit;
|
||||
end;
|
||||
|
||||
// Create a string list to parse the PATH entries
|
||||
PathList := TStringList.Create;
|
||||
try
|
||||
// Split the PATH by semicolons
|
||||
PathList.Delimiter := ';';
|
||||
PathList.StrictDelimiter := True;
|
||||
PathList.DelimitedText := OrigPath;
|
||||
|
||||
// Find and remove the matching entry (case-insensitive)
|
||||
for I := PathList.Count - 1 downto 0 do
|
||||
begin
|
||||
if CompareText(Trim(PathList[I]), Trim(PathToRemove)) = 0 then
|
||||
begin
|
||||
Log('Found and removing PATH entry: ' + PathList[I]);
|
||||
PathList.Delete(I);
|
||||
end;
|
||||
end;
|
||||
|
||||
// Reconstruct the PATH
|
||||
NewPath := PathList.DelimitedText;
|
||||
|
||||
// Write the new PATH back to the registry
|
||||
if RegWriteExpandStringValue(HKEY_LOCAL_MACHINE,
|
||||
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||
'Path', NewPath)
|
||||
then
|
||||
Log('Successfully removed path entry: ' + PathToRemove)
|
||||
else
|
||||
Log('Failed to write modified PATH to registry');
|
||||
finally
|
||||
PathList.Free;
|
||||
end;
|
||||
end;
|
||||
|
||||
procedure CurUninstallStepChanged(CurUninstallStep: TUninstallStep);
|
||||
var
|
||||
AppPath: string;
|
||||
begin
|
||||
if CurUninstallStep = usUninstall then
|
||||
begin
|
||||
// Get the application installation path
|
||||
AppPath := ExpandConstant('{app}');
|
||||
Log('Removing PATH entry for: ' + AppPath);
|
||||
|
||||
// Remove only our path entry from the system PATH
|
||||
RemovePathEntry(AppPath);
|
||||
end;
|
||||
end;
|
||||
299
olm/connect.go
Normal file
299
olm/connect.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
olmDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/dns"
|
||||
dnsOverride "github.com/fosrl/olm/dns/override"
|
||||
"github.com/fosrl/olm/peers"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
// OlmErrorData represents the error data sent from the server
|
||||
type OlmErrorData struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring connect message")
|
||||
return
|
||||
}
|
||||
|
||||
var wgData WgData
|
||||
|
||||
if o.registered {
|
||||
logger.Info("Already connected. Ignoring new connection request.")
|
||||
return
|
||||
}
|
||||
|
||||
if o.stopRegister != nil {
|
||||
o.stopRegister()
|
||||
o.stopRegister = nil
|
||||
}
|
||||
|
||||
if o.updateRegister != nil {
|
||||
o.updateRegister = nil
|
||||
}
|
||||
|
||||
// if there is an existing tunnel then close it
|
||||
if o.dev != nil {
|
||||
logger.Info("Got new message. Closing existing tunnel!")
|
||||
o.dev.Close()
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
o.tdev, err = func() (tun.Device, error) {
|
||||
if o.tunnelConfig.FileDescriptorTun != 0 {
|
||||
return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU)
|
||||
}
|
||||
ifName := o.tunnelConfig.InterfaceName
|
||||
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
||||
ifName, err = network.FindUnusedUTUN()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return tun.CreateTUN(ifName, o.tunnelConfig.MTU)
|
||||
}()
|
||||
if err != nil {
|
||||
logger.Error("Failed to create TUN device: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// if config.FileDescriptorTun == 0 {
|
||||
if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
|
||||
o.tunnelConfig.InterfaceName = realInterfaceName
|
||||
}
|
||||
// }
|
||||
|
||||
// Wrap TUN device with packet filter for DNS proxy
|
||||
o.middleDev = olmDevice.NewMiddleDevice(o.tdev)
|
||||
|
||||
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
||||
// Use filtered device instead of raw TUN device
|
||||
o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger))
|
||||
|
||||
if o.tunnelConfig.EnableUAPI {
|
||||
fileUAPI, err := func() (*os.File, error) {
|
||||
if o.tunnelConfig.FileDescriptorUAPI != 0 {
|
||||
fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err)
|
||||
}
|
||||
return os.NewFile(uintptr(fd), ""), nil
|
||||
}
|
||||
return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName)
|
||||
}()
|
||||
if err != nil {
|
||||
logger.Error("UAPI listen error: %v", err)
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI)
|
||||
if err != nil {
|
||||
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := o.uapiListener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go o.dev.IpcHandle(conn)
|
||||
}
|
||||
}()
|
||||
logger.Info("UAPI listener started")
|
||||
}
|
||||
|
||||
if err = o.dev.Up(); err != nil {
|
||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// Extract interface IP (strip CIDR notation if present)
|
||||
interfaceIP := wgData.TunnelIP
|
||||
if strings.Contains(interfaceIP, "/") {
|
||||
interfaceIP = strings.Split(interfaceIP, "/")[0]
|
||||
}
|
||||
|
||||
// Create and start DNS proxy
|
||||
o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create DNS proxy: %v", err)
|
||||
}
|
||||
|
||||
if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil {
|
||||
logger.Error("Failed to o.tunnelConfigure interface: %v", err)
|
||||
}
|
||||
|
||||
if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet
|
||||
logger.Error("Failed to add route for utility subnet: %v", err)
|
||||
}
|
||||
|
||||
// Create peer manager with integrated peer monitoring
|
||||
o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
|
||||
Device: o.dev,
|
||||
DNSProxy: o.dnsProxy,
|
||||
InterfaceName: o.tunnelConfig.InterfaceName,
|
||||
PrivateKey: o.privateKey,
|
||||
MiddleDev: o.middleDev,
|
||||
LocalIP: interfaceIP,
|
||||
SharedBind: o.sharedBind,
|
||||
WSClient: o.websocket,
|
||||
APIServer: o.apiServer,
|
||||
})
|
||||
|
||||
for i := range wgData.Sites {
|
||||
site := wgData.Sites[i]
|
||||
var siteEndpoint string
|
||||
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
|
||||
if site.RelayEndpoint != "" {
|
||||
siteEndpoint = site.RelayEndpoint
|
||||
} else {
|
||||
siteEndpoint = site.Endpoint
|
||||
}
|
||||
|
||||
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
|
||||
|
||||
if err := o.peerManager.AddPeer(site); err != nil {
|
||||
logger.Error("Failed to add peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Configured peer %s", site.PublicKey)
|
||||
}
|
||||
|
||||
o.peerManager.Start()
|
||||
|
||||
if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime
|
||||
logger.Error("Failed to start DNS proxy: %v", err)
|
||||
}
|
||||
|
||||
if o.tunnelConfig.OverrideDNS {
|
||||
// Set up DNS override to use our DNS proxy
|
||||
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
|
||||
logger.Error("Failed to setup DNS override: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()})
|
||||
}
|
||||
|
||||
o.apiServer.SetRegistered(true)
|
||||
|
||||
o.registered = true
|
||||
|
||||
// Start ping monitor now that we are registered and connected
|
||||
o.websocket.StartPingMonitor()
|
||||
|
||||
// Invoke onConnected callback if configured
|
||||
if o.olmConfig.OnConnected != nil {
|
||||
go o.olmConfig.OnConnected()
|
||||
}
|
||||
|
||||
logger.Info("WireGuard device created.")
|
||||
}
|
||||
|
||||
func (o *Olm) handleOlmError(msg websocket.WSMessage) {
|
||||
logger.Debug("Received olm error message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring olm error message")
|
||||
return
|
||||
}
|
||||
|
||||
var errorData OlmErrorData
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling olm error data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &errorData); err != nil {
|
||||
logger.Error("Error unmarshaling olm error data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error("Olm error (code: %s): %s", errorData.Code, errorData.Message)
|
||||
|
||||
// Set the olm error in the API server so it can be exposed via status
|
||||
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
|
||||
|
||||
// Invoke onOlmError callback if configured
|
||||
if o.olmConfig.OnOlmError != nil {
|
||||
go o.olmConfig.OnOlmError(errorData.Code, errorData.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Olm) handleTerminate(msg websocket.WSMessage) {
|
||||
logger.Info("Received terminate message")
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring terminate message")
|
||||
return
|
||||
}
|
||||
|
||||
var errorData OlmErrorData
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling terminate error data: %v", err)
|
||||
} else {
|
||||
if err := json.Unmarshal(jsonData, &errorData); err != nil {
|
||||
logger.Error("Error unmarshaling terminate error data: %v", err)
|
||||
} else {
|
||||
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
|
||||
|
||||
if errorData.Code == "TERMINATED_INACTIVITY" {
|
||||
logger.Info("Ignoring...")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the olm error in the API server so it can be exposed via status
|
||||
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
|
||||
}
|
||||
}
|
||||
|
||||
o.apiServer.SetTerminated(true)
|
||||
o.apiServer.SetConnectionStatus(false)
|
||||
o.apiServer.SetRegistered(false)
|
||||
o.apiServer.ClearPeerStatuses()
|
||||
|
||||
network.ClearNetworkSettings()
|
||||
|
||||
o.Close()
|
||||
|
||||
if o.olmConfig.OnTerminated != nil {
|
||||
go o.olmConfig.OnTerminated()
|
||||
}
|
||||
}
|
||||
365
olm/data.go
Normal file
365
olm/data.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/peers"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
)
|
||||
|
||||
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
|
||||
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring add-remote-subnets-aliases message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var addSubnetsData peers.PeerAdd
|
||||
if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil {
|
||||
logger.Error("Error unmarshaling add-remote-subnets data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists {
|
||||
logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Add new subnets
|
||||
for _, subnet := range addSubnetsData.RemoteSubnets {
|
||||
if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil {
|
||||
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add new aliases
|
||||
for _, alias := range addSubnetsData.Aliases {
|
||||
if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil {
|
||||
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
|
||||
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring remove-remote-subnets-aliases message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var removeSubnetsData peers.RemovePeerData
|
||||
if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil {
|
||||
logger.Error("Error unmarshaling remove-remote-subnets data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists {
|
||||
logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove subnets
|
||||
for _, subnet := range removeSubnetsData.RemoteSubnets {
|
||||
if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil {
|
||||
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove aliases
|
||||
for _, alias := range removeSubnetsData.Aliases {
|
||||
if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil {
|
||||
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
|
||||
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring update-remote-subnets-aliases message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var updateSubnetsData peers.UpdatePeerData
|
||||
if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil {
|
||||
logger.Error("Error unmarshaling update-remote-subnets data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists {
|
||||
logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Add new subnets BEFORE removing old ones to preserve shared subnets
|
||||
// This ensures that if an old and new subnet are the same on different peers,
|
||||
// the route won't be temporarily removed
|
||||
for _, subnet := range updateSubnetsData.NewRemoteSubnets {
|
||||
if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
|
||||
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old subnets after new ones are added
|
||||
for _, subnet := range updateSubnetsData.OldRemoteSubnets {
|
||||
if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
|
||||
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add new aliases BEFORE removing old ones to preserve shared IP addresses
|
||||
// This ensures that if an old and new alias share the same IP, the IP won't be
|
||||
// temporarily removed from the allowed IPs list
|
||||
for _, alias := range updateSubnetsData.NewAliases {
|
||||
if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil {
|
||||
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old aliases after new ones are added
|
||||
for _, alias := range updateSubnetsData.OldAliases {
|
||||
if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil {
|
||||
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId)
|
||||
}
|
||||
|
||||
// Handler for syncing peer configuration - reconciles expected state with actual state
|
||||
func (o *Olm) handleSync(msg websocket.WSMessage) {
|
||||
logger.Debug("Received sync message: %v", msg.Data)
|
||||
|
||||
if !o.registered {
|
||||
logger.Warn("Not connected, ignoring sync request")
|
||||
return
|
||||
}
|
||||
|
||||
if o.peerManager == nil {
|
||||
logger.Warn("Peer manager not initialized, ignoring sync request")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling sync data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var syncData SyncData
|
||||
if err := json.Unmarshal(jsonData, &syncData); err != nil {
|
||||
logger.Error("Error unmarshaling sync data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Sync exit nodes for hole punching
|
||||
o.syncExitNodes(syncData.ExitNodes)
|
||||
|
||||
// Build a map of expected peers from the incoming data
|
||||
expectedPeers := make(map[int]peers.SiteConfig)
|
||||
for _, site := range syncData.Sites {
|
||||
expectedPeers[site.SiteId] = site
|
||||
}
|
||||
|
||||
// Get all current peers
|
||||
currentPeers := o.peerManager.GetAllPeers()
|
||||
currentPeerMap := make(map[int]peers.SiteConfig)
|
||||
for _, peer := range currentPeers {
|
||||
currentPeerMap[peer.SiteId] = peer
|
||||
}
|
||||
|
||||
// Find peers to remove (in current but not in expected)
|
||||
for siteId := range currentPeerMap {
|
||||
if _, exists := expectedPeers[siteId]; !exists {
|
||||
logger.Info("Sync: Removing peer for site %d (no longer in expected config)", siteId)
|
||||
if err := o.peerManager.RemovePeer(siteId); err != nil {
|
||||
logger.Error("Sync: Failed to remove peer %d: %v", siteId, err)
|
||||
} else {
|
||||
// Remove any exit nodes associated with this peer from hole punching
|
||||
if o.holePunchManager != nil {
|
||||
removed := o.holePunchManager.RemoveExitNodesByPeer(siteId)
|
||||
if removed > 0 {
|
||||
logger.Info("Sync: Removed %d exit nodes associated with peer %d from hole punch rotation", removed, siteId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find peers to add (in expected but not in current) and peers to update
|
||||
for siteId, expectedSite := range expectedPeers {
|
||||
if _, exists := currentPeerMap[siteId]; !exists {
|
||||
// New peer - add it using the add flow (with holepunch)
|
||||
logger.Info("Sync: Adding new peer for site %d", siteId)
|
||||
|
||||
o.holePunchManager.TriggerHolePunch()
|
||||
|
||||
// // TODO: do we need to send the message to the cloud to add the peer that way?
|
||||
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
|
||||
// logger.Error("Sync: Failed to add peer %d: %v", siteId, err)
|
||||
// } else {
|
||||
// logger.Info("Sync: Successfully added peer for site %d", siteId)
|
||||
// }
|
||||
|
||||
// add the peer via the server
|
||||
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
|
||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": expectedSite.SiteId,
|
||||
}, 1*time.Second, 10)
|
||||
|
||||
} else {
|
||||
// Existing peer - check if update is needed
|
||||
currentSite := currentPeerMap[siteId]
|
||||
needsUpdate := false
|
||||
|
||||
// Check if any fields have changed
|
||||
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
|
||||
needsUpdate = true
|
||||
}
|
||||
if expectedSite.RelayEndpoint != "" && expectedSite.RelayEndpoint != currentSite.RelayEndpoint {
|
||||
needsUpdate = true
|
||||
}
|
||||
if expectedSite.PublicKey != "" && expectedSite.PublicKey != currentSite.PublicKey {
|
||||
needsUpdate = true
|
||||
}
|
||||
if expectedSite.ServerIP != "" && expectedSite.ServerIP != currentSite.ServerIP {
|
||||
needsUpdate = true
|
||||
}
|
||||
if expectedSite.ServerPort != 0 && expectedSite.ServerPort != currentSite.ServerPort {
|
||||
needsUpdate = true
|
||||
}
|
||||
// Check remote subnets
|
||||
if expectedSite.RemoteSubnets != nil && !slicesEqual(expectedSite.RemoteSubnets, currentSite.RemoteSubnets) {
|
||||
needsUpdate = true
|
||||
}
|
||||
// Check aliases
|
||||
if expectedSite.Aliases != nil && !aliasesEqual(expectedSite.Aliases, currentSite.Aliases) {
|
||||
needsUpdate = true
|
||||
}
|
||||
|
||||
if needsUpdate {
|
||||
logger.Info("Sync: Updating peer for site %d", siteId)
|
||||
|
||||
// Merge expected data with current data
|
||||
siteConfig := currentSite
|
||||
if expectedSite.Endpoint != "" {
|
||||
siteConfig.Endpoint = expectedSite.Endpoint
|
||||
}
|
||||
if expectedSite.RelayEndpoint != "" {
|
||||
siteConfig.RelayEndpoint = expectedSite.RelayEndpoint
|
||||
}
|
||||
if expectedSite.PublicKey != "" {
|
||||
siteConfig.PublicKey = expectedSite.PublicKey
|
||||
}
|
||||
if expectedSite.ServerIP != "" {
|
||||
siteConfig.ServerIP = expectedSite.ServerIP
|
||||
}
|
||||
if expectedSite.ServerPort != 0 {
|
||||
siteConfig.ServerPort = expectedSite.ServerPort
|
||||
}
|
||||
if expectedSite.RemoteSubnets != nil {
|
||||
siteConfig.RemoteSubnets = expectedSite.RemoteSubnets
|
||||
}
|
||||
if expectedSite.Aliases != nil {
|
||||
siteConfig.Aliases = expectedSite.Aliases
|
||||
}
|
||||
|
||||
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
|
||||
logger.Error("Sync: Failed to update peer %d: %v", siteId, err)
|
||||
} else {
|
||||
// If the endpoint changed, trigger holepunch to refresh NAT mappings
|
||||
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
|
||||
logger.Info("Sync: Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", siteId)
|
||||
o.holePunchManager.TriggerHolePunch()
|
||||
o.holePunchManager.ResetServerHolepunchInterval()
|
||||
}
|
||||
logger.Info("Sync: Successfully updated peer for site %d", siteId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers))
|
||||
}
|
||||
|
||||
// syncExitNodes reconciles the expected exit nodes with the current ones in the hole punch manager
|
||||
func (o *Olm) syncExitNodes(expectedExitNodes []SyncExitNode) {
|
||||
if o.holePunchManager == nil {
|
||||
logger.Warn("Hole punch manager not initialized, skipping exit node sync")
|
||||
return
|
||||
}
|
||||
|
||||
// Build a map of expected exit nodes by endpoint
|
||||
expectedExitNodeMap := make(map[string]SyncExitNode)
|
||||
for _, exitNode := range expectedExitNodes {
|
||||
expectedExitNodeMap[exitNode.Endpoint] = exitNode
|
||||
}
|
||||
|
||||
// Get current exit nodes from hole punch manager
|
||||
currentExitNodes := o.holePunchManager.GetExitNodes()
|
||||
currentExitNodeMap := make(map[string]holepunch.ExitNode)
|
||||
for _, exitNode := range currentExitNodes {
|
||||
currentExitNodeMap[exitNode.Endpoint] = exitNode
|
||||
}
|
||||
|
||||
// Find exit nodes to remove (in current but not in expected)
|
||||
for endpoint := range currentExitNodeMap {
|
||||
if _, exists := expectedExitNodeMap[endpoint]; !exists {
|
||||
logger.Info("Sync: Removing exit node %s (no longer in expected config)", endpoint)
|
||||
o.holePunchManager.RemoveExitNode(endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
// Find exit nodes to add (in expected but not in current)
|
||||
for endpoint, expectedExitNode := range expectedExitNodeMap {
|
||||
if _, exists := currentExitNodeMap[endpoint]; !exists {
|
||||
logger.Info("Sync: Adding new exit node %s", endpoint)
|
||||
|
||||
relayPort := expectedExitNode.RelayPort
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820 // default relay port
|
||||
}
|
||||
|
||||
hpExitNode := holepunch.ExitNode{
|
||||
Endpoint: expectedExitNode.Endpoint,
|
||||
RelayPort: relayPort,
|
||||
PublicKey: expectedExitNode.PublicKey,
|
||||
SiteIds: expectedExitNode.SiteIds,
|
||||
}
|
||||
|
||||
if o.holePunchManager.AddExitNode(hpExitNode) {
|
||||
logger.Info("Sync: Successfully added exit node %s", endpoint)
|
||||
}
|
||||
o.holePunchManager.TriggerHolePunch()
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Sync exit nodes completed: processed %d expected exit nodes, had %d current exit nodes", len(expectedExitNodeMap), len(currentExitNodeMap))
|
||||
}
|
||||
938
olm/olm.go
Normal file
938
olm/olm.go
Normal file
@@ -0,0 +1,938 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/clients/permissions"
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/api"
|
||||
olmDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/dns"
|
||||
dnsOverride "github.com/fosrl/olm/dns/override"
|
||||
"github.com/fosrl/olm/peers"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type Olm struct {
|
||||
privateKey wgtypes.Key
|
||||
logFile *os.File
|
||||
|
||||
registered bool
|
||||
tunnelRunning bool
|
||||
|
||||
uapiListener net.Listener
|
||||
dev *device.Device
|
||||
tdev tun.Device
|
||||
middleDev *olmDevice.MiddleDevice
|
||||
sharedBind *bind.SharedBind
|
||||
|
||||
dnsProxy *dns.DNSProxy
|
||||
apiServer *api.API
|
||||
websocket *websocket.Client
|
||||
holePunchManager *holepunch.Manager
|
||||
peerManager *peers.PeerManager
|
||||
// Power mode management
|
||||
currentPowerMode string
|
||||
powerModeMu sync.Mutex
|
||||
wakeUpTimer *time.Timer
|
||||
wakeUpDebounce time.Duration
|
||||
|
||||
olmCtx context.Context
|
||||
tunnelCancel context.CancelFunc
|
||||
|
||||
olmConfig OlmConfig
|
||||
tunnelConfig TunnelConfig
|
||||
|
||||
// Metadata to send alongside pings
|
||||
fingerprint map[string]any
|
||||
postures map[string]any
|
||||
metaMu sync.Mutex
|
||||
|
||||
stopRegister func()
|
||||
updateRegister func(newData any)
|
||||
|
||||
stopPeerSend func()
|
||||
|
||||
// WaitGroup to track tunnel lifecycle
|
||||
tunnelWg sync.WaitGroup
|
||||
}
|
||||
|
||||
// initTunnelInfo creates the shared UDP socket and holepunch manager.
|
||||
// This is used during initial tunnel setup and when switching organizations.
|
||||
func (o *Olm) initTunnelInfo(clientID string) error {
|
||||
privateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
logger.Error("Failed to generate private key: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
o.privateKey = privateKey
|
||||
|
||||
sourcePort, err := util.FindAvailableUDPPort(49152, 65535)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find available UDP port: %w", err)
|
||||
}
|
||||
|
||||
localAddr := &net.UDPAddr{
|
||||
Port: int(sourcePort),
|
||||
IP: net.IPv4zero,
|
||||
}
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create UDP socket: %w", err)
|
||||
}
|
||||
|
||||
sharedBind, err := bind.New(udpConn)
|
||||
if err != nil {
|
||||
_ = udpConn.Close()
|
||||
return fmt.Errorf("failed to create shared bind: %w", err)
|
||||
}
|
||||
|
||||
o.sharedBind = sharedBind
|
||||
|
||||
// Add a reference for the hole punch senders (creator already has one reference for WireGuard)
|
||||
sharedBind.AddRef()
|
||||
|
||||
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
||||
|
||||
// Create the holepunch manager
|
||||
o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
||||
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
||||
|
||||
// Start pprof server if enabled
|
||||
if config.PprofAddr != "" {
|
||||
go func() {
|
||||
logger.Info("Starting pprof server on %s", config.PprofAddr)
|
||||
if err := http.ListenAndServe(config.PprofAddr, nil); err != nil {
|
||||
logger.Error("Failed to start pprof server: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var logFile *os.File
|
||||
if config.LogFilePath != "" {
|
||||
file, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to open log file: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.SetOutput(file)
|
||||
logFile = file
|
||||
}
|
||||
|
||||
if config.WakeUpDebounce == 0 {
|
||||
config.WakeUpDebounce = 3 * time.Second
|
||||
}
|
||||
|
||||
logger.Debug("Checking permissions for native interface")
|
||||
err := permissions.CheckNativeInterfacePermissions()
|
||||
if err != nil {
|
||||
logger.Fatal("Insufficient permissions to create native TUN interface: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var apiServer *api.API
|
||||
if config.HTTPAddr != "" {
|
||||
apiServer = api.NewAPI(config.HTTPAddr)
|
||||
} else if config.SocketPath != "" {
|
||||
apiServer = api.NewAPISocket(config.SocketPath)
|
||||
} else {
|
||||
// this is so is not null but it cant be started without either the socket path or http addr
|
||||
apiServer = api.NewAPIStub()
|
||||
}
|
||||
|
||||
apiServer.SetVersion(config.Version)
|
||||
apiServer.SetAgent(config.Agent)
|
||||
|
||||
newOlm := &Olm{
|
||||
logFile: logFile,
|
||||
olmCtx: ctx,
|
||||
apiServer: apiServer,
|
||||
olmConfig: config,
|
||||
}
|
||||
|
||||
newOlm.registerAPICallbacks()
|
||||
|
||||
return newOlm, nil
|
||||
}
|
||||
|
||||
func (o *Olm) registerAPICallbacks() {
|
||||
o.apiServer.SetHandlers(
|
||||
// onConnect
|
||||
func(req api.ConnectionRequest) error {
|
||||
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
||||
|
||||
tunnelConfig := TunnelConfig{
|
||||
Endpoint: req.Endpoint,
|
||||
ID: req.ID,
|
||||
Secret: req.Secret,
|
||||
UserToken: req.UserToken,
|
||||
MTU: req.MTU,
|
||||
DNS: req.DNS,
|
||||
UpstreamDNS: req.UpstreamDNS,
|
||||
InterfaceName: req.InterfaceName,
|
||||
Holepunch: req.Holepunch,
|
||||
TlsClientCert: req.TlsClientCert,
|
||||
OrgID: req.OrgID,
|
||||
}
|
||||
|
||||
var err error
|
||||
// Parse ping interval
|
||||
if req.PingInterval != "" {
|
||||
tunnelConfig.PingIntervalDuration, err = time.ParseDuration(req.PingInterval)
|
||||
if err != nil {
|
||||
logger.Warn("Invalid PING_INTERVAL value: %s, using default 3 seconds", req.PingInterval)
|
||||
tunnelConfig.PingIntervalDuration = 3 * time.Second
|
||||
}
|
||||
} else {
|
||||
tunnelConfig.PingIntervalDuration = 3 * time.Second
|
||||
}
|
||||
// Parse ping timeout
|
||||
if req.PingTimeout != "" {
|
||||
tunnelConfig.PingTimeoutDuration, err = time.ParseDuration(req.PingTimeout)
|
||||
if err != nil {
|
||||
logger.Warn("Invalid PING_TIMEOUT value: %s, using default 5 seconds", req.PingTimeout)
|
||||
tunnelConfig.PingTimeoutDuration = 5 * time.Second
|
||||
}
|
||||
} else {
|
||||
tunnelConfig.PingTimeoutDuration = 5 * time.Second
|
||||
}
|
||||
if req.MTU == 0 {
|
||||
tunnelConfig.MTU = 1420
|
||||
}
|
||||
if req.DNS == "" {
|
||||
tunnelConfig.DNS = "9.9.9.9"
|
||||
}
|
||||
// DNSProxyIP has no default - it must be provided if DNS proxy is desired
|
||||
// UpstreamDNS defaults to 8.8.8.8 if not provided
|
||||
if len(req.UpstreamDNS) == 0 {
|
||||
tunnelConfig.UpstreamDNS = []string{"8.8.8.8:53"}
|
||||
}
|
||||
if req.InterfaceName == "" {
|
||||
tunnelConfig.InterfaceName = "olm"
|
||||
}
|
||||
|
||||
// Start the tunnel process with the new credentials
|
||||
if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" {
|
||||
logger.Info("Starting tunnel with new credentials")
|
||||
go o.StartTunnel(tunnelConfig)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
// onSwitchOrg
|
||||
func(req api.SwitchOrgRequest) error {
|
||||
logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID)
|
||||
return o.SwitchOrg(req.OrgID)
|
||||
},
|
||||
// onMetadataChange
|
||||
func(req api.MetadataChangeRequest) error {
|
||||
logger.Info("Received change metadata request via API")
|
||||
|
||||
if req.Fingerprint != nil {
|
||||
o.SetFingerprint(req.Fingerprint)
|
||||
}
|
||||
|
||||
if req.Postures != nil {
|
||||
o.SetPostures(req.Postures)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
// onDisconnect
|
||||
func() error {
|
||||
logger.Info("Processing disconnect request via API")
|
||||
return o.StopTunnel()
|
||||
},
|
||||
// onExit
|
||||
func() error {
|
||||
logger.Info("Processing shutdown request via API")
|
||||
o.Close()
|
||||
if o.olmConfig.OnExit != nil {
|
||||
o.olmConfig.OnExit()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
// onRebind
|
||||
func() error {
|
||||
logger.Info("Processing rebind request via API")
|
||||
return o.RebindSocket()
|
||||
},
|
||||
// onPowerMode
|
||||
func(req api.PowerModeRequest) error {
|
||||
logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
|
||||
return o.SetPowerMode(req.Mode)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
if o.tunnelRunning {
|
||||
logger.Info("Tunnel already running")
|
||||
return
|
||||
}
|
||||
|
||||
// debug print out the whole config
|
||||
logger.Debug("Starting tunnel with config: %+v", config)
|
||||
|
||||
o.tunnelRunning = true // Also set it here in case it is called externally
|
||||
o.tunnelConfig = config
|
||||
|
||||
// Reset terminated status when tunnel starts
|
||||
o.apiServer.SetTerminated(false)
|
||||
|
||||
fingerprint := config.InitialFingerprint
|
||||
if fingerprint == nil {
|
||||
fingerprint = make(map[string]any)
|
||||
}
|
||||
|
||||
postures := config.InitialPostures
|
||||
if postures == nil {
|
||||
postures = make(map[string]any)
|
||||
}
|
||||
|
||||
o.SetFingerprint(fingerprint)
|
||||
o.SetPostures(postures)
|
||||
|
||||
// Create a cancellable context for this tunnel process
|
||||
tunnelCtx, cancel := context.WithCancel(o.olmCtx)
|
||||
o.tunnelCancel = cancel
|
||||
|
||||
var (
|
||||
err error
|
||||
id = config.ID
|
||||
secret = config.Secret
|
||||
userToken = config.UserToken
|
||||
)
|
||||
|
||||
o.tunnelConfig.InterfaceName = config.InterfaceName
|
||||
|
||||
o.apiServer.SetOrgID(config.OrgID)
|
||||
|
||||
// Create a new o.websocket client using the provided credentials
|
||||
o.websocket, err = websocket.NewClient(
|
||||
id,
|
||||
secret,
|
||||
userToken,
|
||||
config.OrgID,
|
||||
config.Endpoint,
|
||||
30*time.Second, // 30 seconds
|
||||
config.PingTimeoutDuration,
|
||||
websocket.WithPingDataProvider(func() map[string]any {
|
||||
o.metaMu.Lock()
|
||||
defer o.metaMu.Unlock()
|
||||
return map[string]any{
|
||||
"fingerprint": o.fingerprint,
|
||||
"postures": o.postures,
|
||||
}
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create olm: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create shared UDP socket and holepunch manager
|
||||
if err := o.initTunnelInfo(id); err != nil {
|
||||
logger.Error("%v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Handlers for managing connection status
|
||||
o.websocket.RegisterHandler("olm/wg/connect", o.handleConnect)
|
||||
o.websocket.RegisterHandler("olm/error", o.handleOlmError)
|
||||
o.websocket.RegisterHandler("olm/terminate", o.handleTerminate)
|
||||
|
||||
// Handlers for managing peers
|
||||
o.websocket.RegisterHandler("olm/wg/peer/add", o.handleWgPeerAdd)
|
||||
o.websocket.RegisterHandler("olm/wg/peer/remove", o.handleWgPeerRemove)
|
||||
o.websocket.RegisterHandler("olm/wg/peer/update", o.handleWgPeerUpdate)
|
||||
o.websocket.RegisterHandler("olm/wg/peer/relay", o.handleWgPeerRelay)
|
||||
o.websocket.RegisterHandler("olm/wg/peer/unrelay", o.handleWgPeerUnrelay)
|
||||
|
||||
// Handlers for managing remote subnets to a peer
|
||||
o.websocket.RegisterHandler("olm/wg/peer/data/add", o.handleWgPeerAddData)
|
||||
o.websocket.RegisterHandler("olm/wg/peer/data/remove", o.handleWgPeerRemoveData)
|
||||
o.websocket.RegisterHandler("olm/wg/peer/data/update", o.handleWgPeerUpdateData)
|
||||
|
||||
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
|
||||
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
|
||||
o.websocket.RegisterHandler("olm/sync", o.handleSync)
|
||||
|
||||
o.websocket.OnConnect(func() error {
|
||||
logger.Info("Websocket Connected")
|
||||
|
||||
o.apiServer.SetConnectionStatus(true)
|
||||
|
||||
if o.registered {
|
||||
o.websocket.StartPingMonitor()
|
||||
|
||||
logger.Debug("Already registered, skipping registration")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if tunnel is still running before starting registration
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel is no longer running, skipping registration")
|
||||
return nil
|
||||
}
|
||||
|
||||
publicKey := o.privateKey.PublicKey()
|
||||
|
||||
// delay for 500ms to allow for time for the hp to get processed
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Check again after sleep in case tunnel was stopped
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped during delay, skipping registration")
|
||||
return nil
|
||||
}
|
||||
|
||||
if o.stopRegister == nil {
|
||||
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
|
||||
o.stopRegister, o.updateRegister = o.websocket.SendMessageInterval("olm/wg/register", map[string]any{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": !config.Holepunch,
|
||||
"olmVersion": o.olmConfig.Version,
|
||||
"olmAgent": o.olmConfig.Agent,
|
||||
"orgId": config.OrgID,
|
||||
"userToken": userToken,
|
||||
"fingerprint": o.fingerprint,
|
||||
"postures": o.postures,
|
||||
}, 1*time.Second, 10)
|
||||
|
||||
// Invoke onRegistered callback if configured
|
||||
if o.olmConfig.OnRegistered != nil {
|
||||
go o.olmConfig.OnRegistered()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
o.websocket.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) {
|
||||
// Check if tunnel is still running and hole punch manager exists
|
||||
if !o.tunnelRunning || o.holePunchManager == nil {
|
||||
logger.Debug("Tunnel stopped or hole punch manager nil, ignoring token update")
|
||||
return
|
||||
}
|
||||
|
||||
o.holePunchManager.SetToken(token)
|
||||
|
||||
logger.Debug("Got exit nodes for hole punching: %v", exitNodes)
|
||||
|
||||
// Convert websocket.ExitNode to holepunch.ExitNode
|
||||
hpExitNodes := make([]holepunch.ExitNode, len(exitNodes))
|
||||
for i, node := range exitNodes {
|
||||
relayPort := node.RelayPort
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820 // default relay port
|
||||
}
|
||||
|
||||
hpExitNodes[i] = holepunch.ExitNode{
|
||||
Endpoint: node.Endpoint,
|
||||
RelayPort: relayPort,
|
||||
PublicKey: node.PublicKey,
|
||||
SiteIds: node.SiteIds,
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("Updated hole punch exit nodes: %v", hpExitNodes)
|
||||
|
||||
// Start hole punching using the manager
|
||||
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
|
||||
if err := o.holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil {
|
||||
logger.Warn("Failed to start hole punch: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
o.websocket.OnAuthError(func(statusCode int, message string) {
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring auth error")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message)
|
||||
o.apiServer.SetTerminated(true)
|
||||
o.apiServer.SetConnectionStatus(false)
|
||||
o.apiServer.SetRegistered(false)
|
||||
o.apiServer.ClearOlmError()
|
||||
o.apiServer.ClearPeerStatuses()
|
||||
network.ClearNetworkSettings()
|
||||
|
||||
o.Close()
|
||||
|
||||
if o.olmConfig.OnAuthError != nil {
|
||||
go o.olmConfig.OnAuthError(statusCode, message)
|
||||
}
|
||||
|
||||
if o.olmConfig.OnTerminated != nil {
|
||||
go o.olmConfig.OnTerminated()
|
||||
}
|
||||
})
|
||||
|
||||
// Indicate that tunnel is starting
|
||||
o.tunnelWg.Add(1)
|
||||
defer o.tunnelWg.Done()
|
||||
|
||||
// Connect to the WebSocket server
|
||||
if err := o.websocket.Connect(); err != nil {
|
||||
logger.Error("Failed to connect to server: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = o.websocket.Close() }()
|
||||
|
||||
// Wait for context cancellation
|
||||
<-tunnelCtx.Done()
|
||||
logger.Info("Tunnel process context cancelled, cleaning up")
|
||||
}
|
||||
|
||||
func (o *Olm) Close() {
|
||||
// Stop registration first to prevent it from trying to use closed websocket
|
||||
if o.stopRegister != nil {
|
||||
logger.Debug("Stopping registration interval")
|
||||
o.stopRegister()
|
||||
o.stopRegister = nil
|
||||
}
|
||||
|
||||
// send a disconnect message to the cloud to show disconnected
|
||||
if o.websocket != nil {
|
||||
o.websocket.SendMessage("olm/disconnecting", map[string]any{})
|
||||
// Close the websocket connection after sending disconnect
|
||||
_ = o.websocket.Close()
|
||||
o.websocket = nil
|
||||
}
|
||||
|
||||
// Restore original DNS configuration
|
||||
// we do this first to avoid any DNS issues if something else gets stuck
|
||||
if err := dnsOverride.RestoreDNSOverride(); err != nil {
|
||||
logger.Error("Failed to restore DNS: %v", err)
|
||||
}
|
||||
|
||||
if o.holePunchManager != nil {
|
||||
o.holePunchManager.Stop()
|
||||
o.holePunchManager = nil
|
||||
}
|
||||
|
||||
// Close() also calls Stop() internally
|
||||
if o.peerManager != nil {
|
||||
o.peerManager.Close()
|
||||
o.peerManager = nil
|
||||
}
|
||||
|
||||
if o.uapiListener != nil {
|
||||
_ = o.uapiListener.Close()
|
||||
o.uapiListener = nil
|
||||
}
|
||||
|
||||
if o.logFile != nil {
|
||||
_ = o.logFile.Close()
|
||||
o.logFile = nil
|
||||
}
|
||||
|
||||
// Stop DNS proxy first - it uses the middleDev for packet filtering
|
||||
if o.dnsProxy != nil {
|
||||
logger.Debug("Stopping DNS proxy")
|
||||
o.dnsProxy.Stop()
|
||||
o.dnsProxy = nil
|
||||
}
|
||||
|
||||
// Close MiddleDevice first - this closes the TUN and signals the closed channel
|
||||
// This unblocks the pump goroutine and allows WireGuard's TUN reader to exit
|
||||
// Note: o.tdev is closed by o.middleDev.Close() since middleDev wraps it
|
||||
if o.middleDev != nil {
|
||||
logger.Debug("Closing MiddleDevice")
|
||||
_ = o.middleDev.Close()
|
||||
o.middleDev = nil
|
||||
} else if o.tdev != nil {
|
||||
// If middleDev was never created but tdev exists, close it directly
|
||||
logger.Debug("Closing TUN device directly (no MiddleDevice)")
|
||||
_ = o.tdev.Close()
|
||||
o.tdev = nil
|
||||
} else if o.tunnelConfig.FileDescriptorTun != 0 {
|
||||
// If we never created a device from the FD, close it explicitly
|
||||
// This can happen if tunnel is stopped during registration before handleConnect
|
||||
logger.Debug("Closing unused TUN file descriptor %d", o.tunnelConfig.FileDescriptorTun)
|
||||
if err := syscall.Close(int(o.tunnelConfig.FileDescriptorTun)); err != nil {
|
||||
logger.Error("Failed to close TUN file descriptor: %v", err)
|
||||
} else {
|
||||
logger.Info("Closed unused TUN file descriptor")
|
||||
}
|
||||
o.tunnelConfig.FileDescriptorTun = 0
|
||||
}
|
||||
|
||||
// Now close WireGuard device - its TUN reader should have exited by now
|
||||
// This will call sharedBind.Close() which releases WireGuard's reference
|
||||
if o.dev != nil {
|
||||
logger.Debug("Closing WireGuard device")
|
||||
o.dev.Close()
|
||||
o.dev = nil
|
||||
}
|
||||
|
||||
// Release the hole punch reference to the shared bind (WireGuard already
|
||||
// released its reference via dev.Close())
|
||||
if o.sharedBind != nil {
|
||||
logger.Debug("Releasing shared bind (refcount before release: %d)", o.sharedBind.GetRefCount())
|
||||
_ = o.sharedBind.Release()
|
||||
logger.Info("Released shared UDP bind")
|
||||
o.sharedBind = nil
|
||||
}
|
||||
|
||||
logger.Info("Olm service stopped")
|
||||
}
|
||||
|
||||
// StopTunnel stops just the tunnel process and websocket connection
|
||||
// without shutting down the entire application
|
||||
func (o *Olm) StopTunnel() error {
|
||||
logger.Info("Stopping tunnel process")
|
||||
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel not running, nothing to stop")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset the running state BEFORE cleanup to prevent callbacks from accessing nil pointers
|
||||
o.registered = false
|
||||
o.tunnelRunning = false
|
||||
|
||||
// Cancel the tunnel context if it exists
|
||||
if o.tunnelCancel != nil {
|
||||
logger.Debug("Cancelling tunnel context")
|
||||
o.tunnelCancel()
|
||||
}
|
||||
|
||||
// Wait for the tunnel goroutine to complete
|
||||
logger.Debug("Waiting for tunnel goroutine to finish")
|
||||
o.tunnelWg.Wait()
|
||||
logger.Debug("Tunnel goroutine finished")
|
||||
|
||||
// Close() will handle sending disconnect message and closing websocket
|
||||
o.Close()
|
||||
|
||||
// Update API server status
|
||||
o.apiServer.SetConnectionStatus(false)
|
||||
o.apiServer.SetRegistered(false)
|
||||
o.apiServer.ClearOlmError()
|
||||
|
||||
network.ClearNetworkSettings()
|
||||
o.apiServer.ClearPeerStatuses()
|
||||
|
||||
logger.Info("Tunnel process stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Olm) StopApi() error {
|
||||
if o.apiServer != nil {
|
||||
err := o.apiServer.Stop()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stop API server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Olm) StartApi() error {
|
||||
if o.apiServer != nil {
|
||||
err := o.apiServer.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start API server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Olm) GetStatus() api.StatusResponse {
|
||||
return o.apiServer.GetStatus()
|
||||
}
|
||||
|
||||
func (o *Olm) SwitchOrg(orgID string) error {
|
||||
logger.Info("Processing org switch request to orgId: %s", orgID)
|
||||
// stop the tunnel
|
||||
if err := o.StopTunnel(); err != nil {
|
||||
return fmt.Errorf("failed to stop existing tunnel: %w", err)
|
||||
}
|
||||
|
||||
// Update the org ID in the API server and global config
|
||||
o.apiServer.SetOrgID(orgID)
|
||||
|
||||
o.tunnelConfig.OrgID = orgID
|
||||
|
||||
// Restart the tunnel with the same config but new org ID
|
||||
go o.StartTunnel(o.tunnelConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Olm) SetFingerprint(data map[string]any) {
|
||||
o.metaMu.Lock()
|
||||
defer o.metaMu.Unlock()
|
||||
|
||||
o.fingerprint = data
|
||||
}
|
||||
|
||||
func (o *Olm) SetPostures(data map[string]any) {
|
||||
o.metaMu.Lock()
|
||||
defer o.metaMu.Unlock()
|
||||
|
||||
o.postures = data
|
||||
}
|
||||
|
||||
// SetPowerMode switches between normal and low power modes
|
||||
// In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes
|
||||
// In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored
|
||||
// Wake-up has a 3-second debounce to prevent rapid flip-flopping; sleep is immediate
|
||||
func (o *Olm) SetPowerMode(mode string) error {
|
||||
// Validate mode
|
||||
if mode != "normal" && mode != "low" {
|
||||
return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode)
|
||||
}
|
||||
|
||||
o.powerModeMu.Lock()
|
||||
defer o.powerModeMu.Unlock()
|
||||
|
||||
// If already in the requested mode, return early
|
||||
if o.currentPowerMode == mode {
|
||||
// Cancel any pending wake-up timer if we're already in normal mode
|
||||
if mode == "normal" && o.wakeUpTimer != nil {
|
||||
o.wakeUpTimer.Stop()
|
||||
o.wakeUpTimer = nil
|
||||
}
|
||||
logger.Debug("Already in %s power mode", mode)
|
||||
return nil
|
||||
}
|
||||
|
||||
if mode == "low" {
|
||||
// Low Power Mode: Cancel any pending wake-up and immediately go to sleep
|
||||
|
||||
// Cancel pending wake-up timer if any
|
||||
if o.wakeUpTimer != nil {
|
||||
logger.Debug("Cancelling pending wake-up timer")
|
||||
o.wakeUpTimer.Stop()
|
||||
o.wakeUpTimer = nil
|
||||
}
|
||||
|
||||
logger.Info("Switching to low power mode")
|
||||
|
||||
// Update API server connection status
|
||||
if o.apiServer != nil {
|
||||
o.apiServer.SetConnectionStatus(false)
|
||||
}
|
||||
|
||||
if o.websocket != nil {
|
||||
logger.Info("Disconnecting websocket for low power mode")
|
||||
if err := o.websocket.Disconnect(); err != nil {
|
||||
logger.Error("Error disconnecting websocket: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
lowPowerInterval := 10 * time.Minute
|
||||
|
||||
if o.peerManager != nil {
|
||||
peerMonitor := o.peerManager.GetPeerMonitor()
|
||||
if peerMonitor != nil {
|
||||
peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval)
|
||||
peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval)
|
||||
logger.Info("Set monitoring intervals to 10 minutes for low power mode")
|
||||
}
|
||||
o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable
|
||||
}
|
||||
|
||||
if o.holePunchManager != nil {
|
||||
o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval)
|
||||
}
|
||||
|
||||
o.currentPowerMode = "low"
|
||||
logger.Info("Switched to low power mode")
|
||||
|
||||
} else {
|
||||
// Normal Power Mode: Start debounce timer before actually waking up
|
||||
|
||||
// If there's already a pending wake-up timer, don't start another
|
||||
if o.wakeUpTimer != nil {
|
||||
logger.Debug("Wake-up already pending, ignoring duplicate request")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Wake-up requested, starting %v debounce timer", o.wakeUpDebounce)
|
||||
|
||||
o.wakeUpTimer = time.AfterFunc(o.wakeUpDebounce, func() {
|
||||
o.powerModeMu.Lock()
|
||||
defer o.powerModeMu.Unlock()
|
||||
|
||||
// Clear the timer reference
|
||||
o.wakeUpTimer = nil
|
||||
|
||||
// Double-check we're still in low power mode (could have changed)
|
||||
if o.currentPowerMode == "normal" {
|
||||
logger.Debug("Already in normal mode after debounce, skipping wake-up")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Debounce complete, switching to normal power mode")
|
||||
|
||||
logger.Info("Reconnecting websocket for normal power mode")
|
||||
if o.websocket != nil {
|
||||
if err := o.websocket.Connect(); err != nil {
|
||||
logger.Error("Failed to reconnect websocket: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Restore intervals and reconnect websocket
|
||||
if o.peerManager != nil {
|
||||
peerMonitor := o.peerManager.GetPeerMonitor()
|
||||
if peerMonitor != nil {
|
||||
peerMonitor.ResetPeerHolepunchInterval()
|
||||
peerMonitor.ResetPeerInterval()
|
||||
}
|
||||
|
||||
o.peerManager.UpdateAllPeersPersistentKeepalive(5)
|
||||
}
|
||||
|
||||
if o.holePunchManager != nil {
|
||||
o.holePunchManager.ResetServerHolepunchInterval()
|
||||
}
|
||||
|
||||
o.currentPowerMode = "normal"
|
||||
logger.Info("Switched to normal power mode")
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RebindSocket recreates the UDP socket when network connectivity changes.
|
||||
// This is necessary on macOS/iOS when transitioning between WiFi and cellular,
|
||||
// as the old socket becomes stale and can no longer route packets.
|
||||
// Call this method when detecting a network path change.
|
||||
func (o *Olm) RebindSocket() error {
|
||||
if o.sharedBind == nil {
|
||||
return fmt.Errorf("shared bind is not initialized")
|
||||
}
|
||||
|
||||
// Close the old socket first to release the port, then try to rebind to the same port
|
||||
currentPort, err := o.sharedBind.CloseSocket()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close old socket: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Rebinding UDP socket (released port: %d)", currentPort)
|
||||
|
||||
// Create a new UDP socket
|
||||
var newConn *net.UDPConn
|
||||
var newPort uint16
|
||||
|
||||
// First try to bind to the same port (now available since we closed the old socket)
|
||||
localAddr := &net.UDPAddr{
|
||||
Port: int(currentPort),
|
||||
IP: net.IPv4zero,
|
||||
}
|
||||
|
||||
newConn, err = net.ListenUDP("udp4", localAddr)
|
||||
if err != nil {
|
||||
// If we can't reuse the port, find a new one
|
||||
logger.Warn("Could not rebind to port %d, finding new port: %v", currentPort, err)
|
||||
newPort, err = util.FindAvailableUDPPort(49152, 65535)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find available UDP port: %w", err)
|
||||
}
|
||||
|
||||
localAddr = &net.UDPAddr{
|
||||
Port: int(newPort),
|
||||
IP: net.IPv4zero,
|
||||
}
|
||||
|
||||
// Use udp4 explicitly to avoid IPv6 dual-stack issues
|
||||
newConn, err = net.ListenUDP("udp4", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new UDP socket: %w", err)
|
||||
}
|
||||
} else {
|
||||
newPort = currentPort
|
||||
}
|
||||
|
||||
// Rebind the shared bind with the new connection
|
||||
if err := o.sharedBind.Rebind(newConn); err != nil {
|
||||
newConn.Close()
|
||||
return fmt.Errorf("failed to rebind shared bind: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Successfully rebound UDP socket on port %d", newPort)
|
||||
|
||||
// Check if we're in low power mode before triggering hole punch
|
||||
o.powerModeMu.Lock()
|
||||
isLowPower := o.currentPowerMode == "low"
|
||||
o.powerModeMu.Unlock()
|
||||
|
||||
// Only trigger hole punch if not in low power mode
|
||||
if !isLowPower && o.holePunchManager != nil {
|
||||
o.holePunchManager.TriggerHolePunch()
|
||||
o.holePunchManager.ResetServerHolepunchInterval()
|
||||
logger.Info("Triggered hole punch after socket rebind")
|
||||
} else if isLowPower {
|
||||
logger.Info("Skipping hole punch trigger due to low power mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Olm) AddDevice(fd uint32) error {
|
||||
if o.middleDev == nil {
|
||||
return fmt.Errorf("middle device is not initialized")
|
||||
}
|
||||
|
||||
if o.tunnelConfig.MTU == 0 {
|
||||
return fmt.Errorf("tunnel MTU is not set")
|
||||
}
|
||||
|
||||
tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create TUN device from fd: %v", err)
|
||||
}
|
||||
|
||||
// Update interface name if available
|
||||
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
||||
o.tunnelConfig.InterfaceName = realInterfaceName
|
||||
}
|
||||
|
||||
// Replace the existing TUN device in the middle device with the new one
|
||||
o.middleDev.AddDevice(tdev)
|
||||
|
||||
logger.Info("Added device from file descriptor %d", fd)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetNetworkSettingsJSON() (string, error) {
|
||||
return network.GetJSON()
|
||||
}
|
||||
|
||||
func GetNetworkSettingsIncrementor() int {
|
||||
return network.GetIncrementor()
|
||||
}
|
||||
282
olm/peer.go
Normal file
282
olm/peer.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/peers"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
)
|
||||
|
||||
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
||||
logger.Debug("Received add-peer message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring add-peer message")
|
||||
return
|
||||
}
|
||||
|
||||
if o.stopPeerSend != nil {
|
||||
o.stopPeerSend()
|
||||
o.stopPeerSend = nil
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var siteConfig peers.SiteConfig
|
||||
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
|
||||
logger.Error("Error unmarshaling add data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||
|
||||
if err := o.peerManager.AddPeer(siteConfig); err != nil {
|
||||
logger.Error("Failed to add peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
||||
logger.Debug("Received remove-peer message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring remove-peer message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var removeData peers.PeerRemove
|
||||
if err := json.Unmarshal(jsonData, &removeData); err != nil {
|
||||
logger.Error("Error unmarshaling remove data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil {
|
||||
logger.Error("Failed to remove peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove any exit nodes associated with this peer from hole punching
|
||||
if o.holePunchManager != nil {
|
||||
removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId)
|
||||
if removed > 0 {
|
||||
logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
|
||||
logger.Debug("Received update-peer message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring update-peer message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var updateData peers.SiteConfig
|
||||
if err := json.Unmarshal(jsonData, &updateData); err != nil {
|
||||
logger.Error("Error unmarshaling update data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get existing peer from PeerManager
|
||||
existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId)
|
||||
if !exists {
|
||||
logger.Warn("Peer with site ID %d not found", updateData.SiteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Create updated site config by merging with existing data
|
||||
siteConfig := existingPeer
|
||||
|
||||
if updateData.Endpoint != "" {
|
||||
siteConfig.Endpoint = updateData.Endpoint
|
||||
}
|
||||
if updateData.RelayEndpoint != "" {
|
||||
siteConfig.RelayEndpoint = updateData.RelayEndpoint
|
||||
}
|
||||
if updateData.PublicKey != "" {
|
||||
siteConfig.PublicKey = updateData.PublicKey
|
||||
}
|
||||
if updateData.ServerIP != "" {
|
||||
siteConfig.ServerIP = updateData.ServerIP
|
||||
}
|
||||
if updateData.ServerPort != 0 {
|
||||
siteConfig.ServerPort = updateData.ServerPort
|
||||
}
|
||||
if updateData.RemoteSubnets != nil {
|
||||
siteConfig.RemoteSubnets = updateData.RemoteSubnets
|
||||
}
|
||||
|
||||
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
|
||||
logger.Error("Failed to update peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the endpoint changed, trigger holepunch to refresh NAT mappings
|
||||
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
|
||||
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
|
||||
_ = o.holePunchManager.TriggerHolePunch()
|
||||
o.holePunchManager.ResetServerHolepunchInterval()
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
|
||||
logger.Debug("Received relay-peer message: %v", msg.Data)
|
||||
|
||||
// Check if peerManager is still valid (may be nil during shutdown)
|
||||
if o.peerManager == nil {
|
||||
logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var relayData peers.RelayPeerData
|
||||
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||
logger.Error("Error unmarshaling relay data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve primary relay endpoint: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update HTTP server to mark this peer as using relay
|
||||
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true)
|
||||
|
||||
o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
|
||||
logger.Debug("Received unrelay-peer message: %v", msg.Data)
|
||||
|
||||
// Check if peerManager is still valid (may be nil during shutdown)
|
||||
if o.peerManager == nil {
|
||||
logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var relayData peers.UnRelayPeerData
|
||||
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||
logger.Error("Error unmarshaling relay data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||
}
|
||||
|
||||
// Update HTTP server to mark this peer as using relay
|
||||
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false)
|
||||
|
||||
o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
logger.Debug("Received peer-handshake message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring peer-handshake message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling handshake data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var handshakeData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
ExitNode struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
RelayPort uint16 `json:"relayPort"`
|
||||
} `json:"exitNode"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
|
||||
logger.Error("Error unmarshaling handshake data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get existing peer from PeerManager
|
||||
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
||||
if exists {
|
||||
logger.Warn("Peer with site ID %d already added", handshakeData.SiteId)
|
||||
return
|
||||
}
|
||||
|
||||
relayPort := handshakeData.ExitNode.RelayPort
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820 // default relay port
|
||||
}
|
||||
|
||||
siteId := handshakeData.SiteId
|
||||
exitNode := holepunch.ExitNode{
|
||||
Endpoint: handshakeData.ExitNode.Endpoint,
|
||||
RelayPort: relayPort,
|
||||
PublicKey: handshakeData.ExitNode.PublicKey,
|
||||
SiteIds: []int{siteId},
|
||||
}
|
||||
|
||||
added := o.holePunchManager.AddExitNode(exitNode)
|
||||
if added {
|
||||
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
|
||||
} else {
|
||||
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
|
||||
}
|
||||
|
||||
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
||||
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||
|
||||
// Send handshake acknowledgment back to server with retry
|
||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": handshakeData.SiteId,
|
||||
}, 1*time.Second, 10)
|
||||
|
||||
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||
}
|
||||
89
olm/types.go
Normal file
89
olm/types.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/olm/peers"
|
||||
)
|
||||
|
||||
type WgData struct {
|
||||
Sites []peers.SiteConfig `json:"sites"`
|
||||
TunnelIP string `json:"tunnelIP"`
|
||||
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
|
||||
}
|
||||
|
||||
type SyncData struct {
|
||||
Sites []peers.SiteConfig `json:"sites"`
|
||||
ExitNodes []SyncExitNode `json:"exitNodes"`
|
||||
}
|
||||
|
||||
type SyncExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
RelayPort uint16 `json:"relayPort"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
SiteIds []int `json:"siteIds"`
|
||||
}
|
||||
|
||||
type OlmConfig struct {
|
||||
// Logging
|
||||
LogLevel string
|
||||
LogFilePath string
|
||||
|
||||
// HTTP server
|
||||
EnableAPI bool
|
||||
HTTPAddr string
|
||||
SocketPath string
|
||||
Version string
|
||||
Agent string
|
||||
|
||||
WakeUpDebounce time.Duration
|
||||
|
||||
// Debugging
|
||||
PprofAddr string // Address to serve pprof on (e.g., "localhost:6060")
|
||||
|
||||
// Callbacks
|
||||
OnRegistered func()
|
||||
OnConnected func()
|
||||
OnTerminated func()
|
||||
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
|
||||
OnOlmError func(code string, message string) // Called when registration fails
|
||||
OnExit func() // Called when exit is requested via API
|
||||
}
|
||||
|
||||
type TunnelConfig struct {
|
||||
// Connection settings
|
||||
Endpoint string
|
||||
ID string
|
||||
Secret string
|
||||
UserToken string
|
||||
|
||||
// Network settings
|
||||
MTU int
|
||||
DNS string
|
||||
UpstreamDNS []string
|
||||
InterfaceName string
|
||||
|
||||
// Advanced
|
||||
Holepunch bool
|
||||
TlsClientCert string
|
||||
|
||||
// Parsed values (not in JSON)
|
||||
PingIntervalDuration time.Duration
|
||||
PingTimeoutDuration time.Duration
|
||||
|
||||
OrgID string
|
||||
// DoNotCreateNewClient bool
|
||||
|
||||
FileDescriptorTun uint32
|
||||
FileDescriptorUAPI uint32
|
||||
|
||||
EnableUAPI bool
|
||||
|
||||
OverrideDNS bool
|
||||
TunnelDNS bool
|
||||
|
||||
InitialFingerprint map[string]any
|
||||
InitialPostures map[string]any
|
||||
|
||||
DisableRelay bool
|
||||
}
|
||||
47
olm/util.go
Normal file
47
olm/util.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"github.com/fosrl/olm/peers"
|
||||
)
|
||||
|
||||
// slicesEqual compares two string slices for equality (order-independent)
|
||||
func slicesEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
// Create a map to count occurrences in slice a
|
||||
counts := make(map[string]int)
|
||||
for _, v := range a {
|
||||
counts[v]++
|
||||
}
|
||||
// Check if slice b has the same elements
|
||||
for _, v := range b {
|
||||
counts[v]--
|
||||
if counts[v] < 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// aliasesEqual compares two Alias slices for equality (order-independent)
|
||||
func aliasesEqual(a, b []peers.Alias) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
// Create a map to count occurrences in slice a (using alias+address as key)
|
||||
counts := make(map[string]int)
|
||||
for _, v := range a {
|
||||
key := v.Alias + "|" + v.AliasAddress
|
||||
counts[key]++
|
||||
}
|
||||
// Check if slice b has the same elements
|
||||
for _, v := range b {
|
||||
key := v.Alias + "|" + v.AliasAddress
|
||||
counts[key]--
|
||||
if counts[key] < 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,324 +0,0 @@
|
||||
package peermonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/fosrl/olm/wgtester"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
// PeerMonitorCallback is the function type for connection status change callbacks
|
||||
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
|
||||
|
||||
// WireGuardConfig holds the WireGuard configuration for a peer
|
||||
type WireGuardConfig struct {
|
||||
SiteID int
|
||||
PublicKey string
|
||||
ServerIP string
|
||||
Endpoint string
|
||||
PrimaryRelay string // The primary relay endpoint
|
||||
}
|
||||
|
||||
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||
type PeerMonitor struct {
|
||||
monitors map[int]*wgtester.Client
|
||||
configs map[int]*WireGuardConfig
|
||||
callback PeerMonitorCallback
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
privateKey string
|
||||
wsClient *websocket.Client
|
||||
device *device.Device
|
||||
handleRelaySwitch bool // Whether to handle relay switching
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
|
||||
return &PeerMonitor{
|
||||
monitors: make(map[int]*wgtester.Client),
|
||||
configs: make(map[int]*WireGuardConfig),
|
||||
callback: callback,
|
||||
interval: 1 * time.Second, // Default check interval
|
||||
timeout: 2500 * time.Millisecond,
|
||||
maxAttempts: 8,
|
||||
privateKey: privateKey,
|
||||
wsClient: wsClient,
|
||||
device: device,
|
||||
handleRelaySwitch: handleRelaySwitch,
|
||||
}
|
||||
}
|
||||
|
||||
// SetInterval changes how frequently peers are checked
|
||||
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.interval = interval
|
||||
|
||||
// Update interval for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetPacketInterval(interval)
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeout changes the timeout for waiting for responses
|
||||
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.timeout = timeout
|
||||
|
||||
// Update timeout for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetTimeout(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.maxAttempts = attempts
|
||||
|
||||
// Update max attempts for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetMaxAttempts(attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeer adds a new peer to monitor
|
||||
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// Check if we're already monitoring this peer
|
||||
if _, exists := pm.monitors[siteID]; exists {
|
||||
// Update the endpoint instead of creating a new monitor
|
||||
pm.removePeerUnlocked(siteID)
|
||||
}
|
||||
|
||||
client, err := wgtester.NewClient(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure the client with our settings
|
||||
client.SetPacketInterval(pm.interval)
|
||||
client.SetTimeout(pm.timeout)
|
||||
client.SetMaxAttempts(pm.maxAttempts)
|
||||
|
||||
// Store the client and config
|
||||
pm.monitors[siteID] = client
|
||||
pm.configs[siteID] = wgConfig
|
||||
|
||||
// If monitor is already running, start monitoring this peer
|
||||
if pm.running {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||
})
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
|
||||
// This function assumes the mutex is already held by the caller
|
||||
func (pm *PeerMonitor) removePeerUnlocked(siteID int) {
|
||||
client, exists := pm.monitors[siteID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
delete(pm.configs, siteID)
|
||||
}
|
||||
|
||||
// RemovePeer stops monitoring a peer and removes it from the monitor
|
||||
func (pm *PeerMonitor) RemovePeer(siteID int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.removePeerUnlocked(siteID)
|
||||
}
|
||||
|
||||
// Start begins monitoring all peers
|
||||
func (pm *PeerMonitor) Start() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if pm.running {
|
||||
return // Already running
|
||||
}
|
||||
|
||||
pm.running = true
|
||||
|
||||
// Start monitoring all peers
|
||||
for siteID, client := range pm.monitors {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err)
|
||||
continue
|
||||
}
|
||||
logger.Info("Started monitoring peer %d\n", siteID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnectionStatusChange is called when a peer's connection status changes
|
||||
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status wgtester.ConnectionStatus) {
|
||||
// Call the user-provided callback first
|
||||
if pm.callback != nil {
|
||||
pm.callback(siteID, status.Connected, status.RTT)
|
||||
}
|
||||
|
||||
// If disconnected, handle failover
|
||||
if !status.Connected {
|
||||
// Send relay message to the server
|
||||
if pm.wsClient != nil {
|
||||
pm.sendRelay(siteID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleFailover handles failover to the relay server when a peer is disconnected
|
||||
func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) {
|
||||
pm.mutex.Lock()
|
||||
config, exists := pm.configs[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Configure WireGuard to use the relay
|
||||
wgConfig := fmt.Sprintf(`private_key=%s
|
||||
public_key=%s
|
||||
allowed_ip=%s/32
|
||||
endpoint=%s:21820
|
||||
persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, relayEndpoint)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
logger.Error("Failed to configure WireGuard device: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Adjusted peer %d to point to relay!\n", siteID)
|
||||
}
|
||||
|
||||
// sendRelay sends a relay message to the server
|
||||
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||
if !pm.handleRelaySwitch {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent relay message")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops monitoring all peers
|
||||
func (pm *PeerMonitor) Stop() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
|
||||
// Stop all monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.StopMonitor()
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops monitoring and cleans up resources
|
||||
func (pm *PeerMonitor) Close() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// Stop and close all clients
|
||||
for siteID, client := range pm.monitors {
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
}
|
||||
|
||||
// TestPeer tests connectivity to a specific peer
|
||||
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||
pm.mutex.Lock()
|
||||
client, exists := pm.monitors[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
defer cancel()
|
||||
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
return connected, rtt, nil
|
||||
}
|
||||
|
||||
// TestAllPeers tests connectivity to all peers
|
||||
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
} {
|
||||
pm.mutex.Lock()
|
||||
peers := make(map[int]*wgtester.Client, len(pm.monitors))
|
||||
for siteID, client := range pm.monitors {
|
||||
peers[siteID] = client
|
||||
}
|
||||
pm.mutex.Unlock()
|
||||
|
||||
results := make(map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
})
|
||||
for siteID, client := range peers {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
cancel()
|
||||
|
||||
results[siteID] = struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
}{
|
||||
Connected: connected,
|
||||
RTT: rtt,
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
920
peers/manager.go
Normal file
920
peers/manager.go
Normal file
@@ -0,0 +1,920 @@
|
||||
package peers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/api"
|
||||
olmDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/dns"
|
||||
"github.com/fosrl/olm/peers/monitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// PeerManagerConfig contains the configuration for creating a PeerManager
|
||||
type PeerManagerConfig struct {
|
||||
Device *device.Device
|
||||
DNSProxy *dns.DNSProxy
|
||||
InterfaceName string
|
||||
PrivateKey wgtypes.Key
|
||||
// For peer monitoring
|
||||
MiddleDev *olmDevice.MiddleDevice
|
||||
LocalIP string
|
||||
SharedBind *bind.SharedBind
|
||||
// WSClient is optional - if nil, relay messages won't be sent
|
||||
WSClient *websocket.Client
|
||||
APIServer *api.API
|
||||
}
|
||||
|
||||
type PeerManager struct {
|
||||
mu sync.RWMutex
|
||||
device *device.Device
|
||||
peers map[int]SiteConfig
|
||||
peerMonitor *monitor.PeerMonitor
|
||||
dnsProxy *dns.DNSProxy
|
||||
interfaceName string
|
||||
privateKey wgtypes.Key
|
||||
// allowedIPOwners tracks which peer currently "owns" each allowed IP in WireGuard
|
||||
// key is the CIDR string, value is the siteId that has it configured in WG
|
||||
allowedIPOwners map[string]int
|
||||
// allowedIPClaims tracks all peers that claim each allowed IP
|
||||
// key is the CIDR string, value is a set of siteIds that want this IP
|
||||
allowedIPClaims map[string]map[int]bool
|
||||
APIServer *api.API
|
||||
|
||||
PersistentKeepalive int
|
||||
}
|
||||
|
||||
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
|
||||
func NewPeerManager(config PeerManagerConfig) *PeerManager {
|
||||
pm := &PeerManager{
|
||||
device: config.Device,
|
||||
peers: make(map[int]SiteConfig),
|
||||
dnsProxy: config.DNSProxy,
|
||||
interfaceName: config.InterfaceName,
|
||||
privateKey: config.PrivateKey,
|
||||
allowedIPOwners: make(map[string]int),
|
||||
allowedIPClaims: make(map[string]map[int]bool),
|
||||
APIServer: config.APIServer,
|
||||
}
|
||||
|
||||
// Create the peer monitor
|
||||
pm.peerMonitor = monitor.NewPeerMonitor(
|
||||
config.WSClient,
|
||||
config.MiddleDev,
|
||||
config.LocalIP,
|
||||
config.SharedBind,
|
||||
config.APIServer,
|
||||
)
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
peer, ok := pm.peers[siteId]
|
||||
return peer, ok
|
||||
}
|
||||
|
||||
// GetPeerMonitor returns the internal peer monitor instance
|
||||
func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
return pm.peerMonitor
|
||||
}
|
||||
|
||||
func (pm *PeerManager) GetAllPeers() []SiteConfig {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
peers := make([]SiteConfig, 0, len(pm.peers))
|
||||
for _, peer := range pm.peers {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
// build the allowed IPs list from the remote subnets and aliases and add them to the peer
|
||||
allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
|
||||
allowedIPs = append(allowedIPs, siteConfig.RemoteSubnets...)
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
allowedIPs = append(allowedIPs, alias.AliasAddress+"/32")
|
||||
}
|
||||
siteConfig.AllowedIps = allowedIPs
|
||||
|
||||
// Register claims for all allowed IPs and determine which ones this peer will own
|
||||
ownedIPs := make([]string, 0, len(allowedIPs))
|
||||
for _, ip := range allowedIPs {
|
||||
pm.claimAllowedIP(siteConfig.SiteId, ip)
|
||||
// Check if this peer became the owner
|
||||
if pm.allowedIPOwners[ip] == siteConfig.SiteId {
|
||||
ownedIPs = append(ownedIPs, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a config with only the owned IPs for WireGuard
|
||||
wgConfig := siteConfig
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := network.AddRouteForServerIP(siteConfig.ServerIP, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to add route for server IP: %v", err)
|
||||
}
|
||||
if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||
}
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||
|
||||
err := pm.peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, siteConfig.Endpoint) // always use the real site endpoint for hole punch monitoring
|
||||
if err != nil {
|
||||
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
|
||||
} else {
|
||||
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||
}
|
||||
|
||||
pm.peers[siteConfig.SiteId] = siteConfig
|
||||
|
||||
pm.APIServer.AddPeerStatus(siteConfig.SiteId, siteConfig.Name, false, 0, siteConfig.Endpoint, false)
|
||||
|
||||
// Perform rapid initial holepunch test (outside of lock to avoid blocking)
|
||||
// This quickly determines if holepunch is viable and triggers relay if not
|
||||
go pm.performRapidInitialTest(siteConfig.SiteId, siteConfig.Endpoint)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once
|
||||
// without recreating them. Returns a map of siteId to error for any peers that failed to update.
|
||||
func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
pm.PersistentKeepalive = interval
|
||||
|
||||
errors := make(map[int]error)
|
||||
|
||||
for siteId, peer := range pm.peers {
|
||||
err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval)
|
||||
if err != nil {
|
||||
errors[siteId] = err
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) == 0 {
|
||||
return nil
|
||||
}
|
||||
return errors
|
||||
}
|
||||
|
||||
func (pm *PeerManager) RemovePeer(siteId int) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
if err := RemovePeer(pm.device, siteId, peer.PublicKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := network.RemoveRouteForServerIP(peer.ServerIP, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to remove route for server IP: %v", err)
|
||||
}
|
||||
|
||||
// Only remove routes for subnets that aren't used by other peers
|
||||
for _, subnet := range peer.RemoteSubnets {
|
||||
subnetStillInUse := false
|
||||
for otherSiteId, otherPeer := range pm.peers {
|
||||
if otherSiteId == siteId {
|
||||
continue // Skip the peer being removed
|
||||
}
|
||||
for _, otherSubnet := range otherPeer.RemoteSubnets {
|
||||
if otherSubnet == subnet {
|
||||
subnetStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if subnetStillInUse {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !subnetStillInUse {
|
||||
if err := network.RemoveRoutes([]string{subnet}); err != nil {
|
||||
logger.Error("Failed to remove route for remote subnet %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For aliases
|
||||
for _, alias := range peer.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.RemoveDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
// Release all IP claims and promote other peers as needed
|
||||
// Collect promotions first to avoid modifying while iterating
|
||||
type promotion struct {
|
||||
newOwner int
|
||||
cidr string
|
||||
}
|
||||
var promotions []promotion
|
||||
|
||||
for _, ip := range peer.AllowedIps {
|
||||
newOwner, promoted := pm.releaseAllowedIP(siteId, ip)
|
||||
if promoted && newOwner >= 0 {
|
||||
promotions = append(promotions, promotion{newOwner: newOwner, cidr: ip})
|
||||
}
|
||||
}
|
||||
|
||||
// Apply promotions - update WireGuard config for newly promoted peers
|
||||
// Group by peer to avoid multiple config updates
|
||||
promotedPeers := make(map[int]bool)
|
||||
for _, p := range promotions {
|
||||
promotedPeers[p.newOwner] = true
|
||||
logger.Info("Promoted peer %d to owner of IP %s", p.newOwner, p.cidr)
|
||||
}
|
||||
|
||||
for promotedPeerId := range promotedPeers {
|
||||
if promotedPeer, exists := pm.peers[promotedPeerId]; exists {
|
||||
// Build the list of IPs this peer now owns
|
||||
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||
wgConfig := promotedPeer
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
|
||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop monitoring this peer
|
||||
pm.peerMonitor.RemovePeer(siteId)
|
||||
logger.Info("Stopped monitoring for site %d", siteId)
|
||||
|
||||
pm.APIServer.RemovePeerStatus(siteId)
|
||||
|
||||
delete(pm.peers, siteId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
oldPeer, exists := pm.peers[siteConfig.SiteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteConfig.SiteId)
|
||||
}
|
||||
|
||||
// If public key changed, remove old peer first
|
||||
if siteConfig.PublicKey != oldPeer.PublicKey {
|
||||
if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil {
|
||||
logger.Error("Failed to remove old peer: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build the new allowed IPs list
|
||||
newAllowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
|
||||
newAllowedIPs = append(newAllowedIPs, siteConfig.RemoteSubnets...)
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
newAllowedIPs = append(newAllowedIPs, alias.AliasAddress+"/32")
|
||||
}
|
||||
siteConfig.AllowedIps = newAllowedIPs
|
||||
|
||||
// Handle allowed IP claim changes
|
||||
oldAllowedIPs := make(map[string]bool)
|
||||
for _, ip := range oldPeer.AllowedIps {
|
||||
oldAllowedIPs[ip] = true
|
||||
}
|
||||
newAllowedIPsSet := make(map[string]bool)
|
||||
for _, ip := range newAllowedIPs {
|
||||
newAllowedIPsSet[ip] = true
|
||||
}
|
||||
|
||||
// Track peers that need WireGuard config updates due to promotions
|
||||
peersToUpdate := make(map[int]bool)
|
||||
|
||||
// Release claims for removed IPs and handle promotions
|
||||
for ip := range oldAllowedIPs {
|
||||
if !newAllowedIPsSet[ip] {
|
||||
newOwner, promoted := pm.releaseAllowedIP(siteConfig.SiteId, ip)
|
||||
if promoted && newOwner >= 0 {
|
||||
peersToUpdate[newOwner] = true
|
||||
logger.Info("Promoted peer %d to owner of IP %s", newOwner, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add claims for new IPs
|
||||
for ip := range newAllowedIPsSet {
|
||||
if !oldAllowedIPs[ip] {
|
||||
pm.claimAllowedIP(siteConfig.SiteId, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Build the list of IPs this peer owns for WireGuard config
|
||||
ownedIPs := pm.getOwnedAllowedIPs(siteConfig.SiteId)
|
||||
wgConfig := siteConfig
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update WireGuard config for any promoted peers
|
||||
for promotedPeerId := range peersToUpdate {
|
||||
if promotedPeer, exists := pm.peers[promotedPeerId]; exists {
|
||||
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||
promotedWgConfig := promotedPeer
|
||||
promotedWgConfig.AllowedIps = promotedOwnedIPs
|
||||
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
|
||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle remote subnet route changes
|
||||
// Calculate added and removed subnets
|
||||
oldSubnets := make(map[string]bool)
|
||||
for _, s := range oldPeer.RemoteSubnets {
|
||||
oldSubnets[s] = true
|
||||
}
|
||||
newSubnets := make(map[string]bool)
|
||||
for _, s := range siteConfig.RemoteSubnets {
|
||||
newSubnets[s] = true
|
||||
}
|
||||
|
||||
var addedSubnets []string
|
||||
var removedSubnets []string
|
||||
|
||||
for s := range newSubnets {
|
||||
if !oldSubnets[s] {
|
||||
addedSubnets = append(addedSubnets, s)
|
||||
}
|
||||
}
|
||||
for s := range oldSubnets {
|
||||
if !newSubnets[s] {
|
||||
removedSubnets = append(removedSubnets, s)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove routes for removed subnets (only if no other peer needs them)
|
||||
for _, subnet := range removedSubnets {
|
||||
subnetStillInUse := false
|
||||
for otherSiteId, otherPeer := range pm.peers {
|
||||
if otherSiteId == siteConfig.SiteId {
|
||||
continue // Skip the current peer (already updated)
|
||||
}
|
||||
for _, otherSubnet := range otherPeer.RemoteSubnets {
|
||||
if otherSubnet == subnet {
|
||||
subnetStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if subnetStillInUse {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !subnetStillInUse {
|
||||
if err := network.RemoveRoutes([]string{subnet}); err != nil {
|
||||
logger.Error("Failed to remove route for subnet %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add routes for added subnets
|
||||
if len(addedSubnets) > 0 {
|
||||
if err := network.AddRoutes(addedSubnets, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to add routes: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update aliases
|
||||
// Remove old aliases
|
||||
for _, alias := range oldPeer.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.RemoveDNSRecord(alias.Alias, address)
|
||||
}
|
||||
// Add new aliases
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
|
||||
|
||||
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||
pm.peerMonitor.UpdatePeerEndpoint(siteConfig.SiteId, monitorPeer) // +1 for monitor port
|
||||
|
||||
pm.peers[siteConfig.SiteId] = siteConfig
|
||||
return nil
|
||||
}
|
||||
|
||||
// claimAllowedIP registers a peer's claim to an allowed IP.
|
||||
// If no other peer owns it in WireGuard, this peer becomes the owner.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) claimAllowedIP(siteId int, cidr string) {
|
||||
// Add to claims
|
||||
if pm.allowedIPClaims[cidr] == nil {
|
||||
pm.allowedIPClaims[cidr] = make(map[int]bool)
|
||||
}
|
||||
pm.allowedIPClaims[cidr][siteId] = true
|
||||
|
||||
// If no owner yet, this peer becomes the owner
|
||||
if _, hasOwner := pm.allowedIPOwners[cidr]; !hasOwner {
|
||||
pm.allowedIPOwners[cidr] = siteId
|
||||
}
|
||||
}
|
||||
|
||||
// releaseAllowedIP removes a peer's claim to an allowed IP.
|
||||
// If this peer was the owner, it promotes another claimant to owner.
|
||||
// Returns the new owner's siteId (or -1 if no new owner) and whether promotion occurred.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) releaseAllowedIP(siteId int, cidr string) (newOwner int, promoted bool) {
|
||||
// Remove from claims
|
||||
if claims, exists := pm.allowedIPClaims[cidr]; exists {
|
||||
delete(claims, siteId)
|
||||
if len(claims) == 0 {
|
||||
delete(pm.allowedIPClaims, cidr)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this peer was the owner
|
||||
owner, isOwned := pm.allowedIPOwners[cidr]
|
||||
if !isOwned || owner != siteId {
|
||||
return -1, false // Not the owner, nothing to promote
|
||||
}
|
||||
|
||||
// This peer was the owner, need to find a new owner
|
||||
delete(pm.allowedIPOwners, cidr)
|
||||
|
||||
// Find another claimant to promote
|
||||
if claims, exists := pm.allowedIPClaims[cidr]; exists && len(claims) > 0 {
|
||||
for claimantId := range claims {
|
||||
pm.allowedIPOwners[cidr] = claimantId
|
||||
return claimantId, true
|
||||
}
|
||||
}
|
||||
|
||||
return -1, false
|
||||
}
|
||||
|
||||
// getOwnedAllowedIPs returns the list of allowed IPs that a peer currently owns in WireGuard.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) getOwnedAllowedIPs(siteId int) []string {
|
||||
var owned []string
|
||||
for cidr, owner := range pm.allowedIPOwners {
|
||||
if owner == siteId {
|
||||
owned = append(owned, cidr)
|
||||
}
|
||||
}
|
||||
return owned
|
||||
}
|
||||
|
||||
// addAllowedIp adds an IP (subnet) to the allowed IPs list of a peer
|
||||
// and updates WireGuard configuration if this peer owns the IP.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) addAllowedIp(siteId int, ip string) error {
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
// Check if IP already exists in AllowedIps
|
||||
for _, allowedIp := range peer.AllowedIps {
|
||||
if allowedIp == ip {
|
||||
return nil // Already exists
|
||||
}
|
||||
}
|
||||
|
||||
// Register our claim to this IP
|
||||
pm.claimAllowedIP(siteId, ip)
|
||||
|
||||
peer.AllowedIps = append(peer.AllowedIps, ip)
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
// Only update WireGuard if we own this IP
|
||||
if pm.allowedIPOwners[ip] == siteId {
|
||||
if err := AddAllowedIP(pm.device, peer.PublicKey, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeAllowedIp removes an IP (subnet) from the allowed IPs list of a peer
|
||||
// and updates WireGuard configuration. If this peer owned the IP, it promotes
|
||||
// another peer that also claims this IP. Must be called with lock held.
|
||||
func (pm *PeerManager) removeAllowedIp(siteId int, cidr string) error {
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
found := false
|
||||
|
||||
// Remove from AllowedIps
|
||||
newAllowedIps := make([]string, 0, len(peer.AllowedIps))
|
||||
for _, allowedIp := range peer.AllowedIps {
|
||||
if allowedIp == cidr {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newAllowedIps = append(newAllowedIps, allowedIp)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil // Not found
|
||||
}
|
||||
|
||||
peer.AllowedIps = newAllowedIps
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
// Release our claim and check if we need to promote another peer
|
||||
newOwner, promoted := pm.releaseAllowedIP(siteId, cidr)
|
||||
|
||||
// Build the list of IPs this peer currently owns for the replace operation
|
||||
ownedIPs := pm.getOwnedAllowedIPs(siteId)
|
||||
// Also include the server IP which is always owned
|
||||
serverIP := strings.Split(peer.ServerIP, "/")[0] + "/32"
|
||||
hasServerIP := false
|
||||
for _, ip := range ownedIPs {
|
||||
if ip == serverIP {
|
||||
hasServerIP = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasServerIP {
|
||||
ownedIPs = append([]string{serverIP}, ownedIPs...)
|
||||
}
|
||||
|
||||
// Update WireGuard for this peer using replace_allowed_ips
|
||||
if err := RemoveAllowedIP(pm.device, peer.PublicKey, ownedIPs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If another peer was promoted to owner, add the IP to their WireGuard config
|
||||
if promoted && newOwner >= 0 {
|
||||
if newOwnerPeer, exists := pm.peers[newOwner]; exists {
|
||||
if err := AddAllowedIP(pm.device, newOwnerPeer.PublicKey, cidr); err != nil {
|
||||
logger.Error("Failed to promote peer %d for IP %s: %v", newOwner, cidr, err)
|
||||
} else {
|
||||
logger.Info("Promoted peer %d to owner of IP %s", newOwner, cidr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRemoteSubnet adds an IP (subnet) to the allowed IPs list of a peer
|
||||
func (pm *PeerManager) AddRemoteSubnet(siteId int, cidr string) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
// Check if IP already exists in RemoteSubnets
|
||||
for _, subnet := range peer.RemoteSubnets {
|
||||
if subnet == cidr {
|
||||
return nil // Already exists
|
||||
}
|
||||
}
|
||||
|
||||
peer.RemoteSubnets = append(peer.RemoteSubnets, cidr)
|
||||
pm.peers[siteId] = peer // Save before calling addAllowedIp which reads from pm.peers
|
||||
|
||||
// Add to allowed IPs
|
||||
if err := pm.addAllowedIp(siteId, cidr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add route
|
||||
if err := network.AddRoutes([]string{cidr}, pm.interfaceName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRemoteSubnet removes an IP (subnet) from the allowed IPs list of a peer
|
||||
func (pm *PeerManager) RemoveRemoteSubnet(siteId int, ip string) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
found := false
|
||||
|
||||
// Remove from RemoteSubnets
|
||||
newSubnets := make([]string, 0, len(peer.RemoteSubnets))
|
||||
for _, subnet := range peer.RemoteSubnets {
|
||||
if subnet == ip {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newSubnets = append(newSubnets, subnet)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil // Not found
|
||||
}
|
||||
|
||||
peer.RemoteSubnets = newSubnets
|
||||
pm.peers[siteId] = peer // Save before calling removeAllowedIp which reads from pm.peers
|
||||
|
||||
// Remove from allowed IPs (this also handles promotion of other peers)
|
||||
if err := pm.removeAllowedIp(siteId, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if any other peer still has this subnet before removing the route
|
||||
subnetStillInUse := false
|
||||
for otherSiteId, otherPeer := range pm.peers {
|
||||
if otherSiteId == siteId {
|
||||
continue // Skip the current peer (already updated above)
|
||||
}
|
||||
for _, subnet := range otherPeer.RemoteSubnets {
|
||||
if subnet == ip {
|
||||
subnetStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if subnetStillInUse {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Only remove route if no other peer needs it
|
||||
if !subnetStillInUse {
|
||||
if err := network.RemoveRoutes([]string{ip}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddAlias adds an alias to a peer
|
||||
func (pm *PeerManager) AddAlias(siteId int, alias Alias) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
peer.Aliases = append(peer.Aliases, alias)
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address != nil {
|
||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
// Add an allowed IP for the alias
|
||||
if err := pm.addAllowedIp(siteId, alias.AliasAddress+"/32"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAlias removes an alias from a peer
|
||||
func (pm *PeerManager) RemoveAlias(siteId int, aliasName string) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
var aliasToRemove *Alias
|
||||
newAliases := make([]Alias, 0, len(peer.Aliases))
|
||||
for _, a := range peer.Aliases {
|
||||
if a.Alias == aliasName {
|
||||
aliasToRemove = &a
|
||||
continue
|
||||
}
|
||||
newAliases = append(newAliases, a)
|
||||
}
|
||||
|
||||
if aliasToRemove != nil {
|
||||
address := net.ParseIP(aliasToRemove.AliasAddress)
|
||||
if address != nil {
|
||||
pm.dnsProxy.RemoveDNSRecord(aliasName, address)
|
||||
}
|
||||
}
|
||||
|
||||
peer.Aliases = newAliases
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
// Check if any other alias is still using this IP address before removing from allowed IPs
|
||||
ipStillInUse := false
|
||||
aliasIP := aliasToRemove.AliasAddress + "/32"
|
||||
for _, a := range newAliases {
|
||||
if a.AliasAddress+"/32" == aliasIP {
|
||||
ipStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Only remove the allowed IP if no other alias is using it
|
||||
if !ipStillInUse {
|
||||
if err := pm.removeAllowedIp(siteId, aliasIP); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelayPeer handles failover to the relay server when a peer is disconnected
|
||||
func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string, relayPort uint16) {
|
||||
pm.mu.Lock()
|
||||
peer, exists := pm.peers[siteId]
|
||||
if exists {
|
||||
// Store the relay endpoint
|
||||
peer.RelayEndpoint = relayEndpoint
|
||||
pm.peers[siteId] = peer
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for IPv6 and format the endpoint correctly
|
||||
formattedEndpoint := relayEndpoint
|
||||
if strings.Contains(relayEndpoint, ":") {
|
||||
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
|
||||
}
|
||||
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820 // fall back to 21820 for backward compatibility
|
||||
}
|
||||
|
||||
// Update only the endpoint for this peer (update_only preserves other settings)
|
||||
wgConfig := fmt.Sprintf(`public_key=%s
|
||||
update_only=true
|
||||
endpoint=%s:%d`, util.FixKey(peer.PublicKey), formattedEndpoint, relayPort)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
logger.Error("Failed to configure WireGuard device: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the peer as relayed in the monitor
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.MarkPeerRelayed(siteId, true)
|
||||
}
|
||||
|
||||
logger.Info("Adjusted peer %d to point to relay!\n", siteId)
|
||||
}
|
||||
|
||||
// performRapidInitialTest performs a rapid holepunch test for a newly added peer.
|
||||
// If the test fails, it immediately requests relay to minimize connection delay.
|
||||
// This runs in a goroutine to avoid blocking AddPeer.
|
||||
func (pm *PeerManager) performRapidInitialTest(siteId int, endpoint string) {
|
||||
if pm.peerMonitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Perform rapid test - this takes ~1-2 seconds max
|
||||
holepunchViable := pm.peerMonitor.RapidTestPeer(siteId, endpoint)
|
||||
|
||||
if !holepunchViable {
|
||||
// Holepunch failed rapid test, request relay immediately
|
||||
logger.Info("Rapid test failed for site %d, requesting relay", siteId)
|
||||
if err := pm.peerMonitor.RequestRelay(siteId); err != nil {
|
||||
logger.Error("Failed to request relay for site %d: %v", siteId, err)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Rapid test passed for site %d, using direct connection", siteId)
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the peer monitor
|
||||
func (pm *PeerManager) Start() {
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.Start()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the peer monitor
|
||||
func (pm *PeerManager) Stop() {
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the peer monitor and cleans up resources
|
||||
func (pm *PeerManager) Close() {
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.Close()
|
||||
pm.peerMonitor = nil
|
||||
}
|
||||
}
|
||||
|
||||
// MarkPeerRelayed marks a peer as currently using relay
|
||||
func (pm *PeerManager) MarkPeerRelayed(siteID int, relayed bool) {
|
||||
pm.mu.Lock()
|
||||
if peer, exists := pm.peers[siteID]; exists {
|
||||
if relayed {
|
||||
// We're being relayed, store the current endpoint as the original
|
||||
// (RelayEndpoint is set by HandleFailover)
|
||||
} else {
|
||||
// Clear relay endpoint when switching back to direct
|
||||
peer.RelayEndpoint = ""
|
||||
pm.peers[siteID] = peer
|
||||
}
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.MarkPeerRelayed(siteID, relayed)
|
||||
}
|
||||
}
|
||||
|
||||
// UnRelayPeer switches a peer from relay back to direct connection
|
||||
func (pm *PeerManager) UnRelayPeer(siteId int, endpoint string) error {
|
||||
pm.mu.Lock()
|
||||
peer, exists := pm.peers[siteId]
|
||||
if exists {
|
||||
// Store the relay endpoint
|
||||
peer.Endpoint = endpoint
|
||||
pm.peers[siteId] = peer
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update WireGuard to use the direct endpoint
|
||||
wgConfig := fmt.Sprintf(`public_key=%s
|
||||
update_only=true
|
||||
endpoint=%s`, util.FixKey(peer.PublicKey), endpoint)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
logger.Error("Failed to switch peer %d to direct connection: %v", siteId, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark as not relayed in monitor
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.MarkPeerRelayed(siteId, false)
|
||||
}
|
||||
|
||||
logger.Info("Switched peer %d back to direct connection at %s", siteId, endpoint)
|
||||
return nil
|
||||
}
|
||||
1000
peers/monitor/monitor.go
Normal file
1000
peers/monitor/monitor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
package wgtester
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -26,17 +26,30 @@ const (
|
||||
|
||||
// Client handles checking connectivity to a server
|
||||
type Client struct {
|
||||
conn *net.UDPConn
|
||||
conn net.Conn
|
||||
serverAddr string
|
||||
monitorRunning bool
|
||||
monitorLock sync.Mutex
|
||||
connLock sync.Mutex // Protects connection operations
|
||||
shutdownCh chan struct{}
|
||||
updateCh chan struct{}
|
||||
packetInterval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
dialer Dialer
|
||||
|
||||
// Exponential backoff fields
|
||||
defaultMinInterval time.Duration // Default minimum interval (initial)
|
||||
defaultMaxInterval time.Duration // Default maximum interval (cap for backoff)
|
||||
minInterval time.Duration // Minimum interval (initial)
|
||||
maxInterval time.Duration // Maximum interval (cap for backoff)
|
||||
backoffMultiplier float64 // Multiplier for each stable check
|
||||
stableCountToBackoff int // Number of stable checks before backing off
|
||||
}
|
||||
|
||||
// Dialer is a function that creates a connection
|
||||
type Dialer func(network, addr string) (net.Conn, error)
|
||||
|
||||
// ConnectionStatus represents the current connection state
|
||||
type ConnectionStatus struct {
|
||||
Connected bool
|
||||
@@ -44,29 +57,75 @@ type ConnectionStatus struct {
|
||||
}
|
||||
|
||||
// NewClient creates a new connection test client
|
||||
func NewClient(serverAddr string) (*Client, error) {
|
||||
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
|
||||
return &Client{
|
||||
serverAddr: serverAddr,
|
||||
shutdownCh: make(chan struct{}),
|
||||
packetInterval: 2 * time.Second,
|
||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||
maxAttempts: 3, // Default max attempts
|
||||
serverAddr: serverAddr,
|
||||
shutdownCh: make(chan struct{}),
|
||||
updateCh: make(chan struct{}, 1),
|
||||
packetInterval: 2 * time.Second,
|
||||
defaultMinInterval: 2 * time.Second,
|
||||
defaultMaxInterval: 30 * time.Second,
|
||||
minInterval: 2 * time.Second,
|
||||
maxInterval: 30 * time.Second,
|
||||
backoffMultiplier: 1.5,
|
||||
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
|
||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||
maxAttempts: 3, // Default max attempts
|
||||
dialer: dialer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetPacketInterval changes how frequently packets are sent in monitor mode
|
||||
func (c *Client) SetPacketInterval(interval time.Duration) {
|
||||
c.packetInterval = interval
|
||||
func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) {
|
||||
c.monitorLock.Lock()
|
||||
c.packetInterval = minInterval
|
||||
c.minInterval = minInterval
|
||||
c.maxInterval = maxInterval
|
||||
updateCh := c.updateCh
|
||||
monitorRunning := c.monitorRunning
|
||||
c.monitorLock.Unlock()
|
||||
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if monitorRunning && updateCh != nil {
|
||||
select {
|
||||
case updateCh <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeout changes the timeout for waiting for responses
|
||||
func (c *Client) SetTimeout(timeout time.Duration) {
|
||||
c.timeout = timeout
|
||||
func (c *Client) ResetPacketInterval() {
|
||||
c.monitorLock.Lock()
|
||||
c.packetInterval = c.defaultMinInterval
|
||||
c.minInterval = c.defaultMinInterval
|
||||
c.maxInterval = c.defaultMaxInterval
|
||||
updateCh := c.updateCh
|
||||
monitorRunning := c.monitorRunning
|
||||
c.monitorLock.Unlock()
|
||||
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if monitorRunning && updateCh != nil {
|
||||
select {
|
||||
case updateCh <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||
func (c *Client) SetMaxAttempts(attempts int) {
|
||||
c.maxAttempts = attempts
|
||||
// UpdateServerAddr updates the server address and resets the connection
|
||||
func (c *Client) UpdateServerAddr(serverAddr string) {
|
||||
c.connLock.Lock()
|
||||
defer c.connLock.Unlock()
|
||||
|
||||
// Close existing connection if any
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
c.serverAddr = serverAddr
|
||||
}
|
||||
|
||||
// Close cleans up client resources
|
||||
@@ -91,12 +150,14 @@ func (c *Client) ensureConnection() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
var err error
|
||||
if c.dialer != nil {
|
||||
c.conn, err = c.dialer("udp", c.serverAddr)
|
||||
} else {
|
||||
// Fallback to standard net.Dial
|
||||
c.conn, err = net.Dial("udp", c.serverAddr)
|
||||
}
|
||||
|
||||
c.conn, err = net.DialUDP("udp", nil, serverAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -104,9 +165,10 @@ func (c *Client) ensureConnection() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestConnection checks if the connection to the server is working
|
||||
// TestPeerConnection checks if the connection to the server is working
|
||||
// Returns true if connected, false otherwise
|
||||
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
|
||||
// logger.Debug("wgtester: testing connection to peer %s", c.serverAddr)
|
||||
if err := c.ensureConnection(); err != nil {
|
||||
logger.Warn("Failed to ensure connection: %v", err)
|
||||
return false, 0
|
||||
@@ -117,6 +179,9 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
|
||||
packet[4] = packetTypeRequest
|
||||
|
||||
// Reusable response buffer
|
||||
responseBuffer := make([]byte, packetSize)
|
||||
|
||||
// Send multiple attempts as specified
|
||||
for attempt := 0; attempt < c.maxAttempts; attempt++ {
|
||||
select {
|
||||
@@ -136,20 +201,17 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
|
||||
_, err := c.conn.Write(packet)
|
||||
if err != nil {
|
||||
c.connLock.Unlock()
|
||||
logger.Info("Error sending packet: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Debug("Successfully sent monitor packet")
|
||||
|
||||
// Set read deadline
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||
|
||||
// Wait for response
|
||||
responseBuffer := make([]byte, packetSize)
|
||||
n, err := c.conn.Read(responseBuffer)
|
||||
c.connLock.Unlock()
|
||||
|
||||
@@ -190,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
return c.TestConnection(ctx)
|
||||
return c.TestPeerConnection(ctx)
|
||||
}
|
||||
|
||||
// MonitorCallback is the function type for connection status change callbacks
|
||||
@@ -217,28 +279,61 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
|
||||
go func() {
|
||||
var lastConnected bool
|
||||
firstRun := true
|
||||
stableCount := 0
|
||||
currentInterval := c.minInterval
|
||||
|
||||
ticker := time.NewTicker(c.packetInterval)
|
||||
defer ticker.Stop()
|
||||
timer := time.NewTimer(currentInterval)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.shutdownCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
case <-c.updateCh:
|
||||
// Interval settings changed, reset to minimum
|
||||
c.monitorLock.Lock()
|
||||
currentInterval = c.minInterval
|
||||
c.monitorLock.Unlock()
|
||||
|
||||
// Reset backoff state
|
||||
stableCount = 0
|
||||
|
||||
timer.Reset(currentInterval)
|
||||
logger.Debug("Packet interval updated, reset to %v", currentInterval)
|
||||
case <-timer.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
connected, rtt := c.TestConnection(ctx)
|
||||
connected, rtt := c.TestPeerConnection(ctx)
|
||||
cancel()
|
||||
|
||||
statusChanged := connected != lastConnected
|
||||
|
||||
// Callback if status changed or it's the first check
|
||||
if connected != lastConnected || firstRun {
|
||||
if statusChanged || firstRun {
|
||||
callback(ConnectionStatus{
|
||||
Connected: connected,
|
||||
RTT: rtt,
|
||||
})
|
||||
lastConnected = connected
|
||||
firstRun = false
|
||||
// Reset backoff on status change
|
||||
stableCount = 0
|
||||
currentInterval = c.minInterval
|
||||
} else {
|
||||
// Status is stable, increment counter
|
||||
stableCount++
|
||||
|
||||
// Apply exponential backoff after stable threshold
|
||||
if stableCount >= c.stableCountToBackoff {
|
||||
newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier)
|
||||
if newInterval > c.maxInterval {
|
||||
newInterval = c.maxInterval
|
||||
}
|
||||
currentInterval = newInterval
|
||||
}
|
||||
}
|
||||
|
||||
// Reset timer with current interval
|
||||
timer.Reset(currentInterval)
|
||||
}
|
||||
}
|
||||
}()
|
||||
160
peers/peer.go
Normal file
160
peers/peer.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package peers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
|
||||
var endpoint string
|
||||
if relay && siteConfig.RelayEndpoint != "" {
|
||||
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
|
||||
} else {
|
||||
endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||
}
|
||||
siteHost, err := util.ResolveDomain(endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||
}
|
||||
|
||||
// Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP
|
||||
allowedIp := strings.Split(siteConfig.ServerIP, "/")
|
||||
if len(allowedIp) > 1 {
|
||||
allowedIp[1] = "32"
|
||||
} else {
|
||||
allowedIp = append(allowedIp, "32")
|
||||
}
|
||||
allowedIpStr := strings.Join(allowedIp, "/")
|
||||
|
||||
// Collect all allowed IPs in a slice
|
||||
var allowedIPs []string
|
||||
allowedIPs = append(allowedIPs, allowedIpStr)
|
||||
|
||||
// Use AllowedIps if available, otherwise fall back to RemoteSubnets for backwards compatibility
|
||||
subnetsToAdd := siteConfig.AllowedIps
|
||||
|
||||
// If we have anything to add, process them
|
||||
if len(subnetsToAdd) > 0 {
|
||||
// Add each subnet
|
||||
for _, subnet := range subnetsToAdd {
|
||||
subnet = strings.TrimSpace(subnet)
|
||||
if subnet != "" {
|
||||
allowedIPs = append(allowedIPs, subnet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct WireGuard config for this peer
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", util.FixKey(privateKey.String())))
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(siteConfig.PublicKey)))
|
||||
|
||||
// Add each allowed IP separately
|
||||
for _, allowedIP := range allowedIPs {
|
||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||
}
|
||||
|
||||
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
||||
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive))
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Configuring peer with config: %s", config)
|
||||
|
||||
err = dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeer removes a peer from the WireGuard device
|
||||
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
|
||||
// Construct WireGuard config to remove the peer
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||
configBuilder.WriteString("remove=true\n")
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Removing peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddAllowedIP adds a single allowed IP to an existing peer without reconfiguring the entire peer
|
||||
func AddAllowedIP(dev *device.Device, publicKey string, allowedIP string) error {
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||
configBuilder.WriteString("update_only=true\n")
|
||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Adding allowed IP to peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add allowed IP to WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAllowedIP removes a single allowed IP from an existing peer by replacing the allowed IPs list
|
||||
// This requires providing all the allowed IPs that should remain after removal
|
||||
func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs []string) error {
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||
configBuilder.WriteString("update_only=true\n")
|
||||
configBuilder.WriteString("replace_allowed_ips=true\n")
|
||||
|
||||
// Add each remaining allowed IP
|
||||
for _, allowedIP := range remainingAllowedIPs {
|
||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||
}
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Removing allowed IP from peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove allowed IP from WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it
|
||||
func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error {
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||
configBuilder.WriteString("update_only=true\n")
|
||||
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval))
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Updating persistent keepalive for peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatEndpoint(endpoint string) string {
|
||||
if strings.Contains(endpoint, ":") {
|
||||
return endpoint
|
||||
}
|
||||
return endpoint + ":51820"
|
||||
}
|
||||
64
peers/types.go
Normal file
64
peers/types.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package peers
|
||||
|
||||
// PeerAction represents a request to add, update, or remove a peer
|
||||
type PeerAction struct {
|
||||
Action string `json:"action"` // "add", "update", or "remove"
|
||||
SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information
|
||||
}
|
||||
|
||||
// UpdatePeerData represents the data needed to update a peer
|
||||
type SiteConfig struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
RelayEndpoint string `json:"relayEndpoint,omitempty"`
|
||||
PublicKey string `json:"publicKey,omitempty"`
|
||||
ServerIP string `json:"serverIP,omitempty"`
|
||||
ServerPort uint16 `json:"serverPort,omitempty"`
|
||||
RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access
|
||||
AllowedIps []string `json:"allowedIps,omitempty"` // optional, array of allowed IPs for the peer
|
||||
Aliases []Alias `json:"aliases,omitempty"` // optional, array of alias configurations
|
||||
}
|
||||
|
||||
type Alias struct {
|
||||
Alias string `json:"alias"` // the alias name
|
||||
AliasAddress string `json:"aliasAddress"` // the alias IP address
|
||||
}
|
||||
|
||||
// RemovePeer represents the data needed to remove a peer
|
||||
type PeerRemove struct {
|
||||
SiteId int `json:"siteId"`
|
||||
}
|
||||
|
||||
type RelayPeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
RelayEndpoint string `json:"relayEndpoint"`
|
||||
RelayPort uint16 `json:"relayPort"`
|
||||
}
|
||||
|
||||
type UnRelayPeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
// PeerAdd represents the data needed to add remote subnets to a peer
|
||||
type PeerAdd struct {
|
||||
SiteId int `json:"siteId"`
|
||||
RemoteSubnets []string `json:"remoteSubnets"` // subnets to add
|
||||
Aliases []Alias `json:"aliases,omitempty"` // aliases to add
|
||||
}
|
||||
|
||||
// RemovePeerData represents the data needed to remove remote subnets from a peer
|
||||
type RemovePeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
RemoteSubnets []string `json:"remoteSubnets"` // subnets to remove
|
||||
Aliases []Alias `json:"aliases,omitempty"` // aliases to remove
|
||||
}
|
||||
|
||||
type UpdatePeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
OldRemoteSubnets []string `json:"oldRemoteSubnets"` // old list of remote subnets
|
||||
NewRemoteSubnets []string `json:"newRemoteSubnets"` // new list of remote subnets
|
||||
OldAliases []Alias `json:"oldAliases,omitempty"` // old list of aliases
|
||||
NewAliases []Alias `json:"newAliases,omitempty"` // new list of aliases
|
||||
}
|
||||
@@ -48,3 +48,7 @@ func setupWindowsEventLog() {
|
||||
func watchLogFile(end bool) error {
|
||||
return fmt.Errorf("watching log file is only available on Windows")
|
||||
}
|
||||
|
||||
func showServiceConfig() {
|
||||
fmt.Println("Service configuration is only available on Windows")
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -69,12 +70,6 @@ func loadServiceArgs() ([]string, error) {
|
||||
return nil, fmt.Errorf("failed to read service args: %v", err)
|
||||
}
|
||||
|
||||
// delete the file after reading
|
||||
err = os.Remove(argsPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete service args file: %v", err)
|
||||
}
|
||||
|
||||
var args []string
|
||||
err = json.Unmarshal(data, &args)
|
||||
if err != nil {
|
||||
@@ -95,7 +90,7 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes
|
||||
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
|
||||
changes <- svc.Status{State: svc.StartPending}
|
||||
|
||||
s.elog.Info(1, "Service Execute called, starting main logic")
|
||||
s.elog.Info(1, fmt.Sprintf("Service Execute called with args: %v", args))
|
||||
|
||||
// Load saved service arguments
|
||||
savedArgs, err := loadServiceArgs()
|
||||
@@ -104,7 +99,42 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes
|
||||
// Continue with empty args if loading fails
|
||||
savedArgs = []string{}
|
||||
}
|
||||
s.args = savedArgs
|
||||
s.elog.Info(1, fmt.Sprintf("Loaded saved service args: %v", savedArgs))
|
||||
|
||||
// Combine service start args with saved args, giving priority to service start args
|
||||
// Note: When the service is started via SCM, args[0] is the service name
|
||||
// When started via s.Start(args...), the args passed are exactly what we provide
|
||||
finalArgs := []string{}
|
||||
|
||||
// Check if we have args passed directly to Execute (from s.Start())
|
||||
if len(args) > 0 {
|
||||
// The first arg from SCM is the service name, but when we call s.Start(args...),
|
||||
// the args we pass become args[1:] in Execute. However, if started by SCM without
|
||||
// args, args[0] will be the service name.
|
||||
// We need to check if args[0] looks like the service name or a flag
|
||||
if len(args) == 1 && args[0] == serviceName {
|
||||
// Only service name, no actual args
|
||||
s.elog.Info(1, "Only service name in args, checking saved args")
|
||||
} else if len(args) > 1 && args[0] == serviceName {
|
||||
// Service name followed by actual args
|
||||
finalArgs = append(finalArgs, args[1:]...)
|
||||
s.elog.Info(1, fmt.Sprintf("Using service start parameters (after service name): %v", finalArgs))
|
||||
} else {
|
||||
// Args don't start with service name, use them all
|
||||
// This happens when args are passed via s.Start(args...)
|
||||
finalArgs = append(finalArgs, args...)
|
||||
s.elog.Info(1, fmt.Sprintf("Using service start parameters (direct): %v", finalArgs))
|
||||
}
|
||||
}
|
||||
|
||||
// If no service start parameters, use saved args
|
||||
if len(finalArgs) == 0 && len(savedArgs) > 0 {
|
||||
finalArgs = savedArgs
|
||||
s.elog.Info(1, fmt.Sprintf("Using saved service args: %v", finalArgs))
|
||||
}
|
||||
|
||||
s.elog.Info(1, fmt.Sprintf("Final args to use: %v", finalArgs))
|
||||
s.args = finalArgs
|
||||
|
||||
// Start the main olm functionality
|
||||
olmDone := make(chan struct{})
|
||||
@@ -151,6 +181,9 @@ func (s *olmService) runOlm() {
|
||||
// Create a context that can be cancelled when the service stops
|
||||
s.ctx, s.stop = context.WithCancel(context.Background())
|
||||
|
||||
// Create a separate context for programmatic shutdown (e.g., via API exit)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Setup logging for service mode
|
||||
s.elog.Info(1, "Starting Olm main logic")
|
||||
|
||||
@@ -165,7 +198,8 @@ func (s *olmService) runOlm() {
|
||||
}()
|
||||
|
||||
// Call the main olm function with stored arguments
|
||||
runOlmMainWithArgs(s.ctx, s.args)
|
||||
// Use s.ctx as the signal context since the service manages shutdown
|
||||
runOlmMainWithArgs(ctx, cancel, s.ctx, s.args)
|
||||
}()
|
||||
|
||||
// Wait for either context cancellation or main logic completion
|
||||
@@ -309,12 +343,15 @@ func removeService() error {
|
||||
}
|
||||
|
||||
func startService(args []string) error {
|
||||
// Save the service arguments before starting
|
||||
if len(args) > 0 {
|
||||
err := saveServiceArgs(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save service args: %v", err)
|
||||
}
|
||||
fmt.Printf("Starting service with args: %v\n", args)
|
||||
|
||||
// Always save the service arguments so they can be loaded on service restart
|
||||
err := saveServiceArgs(args)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to save service args: %v\n", err)
|
||||
// Continue anyway, args will still be passed directly
|
||||
} else {
|
||||
fmt.Printf("Saved service args to: %s\n", getServiceArgsPath())
|
||||
}
|
||||
|
||||
m, err := mgr.Connect()
|
||||
@@ -329,7 +366,9 @@ func startService(args []string) error {
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
err = s.Start()
|
||||
// Pass arguments directly to the service start call
|
||||
// Note: These args will appear in Execute() after the service name
|
||||
err = s.Start(args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start service: %v", err)
|
||||
}
|
||||
@@ -379,17 +418,12 @@ func debugService(args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// fmt.Printf("Starting service in debug mode...\n")
|
||||
|
||||
// Start the service
|
||||
err := startService([]string{}) // Pass empty args since we already saved them
|
||||
// Start the service with the provided arguments
|
||||
err := startService(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start service: %v", err)
|
||||
}
|
||||
|
||||
// fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n")
|
||||
// fmt.Printf("================================================================================\n")
|
||||
|
||||
// Watch the log file
|
||||
return watchLogFile(true)
|
||||
}
|
||||
@@ -509,11 +543,89 @@ func getServiceStatus() (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// showServiceConfig displays current saved service configuration
|
||||
func showServiceConfig() {
|
||||
configPath := getServiceArgsPath()
|
||||
fmt.Printf("Service configuration file: %s\n", configPath)
|
||||
|
||||
args, err := loadServiceArgs()
|
||||
if err != nil {
|
||||
fmt.Printf("No saved configuration found or error loading: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
fmt.Println("No saved service arguments found")
|
||||
} else {
|
||||
fmt.Printf("Saved service arguments: %v\n", args)
|
||||
}
|
||||
}
|
||||
|
||||
func isWindowsService() bool {
|
||||
isWindowsService, err := svc.IsWindowsService()
|
||||
return err == nil && isWindowsService
|
||||
}
|
||||
|
||||
// rotateLogFile handles daily log rotation
|
||||
func rotateLogFile(logDir string, logFile string) error {
|
||||
// Get current log file info
|
||||
info, err := os.Stat(logFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // No current log file to rotate
|
||||
}
|
||||
return fmt.Errorf("failed to stat log file: %v", err)
|
||||
}
|
||||
|
||||
// Check if log file is from today
|
||||
now := time.Now()
|
||||
fileTime := info.ModTime()
|
||||
|
||||
// If the log file is from today, no rotation needed
|
||||
if now.Year() == fileTime.Year() && now.YearDay() == fileTime.YearDay() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create rotated filename with date
|
||||
rotatedName := fmt.Sprintf("olm-%s.log", fileTime.Format("2006-01-02"))
|
||||
rotatedPath := filepath.Join(logDir, rotatedName)
|
||||
|
||||
// Rename current log file to dated filename
|
||||
err = os.Rename(logFile, rotatedPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to rotate log file: %v", err)
|
||||
}
|
||||
|
||||
// Clean up old log files (keep last 30 days)
|
||||
cleanupOldLogFiles(logDir, 30)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOldLogFiles removes log files older than specified days
|
||||
func cleanupOldLogFiles(logDir string, daysToKeep int) {
|
||||
cutoff := time.Now().AddDate(0, 0, -daysToKeep)
|
||||
|
||||
files, err := os.ReadDir(logDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if !file.IsDir() && strings.HasPrefix(file.Name(), "olm-") && strings.HasSuffix(file.Name(), ".log") {
|
||||
filePath := filepath.Join(logDir, file.Name())
|
||||
info, err := file.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if info.ModTime().Before(cutoff) {
|
||||
os.Remove(filePath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupWindowsEventLog() {
|
||||
// Create log directory if it doesn't exist
|
||||
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
|
||||
@@ -524,6 +636,14 @@ func setupWindowsEventLog() {
|
||||
}
|
||||
|
||||
logFile := filepath.Join(logDir, "olm.log")
|
||||
|
||||
// Rotate log file if needed
|
||||
err = rotateLogFile(logDir, logFile)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to rotate log file: %v\n", err)
|
||||
// Continue anyway to create new log file
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to open log file: %v\n", err)
|
||||
|
||||
35
unix.go
35
unix.go
@@ -1,35 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
|
||||
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(int(fd), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "")
|
||||
return tun.CreateTUNFromFile(file, mtuInt)
|
||||
}
|
||||
func uapiOpen(interfaceName string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(interfaceName)
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||
}
|
||||
920
websocket/client.go
Normal file
920
websocket/client.go
Normal file
@@ -0,0 +1,920 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"software.sslmate.com/src/go-pkcs12"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// AuthError represents an authentication/authorization error (401/403)
|
||||
type AuthError struct {
|
||||
StatusCode int
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *AuthError) Error() string {
|
||||
return fmt.Sprintf("authentication error (status %d): %s", e.StatusCode, e.Message)
|
||||
}
|
||||
|
||||
// IsAuthError checks if an error is an authentication error
|
||||
func IsAuthError(err error) bool {
|
||||
_, ok := err.(*AuthError)
|
||||
return ok
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
Data struct {
|
||||
Token string `json:"token"`
|
||||
ExitNodes []ExitNode `json:"exitNodes"`
|
||||
ServerVersion string `json:"serverVersion"`
|
||||
} `json:"data"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
RelayPort uint16 `json:"relayPort"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
SiteIds []int `json:"siteIds"`
|
||||
}
|
||||
|
||||
type WSMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
ConfigVersion int `json:"configVersion,omitempty"`
|
||||
}
|
||||
|
||||
// this is not json anymore
|
||||
type Config struct {
|
||||
ID string
|
||||
Secret string
|
||||
Endpoint string
|
||||
TlsClientCert string // legacy PKCS12 file path
|
||||
UserToken string // optional user token for websocket authentication
|
||||
OrgID string // optional organization ID for websocket authentication
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
config *Config
|
||||
conn *websocket.Conn
|
||||
baseURL string
|
||||
handlers map[string]MessageHandler
|
||||
done chan struct{}
|
||||
handlersMux sync.RWMutex
|
||||
reconnectInterval time.Duration
|
||||
isConnected bool
|
||||
isDisconnected bool // Flag to track if client is intentionally disconnected
|
||||
reconnectMux sync.RWMutex
|
||||
pingInterval time.Duration
|
||||
pingTimeout time.Duration
|
||||
onConnect func() error
|
||||
onTokenUpdate func(token string, exitNodes []ExitNode)
|
||||
onAuthError func(statusCode int, message string) // Callback for auth errors
|
||||
writeMux sync.Mutex
|
||||
clientType string // Type of client (e.g., "newt", "olm")
|
||||
tlsConfig TLSConfig
|
||||
configNeedsSave bool // Flag to track if config needs to be saved
|
||||
configVersion int // Latest config version received from server
|
||||
configVersionMux sync.RWMutex
|
||||
token string // Cached authentication token
|
||||
exitNodes []ExitNode // Cached exit nodes from token response
|
||||
tokenMux sync.RWMutex // Protects token and exitNodes
|
||||
forceNewToken bool // Flag to force fetching a new token on next connection
|
||||
processingMessage bool // Flag to track if a message is currently being processed
|
||||
processingMux sync.RWMutex // Protects processingMessage
|
||||
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
|
||||
getPingData func() map[string]any // Callback to get additional ping data
|
||||
pingStarted bool // Flag to track if ping monitor has been started
|
||||
pingStartedMux sync.Mutex // Protects pingStarted
|
||||
pingDone chan struct{} // Channel to stop the ping monitor independently
|
||||
}
|
||||
|
||||
type ClientOption func(*Client)
|
||||
|
||||
type MessageHandler func(message WSMessage)
|
||||
|
||||
// TLSConfig holds TLS configuration options
|
||||
type TLSConfig struct {
|
||||
// New separate certificate support
|
||||
ClientCertFile string
|
||||
ClientKeyFile string
|
||||
CAFiles []string
|
||||
|
||||
// Existing PKCS12 support (deprecated)
|
||||
PKCS12File string
|
||||
}
|
||||
|
||||
// WithBaseURL sets the base URL for the client
|
||||
func WithBaseURL(url string) ClientOption {
|
||||
return func(c *Client) {
|
||||
c.baseURL = url
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSConfig sets the TLS configuration for the client
|
||||
func WithTLSConfig(config TLSConfig) ClientOption {
|
||||
return func(c *Client) {
|
||||
c.tlsConfig = config
|
||||
// For backward compatibility, also set the legacy field
|
||||
if config.PKCS12File != "" {
|
||||
c.config.TlsClientCert = config.PKCS12File
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithPingDataProvider sets a callback to provide additional data for ping messages
|
||||
func WithPingDataProvider(fn func() map[string]any) ClientOption {
|
||||
return func(c *Client) {
|
||||
c.getPingData = fn
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) OnConnect(callback func() error) {
|
||||
c.onConnect = callback
|
||||
}
|
||||
|
||||
func (c *Client) OnTokenUpdate(callback func(token string, exitNodes []ExitNode)) {
|
||||
c.onTokenUpdate = callback
|
||||
}
|
||||
|
||||
func (c *Client) OnAuthError(callback func(statusCode int, message string)) {
|
||||
c.onAuthError = callback
|
||||
}
|
||||
|
||||
// NewClient creates a new websocket client
|
||||
func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
|
||||
config := &Config{
|
||||
ID: ID,
|
||||
Secret: secret,
|
||||
Endpoint: endpoint,
|
||||
UserToken: userToken,
|
||||
OrgID: orgId,
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
config: config,
|
||||
baseURL: endpoint, // default value
|
||||
handlers: make(map[string]MessageHandler),
|
||||
done: make(chan struct{}),
|
||||
reconnectInterval: 3 * time.Second,
|
||||
isConnected: false,
|
||||
pingInterval: pingInterval,
|
||||
pingTimeout: pingTimeout,
|
||||
clientType: "olm",
|
||||
pingDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Apply options before loading config
|
||||
for _, opt := range opts {
|
||||
if opt == nil {
|
||||
continue
|
||||
}
|
||||
opt(client)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetConfig() *Config {
|
||||
return c.config
|
||||
}
|
||||
|
||||
// Connect establishes the WebSocket connection
|
||||
func (c *Client) Connect() error {
|
||||
if c.isDisconnected {
|
||||
c.isDisconnected = false
|
||||
}
|
||||
go c.connectWithRetry()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the WebSocket connection gracefully
|
||||
func (c *Client) Close() error {
|
||||
// Signal shutdown to all goroutines first
|
||||
select {
|
||||
case <-c.done:
|
||||
// Already closed
|
||||
return nil
|
||||
default:
|
||||
close(c.done)
|
||||
}
|
||||
|
||||
// Set connection status to false
|
||||
c.setConnected(false)
|
||||
|
||||
// Close the WebSocket connection gracefully
|
||||
if c.conn != nil {
|
||||
// Send close message
|
||||
c.writeMux.Lock()
|
||||
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
c.writeMux.Unlock()
|
||||
|
||||
// Close the connection
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
|
||||
func (c *Client) Disconnect() error {
|
||||
c.isDisconnected = true
|
||||
c.setConnected(false)
|
||||
|
||||
// Stop the ping monitor
|
||||
c.stopPingMonitor()
|
||||
|
||||
// Wait for any message currently being processed to complete
|
||||
c.processingWg.Wait()
|
||||
|
||||
if c.conn != nil {
|
||||
c.writeMux.Lock()
|
||||
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
c.writeMux.Unlock()
|
||||
err := c.conn.Close()
|
||||
c.conn = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessage sends a message through the WebSocket connection
|
||||
func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||
if c.isDisconnected || c.conn == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
msg := WSMessage{
|
||||
Type: messageType,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
|
||||
|
||||
c.writeMux.Lock()
|
||||
defer c.writeMux.Unlock()
|
||||
return c.conn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
|
||||
stopChan := make(chan struct{})
|
||||
updateChan := make(chan interface{})
|
||||
var dataMux sync.Mutex
|
||||
currentData := data
|
||||
|
||||
go func() {
|
||||
count := 0
|
||||
|
||||
send := func() {
|
||||
if c.isDisconnected || c.conn == nil {
|
||||
return
|
||||
}
|
||||
err := c.SendMessage(messageType, currentData)
|
||||
if err != nil {
|
||||
logger.Error("websocket: Failed to send message: %v", err)
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
send() // Send immediately
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if maxAttempts != -1 && count >= maxAttempts {
|
||||
logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||
return
|
||||
}
|
||||
dataMux.Lock()
|
||||
send()
|
||||
dataMux.Unlock()
|
||||
case newData := <-updateChan:
|
||||
dataMux.Lock()
|
||||
// Merge newData into currentData if both are maps
|
||||
if currentMap, ok := currentData.(map[string]interface{}); ok {
|
||||
if newMap, ok := newData.(map[string]interface{}); ok {
|
||||
// Update or add keys from newData
|
||||
for key, value := range newMap {
|
||||
currentMap[key] = value
|
||||
}
|
||||
currentData = currentMap
|
||||
} else {
|
||||
// If newData is not a map, replace entirely
|
||||
currentData = newData
|
||||
}
|
||||
} else {
|
||||
// If currentData is not a map, replace entirely
|
||||
currentData = newData
|
||||
}
|
||||
dataMux.Unlock()
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
// Suspend sending if disconnected
|
||||
for c.isDisconnected {
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
close(stopChan)
|
||||
}, func(newData interface{}) {
|
||||
select {
|
||||
case updateChan <- newData:
|
||||
case <-stopChan:
|
||||
// Channel is closed, ignore update
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a specific message type
|
||||
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
||||
c.handlersMux.Lock()
|
||||
defer c.handlersMux.Unlock()
|
||||
c.handlers[messageType] = handler
|
||||
}
|
||||
|
||||
func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
// Parse the base URL to ensure we have the correct hostname
|
||||
baseURL, err := url.Parse(c.baseURL)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to parse base URL: %w", err)
|
||||
}
|
||||
|
||||
// Ensure we have the base URL without trailing slashes
|
||||
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
||||
|
||||
var tlsConfig *tls.Config = nil
|
||||
|
||||
// Use new TLS configuration method
|
||||
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||
tlsConfig, err = c.setupTLS()
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for environment variable to skip TLS verification
|
||||
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
|
||||
if tlsConfig == nil {
|
||||
tlsConfig = &tls.Config{}
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
tokenData := map[string]interface{}{
|
||||
"olmId": c.config.ID,
|
||||
"secret": c.config.Secret,
|
||||
"orgId": c.config.OrgID,
|
||||
}
|
||||
jsonData, err := json.Marshal(tokenData)
|
||||
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to marshal token request data: %w", err)
|
||||
}
|
||||
|
||||
// Create a new request
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token",
|
||||
bytes.NewBuffer(jsonData),
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
||||
|
||||
// print out the request for debugging
|
||||
logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
|
||||
|
||||
// Make the request
|
||||
client := &http.Client{}
|
||||
if tlsConfig != nil {
|
||||
client.Transport = &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to request new token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
|
||||
// Return AuthError for 401/403 status codes
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return "", nil, &AuthError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Message: string(body),
|
||||
}
|
||||
}
|
||||
|
||||
// For other errors (5xx, network issues, etc.), return regular error
|
||||
return "", nil, fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
logger.Error("websocket: Failed to decode token response.")
|
||||
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
if !tokenResp.Success {
|
||||
return "", nil, fmt.Errorf("failed to get token: %s", tokenResp.Message)
|
||||
}
|
||||
|
||||
if tokenResp.Data.Token == "" {
|
||||
return "", nil, fmt.Errorf("received empty token from server")
|
||||
}
|
||||
|
||||
logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
|
||||
|
||||
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
|
||||
}
|
||||
|
||||
func (c *Client) connectWithRetry() {
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
err := c.establishConnection()
|
||||
if err != nil {
|
||||
// Check if this is an auth error (401/403)
|
||||
var authErr *AuthError
|
||||
if errors.As(err, &authErr) {
|
||||
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
|
||||
// Trigger auth error callback if set (this should terminate the tunnel)
|
||||
if c.onAuthError != nil {
|
||||
c.onAuthError(authErr.StatusCode, authErr.Message)
|
||||
}
|
||||
// Continue retrying after auth error
|
||||
time.Sleep(c.reconnectInterval)
|
||||
continue
|
||||
}
|
||||
// For other errors (5xx, network issues), continue retrying
|
||||
logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||
time.Sleep(c.reconnectInterval)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) establishConnection() error {
|
||||
// Get token for authentication - reuse cached token unless forced to get new one
|
||||
c.tokenMux.Lock()
|
||||
needNewToken := c.token == "" || c.forceNewToken
|
||||
if needNewToken {
|
||||
token, exitNodes, err := c.getToken()
|
||||
if err != nil {
|
||||
c.tokenMux.Unlock()
|
||||
return fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
c.token = token
|
||||
c.exitNodes = exitNodes
|
||||
c.forceNewToken = false
|
||||
|
||||
if c.onTokenUpdate != nil {
|
||||
c.onTokenUpdate(token, exitNodes)
|
||||
}
|
||||
}
|
||||
token := c.token
|
||||
c.tokenMux.Unlock()
|
||||
|
||||
// Parse the base URL to determine protocol and hostname
|
||||
baseURL, err := url.Parse(c.baseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse base URL: %w", err)
|
||||
}
|
||||
|
||||
// Determine WebSocket protocol based on HTTP protocol
|
||||
wsProtocol := "wss"
|
||||
if baseURL.Scheme == "http" {
|
||||
wsProtocol = "ws"
|
||||
}
|
||||
|
||||
// Create WebSocket URL
|
||||
wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host)
|
||||
u, err := url.Parse(wsURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse WebSocket URL: %w", err)
|
||||
}
|
||||
|
||||
// Add token to query parameters
|
||||
q := u.Query()
|
||||
q.Set("token", token)
|
||||
q.Set("clientType", c.clientType)
|
||||
if c.config.UserToken != "" {
|
||||
q.Set("userToken", c.config.UserToken)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
// Connect to WebSocket
|
||||
dialer := websocket.DefaultDialer
|
||||
|
||||
// Use new TLS configuration method
|
||||
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||
logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
|
||||
tlsConfig, err := c.setupTLS()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||
}
|
||||
dialer.TLSClientConfig = tlsConfig
|
||||
}
|
||||
|
||||
// Check for environment variable to skip TLS verification for WebSocket connection
|
||||
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
|
||||
if dialer.TLSClientConfig == nil {
|
||||
dialer.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
dialer.TLSClientConfig.InsecureSkipVerify = true
|
||||
logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
conn, resp, err := dialer.Dial(u.String(), nil)
|
||||
if err != nil {
|
||||
// Check if this is an unauthorized error (401)
|
||||
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
|
||||
logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized")
|
||||
// Force getting a new token on next reconnect attempt
|
||||
c.tokenMux.Lock()
|
||||
c.forceNewToken = true
|
||||
c.tokenMux.Unlock()
|
||||
return &AuthError{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
Message: "WebSocket connection unauthorized",
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
c.setConnected(true)
|
||||
|
||||
// Note: ping monitor is NOT started here - it will be started when
|
||||
// StartPingMonitor() is called after registration completes
|
||||
|
||||
// Start the read pump with disconnect detection
|
||||
go c.readPumpWithDisconnectDetection()
|
||||
|
||||
if c.onConnect != nil {
|
||||
if err := c.onConnect(); err != nil {
|
||||
logger.Error("websocket: OnConnect callback failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupTLS configures TLS based on the TLS configuration
|
||||
func (c *Client) setupTLS() (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
// Handle new separate certificate configuration
|
||||
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
|
||||
logger.Info("websocket: Loading separate certificate files for mTLS")
|
||||
logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||
logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||
|
||||
// Load client certificate and key
|
||||
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load client certificate pair: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
|
||||
// Load CA certificates for remote validation if specified
|
||||
if len(c.tlsConfig.CAFiles) > 0 {
|
||||
logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||
caCertPool := x509.NewCertPool()
|
||||
for _, caFile := range c.tlsConfig.CAFiles {
|
||||
caCert, err := os.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA file %s: %w", caFile, err)
|
||||
}
|
||||
|
||||
// Try to parse as PEM first, then DER
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
// If PEM parsing failed, try DER
|
||||
cert, err := x509.ParseCertificate(caCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate from %s: %w", caFile, err)
|
||||
}
|
||||
caCertPool.AddCert(cert)
|
||||
}
|
||||
}
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// Fallback to existing PKCS12 implementation for backward compatibility
|
||||
if c.tlsConfig.PKCS12File != "" {
|
||||
logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
|
||||
return c.setupPKCS12TLS()
|
||||
}
|
||||
|
||||
// Legacy fallback using config.TlsClientCert
|
||||
if c.config.TlsClientCert != "" {
|
||||
logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||
return loadClientCertificate(c.config.TlsClientCert)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// setupPKCS12TLS loads TLS configuration from PKCS12 file
|
||||
func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
|
||||
return loadClientCertificate(c.tlsConfig.PKCS12File)
|
||||
}
|
||||
|
||||
// sendPing sends a single ping message
|
||||
func (c *Client) sendPing() {
|
||||
if c.isDisconnected || c.conn == nil {
|
||||
return
|
||||
}
|
||||
// Skip ping if a message is currently being processed
|
||||
c.processingMux.RLock()
|
||||
isProcessing := c.processingMessage
|
||||
c.processingMux.RUnlock()
|
||||
if isProcessing {
|
||||
logger.Debug("websocket: Skipping ping, message is being processed")
|
||||
return
|
||||
}
|
||||
// Send application-level ping with config version
|
||||
c.configVersionMux.RLock()
|
||||
configVersion := c.configVersion
|
||||
c.configVersionMux.RUnlock()
|
||||
|
||||
pingData := map[string]any{
|
||||
"timestamp": time.Now().Unix(),
|
||||
"userToken": c.config.UserToken,
|
||||
}
|
||||
if c.getPingData != nil {
|
||||
for k, v := range c.getPingData() {
|
||||
pingData[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
pingMsg := WSMessage{
|
||||
Type: "olm/ping",
|
||||
Data: pingData,
|
||||
ConfigVersion: configVersion,
|
||||
}
|
||||
|
||||
logger.Debug("websocket: Sending ping: %+v", pingMsg)
|
||||
|
||||
c.writeMux.Lock()
|
||||
err := c.conn.WriteJSON(pingMsg)
|
||||
c.writeMux.Unlock()
|
||||
if err != nil {
|
||||
// Check if we're shutting down before logging error and reconnecting
|
||||
select {
|
||||
case <-c.done:
|
||||
// Expected during shutdown
|
||||
return
|
||||
default:
|
||||
logger.Error("websocket: Ping failed: %v", err)
|
||||
c.reconnect()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pingMonitor sends pings at a short interval and triggers reconnect on failure
|
||||
func (c *Client) pingMonitor() {
|
||||
ticker := time.NewTicker(c.pingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case <-c.pingDone:
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.sendPing()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartPingMonitor starts the ping monitor goroutine.
|
||||
// This should be called after the client is registered and connected.
|
||||
// It is safe to call multiple times - only the first call will start the monitor.
|
||||
func (c *Client) StartPingMonitor() {
|
||||
c.pingStartedMux.Lock()
|
||||
defer c.pingStartedMux.Unlock()
|
||||
|
||||
if c.pingStarted {
|
||||
return
|
||||
}
|
||||
c.pingStarted = true
|
||||
|
||||
// Create a new pingDone channel for this ping monitor instance
|
||||
c.pingDone = make(chan struct{})
|
||||
|
||||
// Send an initial ping immediately
|
||||
go func() {
|
||||
c.sendPing()
|
||||
c.pingMonitor()
|
||||
}()
|
||||
}
|
||||
|
||||
// stopPingMonitor stops the ping monitor goroutine if it's running.
|
||||
func (c *Client) stopPingMonitor() {
|
||||
c.pingStartedMux.Lock()
|
||||
defer c.pingStartedMux.Unlock()
|
||||
|
||||
if !c.pingStarted {
|
||||
return
|
||||
}
|
||||
|
||||
// Close the pingDone channel to stop the monitor
|
||||
close(c.pingDone)
|
||||
c.pingStarted = false
|
||||
}
|
||||
|
||||
// GetConfigVersion returns the current config version
|
||||
func (c *Client) GetConfigVersion() int {
|
||||
c.configVersionMux.RLock()
|
||||
defer c.configVersionMux.RUnlock()
|
||||
return c.configVersion
|
||||
}
|
||||
|
||||
// setConfigVersion updates the config version if the new version is higher
|
||||
func (c *Client) setConfigVersion(version int) {
|
||||
c.configVersionMux.Lock()
|
||||
defer c.configVersionMux.Unlock()
|
||||
logger.Debug("websocket: setting config version to %d", version)
|
||||
c.configVersion = version
|
||||
}
|
||||
|
||||
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
|
||||
func (c *Client) readPumpWithDisconnectDetection() {
|
||||
defer func() {
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
// Only attempt reconnect if we're not shutting down
|
||||
select {
|
||||
case <-c.done:
|
||||
// Shutting down, don't reconnect
|
||||
return
|
||||
default:
|
||||
c.reconnect()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
var msg WSMessage
|
||||
err := c.conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
// Check if we're shutting down or explicitly disconnected before logging error
|
||||
select {
|
||||
case <-c.done:
|
||||
// Expected during shutdown, don't log as error
|
||||
logger.Debug("websocket: connection closed during shutdown")
|
||||
return
|
||||
default:
|
||||
// Check if explicitly disconnected
|
||||
if c.isDisconnected {
|
||||
logger.Debug("websocket: connection closed: client was explicitly disconnected")
|
||||
return
|
||||
}
|
||||
|
||||
// Unexpected error during normal operation
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
||||
logger.Error("websocket: read error: %v", err)
|
||||
} else {
|
||||
logger.Debug("websocket: connection closed: %v", err)
|
||||
}
|
||||
return // triggers reconnect via defer
|
||||
}
|
||||
}
|
||||
|
||||
// Update config version from incoming message
|
||||
c.setConfigVersion(msg.ConfigVersion)
|
||||
|
||||
c.handlersMux.RLock()
|
||||
if handler, ok := c.handlers[msg.Type]; ok {
|
||||
// Mark that we're processing a message
|
||||
c.processingMux.Lock()
|
||||
c.processingMessage = true
|
||||
c.processingMux.Unlock()
|
||||
c.processingWg.Add(1)
|
||||
|
||||
handler(msg)
|
||||
|
||||
// Mark that we're done processing
|
||||
c.processingWg.Done()
|
||||
c.processingMux.Lock()
|
||||
c.processingMessage = false
|
||||
c.processingMux.Unlock()
|
||||
}
|
||||
c.handlersMux.RUnlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) reconnect() {
|
||||
c.setConnected(false)
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
// Don't reconnect if explicitly disconnected
|
||||
if c.isDisconnected {
|
||||
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
|
||||
return
|
||||
}
|
||||
|
||||
// Only reconnect if we're not shutting down
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
go c.connectWithRetry()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) setConnected(status bool) {
|
||||
c.reconnectMux.Lock()
|
||||
defer c.reconnectMux.Unlock()
|
||||
c.isConnected = status
|
||||
}
|
||||
|
||||
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
|
||||
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||
logger.Info("websocket: Loading tls-client-cert %s", p12Path)
|
||||
// Read the PKCS12 file
|
||||
p12Data, err := os.ReadFile(p12Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
|
||||
}
|
||||
|
||||
// Parse PKCS12 with empty password for non-encrypted files
|
||||
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
|
||||
}
|
||||
|
||||
// Create certificate
|
||||
cert := tls.Certificate{
|
||||
Certificate: [][]byte{certificate.Raw},
|
||||
PrivateKey: privateKey,
|
||||
}
|
||||
|
||||
// Optional: Add CA certificates if present
|
||||
rootCAs, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
|
||||
}
|
||||
if len(caCerts) > 0 {
|
||||
for _, caCert := range caCerts {
|
||||
rootCAs.AddCert(caCert)
|
||||
}
|
||||
}
|
||||
|
||||
// Create TLS configuration
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
RootCAs: rootCAs,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user