mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 22:16:42 +00:00
Compare commits
191 Commits
1.0.0-beta
...
bind
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d891cfa97 | ||
|
|
78e3bb374a | ||
|
|
a61c7ca1ee | ||
|
|
7696ba2e36 | ||
|
|
235877c379 | ||
|
|
befab0f8d1 | ||
|
|
914d080a57 | ||
|
|
a274b4b38f | ||
|
|
ce3c585514 | ||
|
|
963d8abad5 | ||
|
|
38eb56381f | ||
|
|
43b3822090 | ||
|
|
b0fb370c4d | ||
|
|
99328ee76f | ||
|
|
36fc3ea253 | ||
|
|
a7979259f3 | ||
|
|
ea6fa72bc0 | ||
|
|
f9adde6b1d | ||
|
|
ba25586646 | ||
|
|
952ab63e8d | ||
|
|
5e84f802ed | ||
|
|
f40b0ff820 | ||
|
|
95a4840374 | ||
|
|
27424170e4 | ||
|
|
a8ace6f64a | ||
|
|
3fa1073f49 | ||
|
|
76d86c10ff | ||
|
|
2d34c6c8b2 | ||
|
|
a7f3477bdd | ||
|
|
af0a72d296 | ||
|
|
d1e836e760 | ||
|
|
8dd45c4ca2 | ||
|
|
9db009058b | ||
|
|
29c01deb05 | ||
|
|
7224d9824d | ||
|
|
8afc28fdff | ||
|
|
4ba2fb7b53 | ||
|
|
2e6076923d | ||
|
|
4c001dc751 | ||
|
|
2b8e240752 | ||
|
|
bee490713d | ||
|
|
1cb7fd94ab | ||
|
|
dc9a547950 | ||
|
|
2be0933246 | ||
|
|
c0b1cd6bde | ||
|
|
dd00289f8e | ||
|
|
b23a02ee97 | ||
|
|
80f726cfea | ||
|
|
aa8828186f | ||
|
|
0a990d196d | ||
|
|
00e8050949 | ||
|
|
18ee4c93fb | ||
|
|
8fa2da00b6 | ||
|
|
b851cd73c9 | ||
|
|
4c19d7ef6d | ||
|
|
cbecb9a0ce | ||
|
|
4fc8db08ba | ||
|
|
7ca46e0a75 | ||
|
|
a4ea5143af | ||
|
|
e9257b6423 | ||
|
|
3c9d3a1d2c | ||
|
|
b426f14190 | ||
|
|
d48acfba39 | ||
|
|
35b48cd8e5 | ||
|
|
15bca53309 | ||
|
|
898b599db5 | ||
|
|
c07bba18bb | ||
|
|
4c24d3b808 | ||
|
|
ad4ab3d04f | ||
|
|
e21153fae1 | ||
|
|
41c3360e23 | ||
|
|
1960d32443 | ||
|
|
74b83b3303 | ||
|
|
c2c3470868 | ||
|
|
2bda3dc3cc | ||
|
|
52573c8664 | ||
|
|
0d3c34e23f | ||
|
|
891df5c74b | ||
|
|
6f3f162d2b | ||
|
|
f6fa5fd02c | ||
|
|
8f4e0ba29e | ||
|
|
32b7dc7c43 | ||
|
|
78d2ebe1de | ||
|
|
014f8eb4e5 | ||
|
|
cd42803291 | ||
|
|
5c5b303994 | ||
|
|
968873da22 | ||
|
|
e2772f918b | ||
|
|
cdf6a31b67 | ||
|
|
2ce72065a7 | ||
|
|
b3e7aafb58 | ||
|
|
dd610ad850 | ||
|
|
b4b0a832e7 | ||
|
|
79963c1f66 | ||
|
|
5f95282161 | ||
|
|
1cca54f9d5 | ||
|
|
219df22919 | ||
|
|
337d9934fd | ||
|
|
bba4d72a78 | ||
|
|
fb5c793126 | ||
|
|
1821dbb672 | ||
|
|
4fda6fe031 | ||
|
|
f286f0faf6 | ||
|
|
cba3d607bf | ||
|
|
5ca12834a1 | ||
|
|
9d41154daa | ||
|
|
63933b57fc | ||
|
|
c25d77597d | ||
|
|
ad080046a1 | ||
|
|
c1f7cf93a5 | ||
|
|
f3f112fc42 | ||
|
|
612a9ddb15 | ||
|
|
ad1fa2e59a | ||
|
|
29235f6100 | ||
|
|
848ac6b0c4 | ||
|
|
6ab66e6c36 | ||
|
|
d7f29d4709 | ||
|
|
6fb2b68e21 | ||
|
|
8d72e77d57 | ||
|
|
25a9b83496 | ||
|
|
4d33016389 | ||
|
|
0f717aec01 | ||
|
|
4c58cd6eff | ||
|
|
3ad36f95e1 | ||
|
|
b58e7c9fad | ||
|
|
85a8a737e8 | ||
|
|
8e83a83294 | ||
|
|
c1ef56001f | ||
|
|
c04e727bd3 | ||
|
|
becc214078 | ||
|
|
13e7f55b30 | ||
|
|
0be3ee7eee | ||
|
|
0b1724a3f3 | ||
|
|
e606264deb | ||
|
|
5497eb8a4e | ||
|
|
2159371371 | ||
|
|
ad8a94fdc8 | ||
|
|
61b7feef80 | ||
|
|
4cb31df3c8 | ||
|
|
3b0eef6d60 | ||
|
|
5d305f1d03 | ||
|
|
31e5d4e3bd | ||
|
|
8fb9468d08 | ||
|
|
5bbd5016aa | ||
|
|
8c40b8c578 | ||
|
|
e35f7c2d36 | ||
|
|
aeb8f203a4 | ||
|
|
73bd036e58 | ||
|
|
eb6b310304 | ||
|
|
3d70ff190f | ||
|
|
7e2d7b93a1 | ||
|
|
f50ff67057 | ||
|
|
76d5e95fbf | ||
|
|
b6db70e285 | ||
|
|
3819823d95 | ||
|
|
b2830e8473 | ||
|
|
43a43b429d | ||
|
|
1593f22691 | ||
|
|
4883402393 | ||
|
|
02eab1ff52 | ||
|
|
e0ca38bb35 | ||
|
|
8d46ae3aa2 | ||
|
|
7424caca8a | ||
|
|
9d9f10a799 | ||
|
|
a42d2b75dd | ||
|
|
8b09545cf6 | ||
|
|
eb77be09e2 | ||
|
|
ad01296c41 | ||
|
|
b553209712 | ||
|
|
c5098f0cd0 | ||
|
|
5ec1aac0d1 | ||
|
|
313ef42883 | ||
|
|
085c98668d | ||
|
|
6107d20e26 | ||
|
|
66edae4288 | ||
|
|
f69a7f647d | ||
|
|
e8bd55bed9 | ||
|
|
b23eda9c06 | ||
|
|
76503f3f2c | ||
|
|
9c3112f9bd | ||
|
|
462af30d16 | ||
|
|
fa6038eb38 | ||
|
|
f346b6cc5d | ||
|
|
f20b9ebb14 | ||
|
|
39bfe5b230 | ||
|
|
a1a3dd9ba2 | ||
|
|
7b1492f327 | ||
|
|
4e50819785 | ||
|
|
f8dccbec80 | ||
|
|
0c5c59cf00 | ||
|
|
868bb55f87 |
@@ -1,6 +1,6 @@
|
|||||||
.gitignore
|
.gitignore
|
||||||
.dockerignore
|
.dockerignore
|
||||||
newt
|
olm
|
||||||
*.json
|
*.json
|
||||||
README.md
|
README.md
|
||||||
Makefile
|
Makefile
|
||||||
|
|||||||
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Summary
|
||||||
|
description: A clear and concise summary of the requested feature.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Motivation
|
||||||
|
description: |
|
||||||
|
Why is this feature important?
|
||||||
|
Explain the problem this feature would solve or what use case it would enable.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Proposed Solution
|
||||||
|
description: |
|
||||||
|
How would you like to see this feature implemented?
|
||||||
|
Provide as much detail as possible about the desired behavior, configuration, or changes.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Alternatives Considered
|
||||||
|
description: Describe any alternative solutions or workarounds you've thought about.
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Additional Context
|
||||||
|
description: Add any other context, mockups, or screenshots about the feature request here.
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Before submitting, please:
|
||||||
|
- Check if there is an existing issue for this feature.
|
||||||
|
- Clearly explain the benefit and use case.
|
||||||
|
- Be as specific as possible to help contributors evaluate and implement.
|
||||||
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Create a bug report
|
||||||
|
labels: []
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Describe the Bug
|
||||||
|
description: A clear and concise description of what the bug is.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Environment
|
||||||
|
description: Please fill out the relevant details below for your environment.
|
||||||
|
value: |
|
||||||
|
- OS Type & Version: (e.g., Ubuntu 22.04)
|
||||||
|
- Pangolin Version:
|
||||||
|
- Gerbil Version:
|
||||||
|
- Traefik Version:
|
||||||
|
- Newt Version:
|
||||||
|
- Olm Version: (if applicable)
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: To Reproduce
|
||||||
|
description: |
|
||||||
|
Steps to reproduce the behavior, please provide a clear description of how to reproduce the issue, based on the linked minimal reproduction. Screenshots can be provided in the issue body below.
|
||||||
|
|
||||||
|
If using code blocks, make sure syntax highlighting is correct and double-check that the rendered preview is not broken.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Expected Behavior
|
||||||
|
description: A clear and concise description of what you expected to happen.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Before posting the issue go through the steps you've written down to make sure the steps provided are detailed and clear.
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Contributors should be able to follow the steps provided in order to reproduce the bug.
|
||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: false
|
||||||
|
contact_links:
|
||||||
|
- name: Need help or have questions?
|
||||||
|
url: https://github.com/orgs/fosrl/discussions
|
||||||
|
about: Ask questions, get help, and discuss with other community members
|
||||||
|
- name: Request a Feature
|
||||||
|
url: https://github.com/orgs/fosrl/discussions/new?category=feature-requests
|
||||||
|
about: Feature requests should be opened as discussions so others can upvote and comment
|
||||||
40
.github/dependabot.yml
vendored
Normal file
40
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "gomod"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
groups:
|
||||||
|
dev-patch-updates:
|
||||||
|
dependency-type: "development"
|
||||||
|
update-types:
|
||||||
|
- "patch"
|
||||||
|
dev-minor-updates:
|
||||||
|
dependency-type: "development"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
prod-patch-updates:
|
||||||
|
dependency-type: "production"
|
||||||
|
update-types:
|
||||||
|
- "patch"
|
||||||
|
prod-minor-updates:
|
||||||
|
dependency-type: "production"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
|
||||||
|
- package-ecosystem: "docker"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
groups:
|
||||||
|
patch-updates:
|
||||||
|
update-types:
|
||||||
|
- "patch"
|
||||||
|
minor-updates:
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
|
||||||
|
- package-ecosystem: "github-actions"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
60
.github/workflows/cicd.yml
vendored
Normal file
60
.github/workflows/cicd.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
name: CI/CD Pipeline
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- "*"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release:
|
||||||
|
name: Build and Release
|
||||||
|
runs-on: amd64-runner
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v3
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Log in to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||||
|
|
||||||
|
- name: Extract tag name
|
||||||
|
id: get-tag
|
||||||
|
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version: 1.25
|
||||||
|
|
||||||
|
- name: Update version in main.go
|
||||||
|
run: |
|
||||||
|
TAG=${{ env.TAG }}
|
||||||
|
if [ -f main.go ]; then
|
||||||
|
sed -i 's/version_replaceme/'"$TAG"'/' main.go
|
||||||
|
echo "Updated main.go with version $TAG"
|
||||||
|
else
|
||||||
|
echo "main.go not found"
|
||||||
|
fi
|
||||||
|
- name: Build and push Docker images
|
||||||
|
run: |
|
||||||
|
TAG=${{ env.TAG }}
|
||||||
|
make docker-build-release tag=$TAG
|
||||||
|
|
||||||
|
- name: Build binaries
|
||||||
|
run: |
|
||||||
|
make go-build-release
|
||||||
|
|
||||||
|
- name: Upload artifacts from /bin
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: binaries
|
||||||
|
path: bin/
|
||||||
132
.github/workflows/mirror.yaml
vendored
Normal file
132
.github/workflows/mirror.yaml
vendored
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
name: Mirror & Sign (Docker Hub to GHCR)
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: {}
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
id-token: write # for keyless OIDC
|
||||||
|
|
||||||
|
env:
|
||||||
|
SOURCE_IMAGE: docker.io/fosrl/olm
|
||||||
|
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
mirror-and-dual-sign:
|
||||||
|
runs-on: amd64-runner
|
||||||
|
steps:
|
||||||
|
- name: Install skopeo + jq
|
||||||
|
run: |
|
||||||
|
sudo apt-get update -y
|
||||||
|
sudo apt-get install -y skopeo jq
|
||||||
|
skopeo --version
|
||||||
|
|
||||||
|
- name: Install cosign
|
||||||
|
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||||
|
|
||||||
|
- name: Input check
|
||||||
|
run: |
|
||||||
|
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
|
||||||
|
echo "Source : ${SOURCE_IMAGE}"
|
||||||
|
echo "Target : ${DEST_IMAGE}"
|
||||||
|
|
||||||
|
# Auth for skopeo (containers-auth)
|
||||||
|
- name: Skopeo login to GHCR
|
||||||
|
run: |
|
||||||
|
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
|
||||||
|
|
||||||
|
# Auth for cosign (docker-config)
|
||||||
|
- name: Docker login to GHCR (for cosign)
|
||||||
|
run: |
|
||||||
|
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
|
||||||
|
|
||||||
|
- name: List source tags
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
|
||||||
|
| jq -r '.Tags[]' | sort -u > src-tags.txt
|
||||||
|
echo "Found source tags: $(wc -l < src-tags.txt)"
|
||||||
|
head -n 20 src-tags.txt || true
|
||||||
|
|
||||||
|
- name: List destination tags (skip existing)
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
|
||||||
|
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
|
||||||
|
else
|
||||||
|
: > dst-tags.txt
|
||||||
|
fi
|
||||||
|
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
|
||||||
|
|
||||||
|
- name: Mirror, dual-sign, and verify
|
||||||
|
env:
|
||||||
|
# keyless
|
||||||
|
COSIGN_YES: "true"
|
||||||
|
# key-based
|
||||||
|
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||||
|
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||||
|
# verify
|
||||||
|
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
copied=0; skipped=0; v_ok=0; errs=0
|
||||||
|
|
||||||
|
issuer="https://token.actions.githubusercontent.com"
|
||||||
|
id_regex="^https://github.com/${{ github.repository }}/.+"
|
||||||
|
|
||||||
|
while read -r tag; do
|
||||||
|
[ -z "$tag" ] && continue
|
||||||
|
|
||||||
|
if grep -Fxq "$tag" dst-tags.txt; then
|
||||||
|
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
|
||||||
|
skipped=$((skipped+1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
|
||||||
|
if ! skopeo copy --all --retry-times 3 \
|
||||||
|
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
|
||||||
|
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
|
||||||
|
errs=$((errs+1)); continue
|
||||||
|
fi
|
||||||
|
copied=$((copied+1))
|
||||||
|
|
||||||
|
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
|
||||||
|
ref="${DEST_IMAGE}@${digest}"
|
||||||
|
|
||||||
|
echo "==> cosign sign (keyless) --recursive ${ref}"
|
||||||
|
if ! cosign sign --recursive "${ref}"; then
|
||||||
|
echo "::warning title=Keyless sign failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> cosign sign (key) --recursive ${ref}"
|
||||||
|
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
|
||||||
|
echo "::warning title=Key sign failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> cosign verify (public key) ${ref}"
|
||||||
|
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
|
||||||
|
echo "::warning title=Verify(pubkey) failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> cosign verify (keyless policy) ${ref}"
|
||||||
|
if ! cosign verify \
|
||||||
|
--certificate-oidc-issuer "${issuer}" \
|
||||||
|
--certificate-identity-regexp "${id_regex}" \
|
||||||
|
"${ref}" -o text; then
|
||||||
|
echo "::warning title=Verify(keyless) failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
else
|
||||||
|
v_ok=$((v_ok+1))
|
||||||
|
fi
|
||||||
|
done < src-tags.txt
|
||||||
|
|
||||||
|
echo "---- Summary ----"
|
||||||
|
echo "Copied : $copied"
|
||||||
|
echo "Skipped : $skipped"
|
||||||
|
echo "Verified OK : $v_ok"
|
||||||
|
echo "Errors : $errs"
|
||||||
28
.github/workflows/test.yml
vendored
Normal file
28
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
name: Run Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- dev
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version: 1.25
|
||||||
|
|
||||||
|
- name: Build go
|
||||||
|
run: go build
|
||||||
|
|
||||||
|
- name: Build Docker image
|
||||||
|
run: make build
|
||||||
|
|
||||||
|
- name: Build binaries
|
||||||
|
run: make go-build-release
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,2 @@
|
|||||||
newt
|
|
||||||
.DS_Store
|
.DS_Store
|
||||||
bin/
|
bin/
|
||||||
1
.go-version
Normal file
1
.go-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
1.25
|
||||||
@@ -4,11 +4,7 @@ Contributions are welcome!
|
|||||||
|
|
||||||
Please see the contribution and local development guide on the docs page before getting started:
|
Please see the contribution and local development guide on the docs page before getting started:
|
||||||
|
|
||||||
https://docs.fossorial.io/development
|
https://docs.pangolin.net/development/contributing
|
||||||
|
|
||||||
For ideas about what features to work on and our future plans, please see the roadmap:
|
|
||||||
|
|
||||||
https://docs.fossorial.io/roadmap
|
|
||||||
|
|
||||||
### Licensing Considerations
|
### Licensing Considerations
|
||||||
|
|
||||||
|
|||||||
12
Dockerfile
12
Dockerfile
@@ -1,4 +1,4 @@
|
|||||||
FROM golang:1.23.1-alpine AS builder
|
FROM golang:1.25-alpine AS builder
|
||||||
|
|
||||||
# Set the working directory inside the container
|
# Set the working directory inside the container
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
@@ -13,15 +13,15 @@ RUN go mod download
|
|||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Build the application
|
# Build the application
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /newt
|
RUN CGO_ENABLED=0 GOOS=linux go build -o /olm
|
||||||
|
|
||||||
# Start a new stage from scratch
|
# Start a new stage from scratch
|
||||||
FROM ubuntu:22.04 AS runner
|
FROM alpine:3.22 AS runner
|
||||||
|
|
||||||
RUN apt-get update && apt-get install ca-certificates -y && rm -rf /var/lib/apt/lists/*
|
RUN apk --no-cache add ca-certificates
|
||||||
|
|
||||||
# Copy the pre-built binary file from the previous stage and the entrypoint script
|
# Copy the pre-built binary file from the previous stage and the entrypoint script
|
||||||
COPY --from=builder /newt /usr/local/bin/
|
COPY --from=builder /olm /usr/local/bin/
|
||||||
COPY entrypoint.sh /
|
COPY entrypoint.sh /
|
||||||
|
|
||||||
RUN chmod +x /entrypoint.sh
|
RUN chmod +x /entrypoint.sh
|
||||||
@@ -30,4 +30,4 @@ RUN chmod +x /entrypoint.sh
|
|||||||
ENTRYPOINT ["/entrypoint.sh"]
|
ENTRYPOINT ["/entrypoint.sh"]
|
||||||
|
|
||||||
# Command to run the executable
|
# Command to run the executable
|
||||||
CMD ["newt"]
|
CMD ["olm"]
|
||||||
38
Makefile
38
Makefile
@@ -1,26 +1,26 @@
|
|||||||
|
|
||||||
all: build push
|
all: go-build-release
|
||||||
|
|
||||||
build:
|
docker-build-release:
|
||||||
docker build -t fosrl/newt:latest .
|
@if [ -z "$(tag)" ]; then \
|
||||||
|
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
|
||||||
push:
|
exit 1; \
|
||||||
docker push fosrl/newt:latest
|
fi
|
||||||
|
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:latest -f Dockerfile --push .
|
||||||
test:
|
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:$(tag) -f Dockerfile --push .
|
||||||
docker run fosrl/newt:latest
|
|
||||||
|
|
||||||
local:
|
local:
|
||||||
CGO_ENABLED=0 go build -o newt
|
CGO_ENABLED=0 go build -o bin/olm
|
||||||
|
|
||||||
release:
|
build:
|
||||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/newt_linux_arm64
|
docker build -t fosrl/olm:latest .
|
||||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/newt_linux_amd64
|
|
||||||
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/newt_darwin_arm64
|
go-build-release:
|
||||||
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/newt_darwin_amd64
|
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/olm_linux_arm64
|
||||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o newt_windows_amd64.bin/exe
|
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/olm_linux_amd64
|
||||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -o bin/newt_freebsd_amd64
|
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/olm_darwin_arm64
|
||||||
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -o bin/newt_freebsd_arm64
|
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/olm_darwin_amd64
|
||||||
|
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/olm_windows_amd64.exe
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm newt
|
rm olm
|
||||||
301
README.md
301
README.md
@@ -1,46 +1,59 @@
|
|||||||
# Newt
|
# Olm
|
||||||
|
|
||||||
Newt is a fully user space [WireGuard](https://www.wireguard.com/) tunnel client and TCP/UDP proxy, designed to securely expose private resources controlled by Pangolin. By using Newt, you don't need to manage complex WireGuard tunnels and NATing.
|
Olm is a [WireGuard](https://www.wireguard.com/) tunnel client designed to securely connect your computer to Newt sites running on remote networks.
|
||||||
|
|
||||||
### Installation and Documentation
|
### Installation and Documentation
|
||||||
|
|
||||||
Newt is used with Pangolin and Gerbil as part of the larger system. See documentation below:
|
Olm is used with Pangolin and Newt as part of the larger system. See documentation below:
|
||||||
|
|
||||||
- [Installation Instructions](https://docs.fossorial.io)
|
- [Full Documentation](https://docs.pangolin.net)
|
||||||
- [Full Documentation](https://docs.fossorial.io)
|
|
||||||
|
|
||||||
## Preview
|
|
||||||
|
|
||||||
<img src="public/screenshots/preview.png" alt="Preview"/>
|
|
||||||
|
|
||||||
_Sample output of a Newt container connected to Pangolin and hosting various resource target proxies._
|
|
||||||
|
|
||||||
## Key Functions
|
## Key Functions
|
||||||
|
|
||||||
### Registers with Pangolin
|
### Registers with Pangolin
|
||||||
|
|
||||||
Using the Newt ID and a secret, the client will make HTTP requests to Pangolin to receive a session token. Using that token, it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket.
|
Using the Olm ID and a secret, the olm will make HTTP requests to Pangolin to receive a session token. Using that token, it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket.
|
||||||
|
|
||||||
### Receives WireGuard Control Messages
|
### Receives WireGuard Control Messages
|
||||||
|
|
||||||
When Newt receives WireGuard control messages, it will use the information encoded (endpoint, public key) to bring up a WireGuard tunnel using [netstack](https://github.com/WireGuard/wireguard-go/blob/master/tun/netstack/examples/http_server.go) fully in user space. It will ping over the tunnel to ensure the peer on the Gerbil side is brought up.
|
When Olm receives WireGuard control messages, it will use the information encoded (endpoint, public key) to bring up a WireGuard tunnel on your computer to a remote Newt. It will ping over the tunnel to ensure the peer is brought up.
|
||||||
|
|
||||||
### Receives Proxy Control Messages
|
|
||||||
|
|
||||||
When Newt receives WireGuard control messages, it will use the information encoded to create a local low level TCP and UDP proxies attached to the virtual tunnel in order to relay traffic to programmed targets.
|
|
||||||
|
|
||||||
## CLI Args
|
## CLI Args
|
||||||
|
|
||||||
- `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket.
|
- `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket.
|
||||||
- `id`: Newt ID generated by Pangolin to identify the client.
|
- `id`: Olm ID generated by Pangolin to identify the olm.
|
||||||
- `secret`: A unique secret (not shared and kept private) used to authenticate the client ID with the websocket in order to receive commands.
|
- `secret`: A unique secret (not shared and kept private) used to authenticate the olm ID with the websocket in order to receive commands.
|
||||||
- `dns`: DNS server to use to resolve the endpoint
|
- `mtu` (optional): MTU for the internal WG interface. Default: 1280
|
||||||
- `log-level` (optional): The log level to use. Default: INFO
|
- `dns` (optional): DNS server to use to resolve the endpoint. Default: 8.8.8.8
|
||||||
|
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: INFO
|
||||||
|
- `ping-interval` (optional): Interval for pinging the server. Default: 3s
|
||||||
|
- `ping-timeout` (optional): Timeout for each ping. Default: 5s
|
||||||
|
- `interface` (optional): Name of the WireGuard interface. Default: olm
|
||||||
|
- `enable-http` (optional): Enable HTTP server for receiving connection requests. Default: false
|
||||||
|
- `http-addr` (optional): HTTP server address (e.g., ':9452'). Default: :9452
|
||||||
|
- `holepunch` (optional): Enable hole punching. Default: false
|
||||||
|
|
||||||
Example:
|
## Environment Variables
|
||||||
|
|
||||||
|
All CLI arguments can also be set via environment variables:
|
||||||
|
|
||||||
|
- `PANGOLIN_ENDPOINT`: Equivalent to `--endpoint`
|
||||||
|
- `OLM_ID`: Equivalent to `--id`
|
||||||
|
- `OLM_SECRET`: Equivalent to `--secret`
|
||||||
|
- `MTU`: Equivalent to `--mtu`
|
||||||
|
- `DNS`: Equivalent to `--dns`
|
||||||
|
- `LOG_LEVEL`: Equivalent to `--log-level`
|
||||||
|
- `INTERFACE`: Equivalent to `--interface`
|
||||||
|
- `HTTP_ADDR`: Equivalent to `--http-addr`
|
||||||
|
- `PING_INTERVAL`: Equivalent to `--ping-interval`
|
||||||
|
- `PING_TIMEOUT`: Equivalent to `--ping-timeout`
|
||||||
|
- `HOLEPUNCH`: Set to "true" to enable hole punching (equivalent to `--holepunch`)
|
||||||
|
- `CONFIG_FILE`: Set to the location of a JSON file to load secret values
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./newt \
|
olm \
|
||||||
--id 31frd0uzbjvp721 \
|
--id 31frd0uzbjvp721 \
|
||||||
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
||||||
--endpoint https://example.com
|
--endpoint https://example.com
|
||||||
@@ -50,40 +63,230 @@ You can also run it with Docker compose. For example, a service in your `docker-
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
services:
|
services:
|
||||||
newt:
|
olm:
|
||||||
image: fosrl/newt
|
image: fosrl/olm
|
||||||
container_name: newt
|
container_name: olm
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
environment:
|
network_mode: host
|
||||||
- PANGOLIN_ENDPOINT=https://example.com
|
devices:
|
||||||
- NEWT_ID=2ix2t8xk22ubpfy
|
- /dev/net/tun:/dev/net/tun
|
||||||
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
environment:
|
||||||
|
- PANGOLIN_ENDPOINT=https://example.com
|
||||||
|
- OLM_ID=31frd0uzbjvp721
|
||||||
|
- OLM_SECRET=h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also pass the CLI args to the container:
|
You can also pass the CLI args to the container:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
services:
|
services:
|
||||||
newt:
|
olm:
|
||||||
image: fosrl/newt
|
image: fosrl/olm
|
||||||
container_name: newt
|
container_name: olm
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
command:
|
network_mode: host
|
||||||
- --id 31frd0uzbjvp721
|
devices:
|
||||||
- --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6
|
- /dev/net/tun:/dev/net/tun
|
||||||
- --endpoint https://example.com
|
command:
|
||||||
|
- --id 31frd0uzbjvp721
|
||||||
|
- --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6
|
||||||
|
- --endpoint https://example.com
|
||||||
|
```
|
||||||
|
|
||||||
|
**Docker Configuration Notes:**
|
||||||
|
|
||||||
|
- `network_mode: host` brings the olm network interface to the host system, allowing the WireGuard tunnel to function properly
|
||||||
|
- `devices: - /dev/net/tun:/dev/net/tun` is required to give the container access to the TUN device for creating WireGuard interfaces
|
||||||
|
|
||||||
|
## Loading secrets from files
|
||||||
|
|
||||||
|
You can use `CONFIG_FILE` to define a location of a config file to store the credentials between runs.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cat ~/.config/olm-client/config.json
|
||||||
|
{
|
||||||
|
"id": "spmzu8rbpzj1qq6",
|
||||||
|
"secret": "f6v61mjutwme2kkydbw3fjo227zl60a2tsf5psw9r25hgae3",
|
||||||
|
"endpoint": "https://app.pangolin.net",
|
||||||
|
"tlsClientCert": ""
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This file is also written to when newt first starts up. So you do not need to run every time with --id and secret if you have run it once!
|
||||||
|
|
||||||
|
Default locations:
|
||||||
|
|
||||||
|
- **macOS**: `~/Library/Application Support/olm-client/config.json`
|
||||||
|
- **Windows**: `%PROGRAMDATA%\olm\olm-client\config.json`
|
||||||
|
- **Linux/Others**: `~/.config/olm-client/config.json`
|
||||||
|
|
||||||
|
## Hole Punching
|
||||||
|
|
||||||
|
In the default mode, olm "relays" traffic through Gerbil in the cloud to get down to newt. This is a little more reliable. Support for NAT hole punching is also EXPERIMENTAL right now using the `--holepunch` flag. This will attempt to orchestrate a NAT hole punch between the two sites so that traffic flows directly. This will save data costs and speed. If it fails it should fall back to relaying.
|
||||||
|
|
||||||
|
Right now, basic NAT hole punching is supported. We plan to add:
|
||||||
|
|
||||||
|
- [ ] Birthday paradox
|
||||||
|
- [ ] UPnP
|
||||||
|
- [ ] LAN detection
|
||||||
|
|
||||||
|
## Windows Service
|
||||||
|
|
||||||
|
On Windows, olm has to be installed and run as a Windows service. When running it with the cli args live above it will attempt to install and run the service to function like a cli tool. You can also run the following:
|
||||||
|
|
||||||
|
### Service Management Commands
|
||||||
|
|
||||||
|
```
|
||||||
|
# Install the service
|
||||||
|
olm.exe install
|
||||||
|
|
||||||
|
# Start the service
|
||||||
|
olm.exe start
|
||||||
|
|
||||||
|
# Stop the service
|
||||||
|
olm.exe stop
|
||||||
|
|
||||||
|
# Check service status
|
||||||
|
olm.exe status
|
||||||
|
|
||||||
|
# Remove the service
|
||||||
|
olm.exe remove
|
||||||
|
|
||||||
|
# Run in debug mode (console output) with our without id & secret
|
||||||
|
olm.exe debug
|
||||||
|
|
||||||
|
# Show help
|
||||||
|
olm.exe help
|
||||||
|
```
|
||||||
|
|
||||||
|
Note running the service requires credentials in `%PROGRAMDATA%\olm\olm-client\config.json`.
|
||||||
|
|
||||||
|
### Service Configuration
|
||||||
|
|
||||||
|
When running as a service, Olm will read configuration from environment variables or you can modify the service to include command-line arguments:
|
||||||
|
|
||||||
|
1. Install the service: `olm.exe install`
|
||||||
|
2. Set the credentials in `%PROGRAMDATA%\olm\olm-client\config.json`. Hint: if you run olm once with --id and --secret this file will be populated!
|
||||||
|
3. Start the service: `olm.exe start`
|
||||||
|
|
||||||
|
### Service Logs
|
||||||
|
|
||||||
|
When running as a service, logs are written to:
|
||||||
|
|
||||||
|
- Windows Event Log (Application log, source: "OlmWireguardService")
|
||||||
|
- Log files in: `%PROGRAMDATA%\olm\logs\olm.log`
|
||||||
|
|
||||||
|
You can view the Windows Event Log using Event Viewer or PowerShell:
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
Get-EventLog -LogName Application -Source "OlmWireguardService" -Newest 10
|
||||||
|
```
|
||||||
|
|
||||||
|
## HTTP Endpoints
|
||||||
|
|
||||||
|
Olm can be controlled with an embedded http server when using `--enable-http`. This allows you to start it as a daemon and trigger it with the following endpoints:
|
||||||
|
|
||||||
|
### POST /connect
|
||||||
|
Initiates a new connection request.
|
||||||
|
|
||||||
|
**Request Body:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "string",
|
||||||
|
"secret": "string",
|
||||||
|
"endpoint": "string"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required Fields:**
|
||||||
|
- `id`: Connection identifier
|
||||||
|
- `secret`: Authentication secret
|
||||||
|
- `endpoint`: Target endpoint URL
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
- **Status Code:** `202 Accepted`
|
||||||
|
- **Content-Type:** `application/json`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "connection request accepted"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Error Responses:**
|
||||||
|
- `405 Method Not Allowed` - Non-POST requests
|
||||||
|
- `400 Bad Request` - Invalid JSON or missing required fields
|
||||||
|
|
||||||
|
### GET /status
|
||||||
|
Returns the current connection status and peer information.
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
- **Status Code:** `200 OK`
|
||||||
|
- **Content-Type:** `application/json`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "connected",
|
||||||
|
"connected": true,
|
||||||
|
"tunnelIP": "100.89.128.3/20",
|
||||||
|
"version": "version_replaceme",
|
||||||
|
"peers": {
|
||||||
|
"10": {
|
||||||
|
"siteId": 10,
|
||||||
|
"connected": true,
|
||||||
|
"rtt": 145338339,
|
||||||
|
"lastSeen": "2025-08-13T14:39:17.208334428-07:00",
|
||||||
|
"endpoint": "p.fosrl.io:21820",
|
||||||
|
"isRelay": true
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"siteId": 8,
|
||||||
|
"connected": false,
|
||||||
|
"rtt": 0,
|
||||||
|
"lastSeen": "2025-08-13T14:39:19.663823645-07:00",
|
||||||
|
"endpoint": "p.fosrl.io:21820",
|
||||||
|
"isRelay": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Fields:**
|
||||||
|
- `status`: Overall connection status ("connected" or "disconnected")
|
||||||
|
- `connected`: Boolean connection state
|
||||||
|
- `tunnelIP`: IP address and subnet of the tunnel (when connected)
|
||||||
|
- `version`: Olm version string
|
||||||
|
- `peers`: Map of peer statuses by site ID
|
||||||
|
- `siteId`: Peer site identifier
|
||||||
|
- `connected`: Boolean peer connection state
|
||||||
|
- `rtt`: Peer round-trip time (integer, nanoseconds)
|
||||||
|
- `lastSeen`: Last time peer was seen (RFC3339 timestamp)
|
||||||
|
- `endpoint`: Peer endpoint address
|
||||||
|
- `isRelay`: Whether the peer is relayed (true) or direct (false)
|
||||||
|
|
||||||
|
**Error Responses:**
|
||||||
|
- `405 Method Not Allowed` - Non-GET requests
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Connect to a peer
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/connect \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"id": "31frd0uzbjvp721",
|
||||||
|
"secret": "h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6",
|
||||||
|
"endpoint": "https://example.com"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check connection status
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8080/status
|
||||||
```
|
```
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
### Container
|
|
||||||
|
|
||||||
Ensure Docker is installed.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make
|
|
||||||
```
|
|
||||||
|
|
||||||
### Binary
|
### Binary
|
||||||
|
|
||||||
Make sure to have Go 1.23.1 installed.
|
Make sure to have Go 1.23.1 installed.
|
||||||
@@ -94,7 +297,7 @@ make local
|
|||||||
|
|
||||||
## Licensing
|
## Licensing
|
||||||
|
|
||||||
Newt is dual licensed under the AGPLv3 and the Fossorial Commercial license. For inquiries about commercial licensing, please contact us.
|
Olm is dual licensed under the AGPLv3 and the Fossorial Commercial license. For inquiries about commercial licensing, please contact us.
|
||||||
|
|
||||||
## Contributions
|
## Contributions
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
||||||
|
|
||||||
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
||||||
2. Send a detailed report to [security@fossorial.io](mailto:security@fossorial.io) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
2. Send a detailed report to [security@pangolin.net](mailto:security@pangolin.net) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||||
|
|
||||||
- Description and location of the vulnerability.
|
- Description and location of the vulnerability.
|
||||||
- Potential impact of the vulnerability.
|
- Potential impact of the vulnerability.
|
||||||
|
|||||||
411
api/api.go
Normal file
411
api/api.go
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionRequest defines the structure for an incoming connection request
|
||||||
|
type ConnectionRequest struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
UserToken string `json:"userToken,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwitchOrgRequest defines the structure for switching organizations
|
||||||
|
type SwitchOrgRequest struct {
|
||||||
|
OrgID string `json:"orgId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerStatus represents the status of a peer connection
|
||||||
|
type PeerStatus struct {
|
||||||
|
SiteID int `json:"siteId"`
|
||||||
|
Connected bool `json:"connected"`
|
||||||
|
RTT time.Duration `json:"rtt"`
|
||||||
|
LastSeen time.Time `json:"lastSeen"`
|
||||||
|
Endpoint string `json:"endpoint,omitempty"`
|
||||||
|
IsRelay bool `json:"isRelay"`
|
||||||
|
PeerIP string `json:"peerAddress,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusResponse is returned by the status endpoint
|
||||||
|
type StatusResponse struct {
|
||||||
|
Connected bool `json:"connected"`
|
||||||
|
Registered bool `json:"registered"`
|
||||||
|
TunnelIP string `json:"tunnelIP,omitempty"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
OrgID string `json:"orgId,omitempty"`
|
||||||
|
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// API represents the HTTP server and its state
|
||||||
|
type API struct {
|
||||||
|
addr string
|
||||||
|
socketPath string
|
||||||
|
listener net.Listener
|
||||||
|
server *http.Server
|
||||||
|
connectionChan chan ConnectionRequest
|
||||||
|
switchOrgChan chan SwitchOrgRequest
|
||||||
|
shutdownChan chan struct{}
|
||||||
|
disconnectChan chan struct{}
|
||||||
|
statusMu sync.RWMutex
|
||||||
|
peerStatuses map[int]*PeerStatus
|
||||||
|
connectedAt time.Time
|
||||||
|
isConnected bool
|
||||||
|
isRegistered bool
|
||||||
|
tunnelIP string
|
||||||
|
version string
|
||||||
|
orgID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||||
|
func NewAPI(addr string) *API {
|
||||||
|
s := &API{
|
||||||
|
addr: addr,
|
||||||
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
|
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||||
|
shutdownChan: make(chan struct{}, 1),
|
||||||
|
disconnectChan: make(chan struct{}, 1),
|
||||||
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
|
||||||
|
func NewAPISocket(socketPath string) *API {
|
||||||
|
s := &API{
|
||||||
|
socketPath: socketPath,
|
||||||
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
|
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||||
|
shutdownChan: make(chan struct{}, 1),
|
||||||
|
disconnectChan: make(chan struct{}, 1),
|
||||||
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the HTTP server
|
||||||
|
func (s *API) Start() error {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/connect", s.handleConnect)
|
||||||
|
mux.HandleFunc("/status", s.handleStatus)
|
||||||
|
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||||
|
mux.HandleFunc("/disconnect", s.handleDisconnect)
|
||||||
|
mux.HandleFunc("/exit", s.handleExit)
|
||||||
|
|
||||||
|
s.server = &http.Server{
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if s.socketPath != "" {
|
||||||
|
// Use platform-specific socket listener
|
||||||
|
s.listener, err = createSocketListener(s.socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create socket listener: %w", err)
|
||||||
|
}
|
||||||
|
logger.Info("Starting HTTP server on socket %s", s.socketPath)
|
||||||
|
} else {
|
||||||
|
// Use TCP listener
|
||||||
|
s.listener, err = net.Listen("tcp", s.addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create TCP listener: %w", err)
|
||||||
|
}
|
||||||
|
logger.Info("Starting HTTP server on %s", s.addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
logger.Error("HTTP server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the HTTP server
|
||||||
|
func (s *API) Stop() error {
|
||||||
|
logger.Info("Stopping api server")
|
||||||
|
|
||||||
|
// Close the server first, which will also close the listener gracefully
|
||||||
|
if s.server != nil {
|
||||||
|
s.server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up socket file if using Unix socket
|
||||||
|
if s.socketPath != "" {
|
||||||
|
cleanupSocket(s.socketPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionChannel returns the channel for receiving connection requests
|
||||||
|
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
|
||||||
|
return s.connectionChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
||||||
|
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
||||||
|
return s.switchOrgChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetShutdownChannel returns the channel for receiving shutdown requests
|
||||||
|
func (s *API) GetShutdownChannel() <-chan struct{} {
|
||||||
|
return s.shutdownChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisconnectChannel returns the channel for receiving disconnect requests
|
||||||
|
func (s *API) GetDisconnectChannel() <-chan struct{} {
|
||||||
|
return s.disconnectChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||||
|
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
|
||||||
|
status, exists := s.peerStatuses[siteID]
|
||||||
|
if !exists {
|
||||||
|
status = &PeerStatus{
|
||||||
|
SiteID: siteID,
|
||||||
|
}
|
||||||
|
s.peerStatuses[siteID] = status
|
||||||
|
}
|
||||||
|
|
||||||
|
status.Connected = connected
|
||||||
|
status.RTT = rtt
|
||||||
|
status.LastSeen = time.Now()
|
||||||
|
status.Endpoint = endpoint
|
||||||
|
status.IsRelay = isRelay
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConnectionStatus sets the overall connection status
|
||||||
|
func (s *API) SetConnectionStatus(isConnected bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
|
||||||
|
s.isConnected = isConnected
|
||||||
|
|
||||||
|
if isConnected {
|
||||||
|
s.connectedAt = time.Now()
|
||||||
|
} else {
|
||||||
|
// Clear peer statuses when disconnected
|
||||||
|
s.peerStatuses = make(map[int]*PeerStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *API) SetRegistered(registered bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.isRegistered = registered
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTunnelIP sets the tunnel IP address
|
||||||
|
func (s *API) SetTunnelIP(tunnelIP string) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.tunnelIP = tunnelIP
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVersion sets the olm version
|
||||||
|
func (s *API) SetVersion(version string) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.version = version
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOrgID sets the organization ID
|
||||||
|
func (s *API) SetOrgID(orgID string) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.orgID = orgID
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeerRelayStatus updates only the relay status of a peer
|
||||||
|
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
|
||||||
|
status, exists := s.peerStatuses[siteID]
|
||||||
|
if !exists {
|
||||||
|
status = &PeerStatus{
|
||||||
|
SiteID: siteID,
|
||||||
|
}
|
||||||
|
s.peerStatuses[siteID] = status
|
||||||
|
}
|
||||||
|
|
||||||
|
status.Endpoint = endpoint
|
||||||
|
status.IsRelay = isRelay
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConnect handles the /connect endpoint
|
||||||
|
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we are already connected, reject new connection requests
|
||||||
|
s.statusMu.RLock()
|
||||||
|
alreadyConnected := s.isConnected
|
||||||
|
s.statusMu.RUnlock()
|
||||||
|
if alreadyConnected {
|
||||||
|
http.Error(w, "Already connected to a server. Disconnect first before connecting again.", http.StatusConflict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ConnectionRequest
|
||||||
|
decoder := json.NewDecoder(r.Body)
|
||||||
|
if err := decoder.Decode(&req); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required fields
|
||||||
|
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
|
||||||
|
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the request to the main goroutine
|
||||||
|
s.connectionChan <- req
|
||||||
|
|
||||||
|
// Return a success response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"status": "connection request accepted",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleStatus handles the /status endpoint
|
||||||
|
func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.statusMu.RLock()
|
||||||
|
defer s.statusMu.RUnlock()
|
||||||
|
|
||||||
|
resp := StatusResponse{
|
||||||
|
Connected: s.isConnected,
|
||||||
|
Registered: s.isRegistered,
|
||||||
|
TunnelIP: s.tunnelIP,
|
||||||
|
Version: s.version,
|
||||||
|
OrgID: s.orgID,
|
||||||
|
PeerStatuses: s.peerStatuses,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleExit handles the /exit endpoint
|
||||||
|
func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Received exit request via API")
|
||||||
|
|
||||||
|
// Send shutdown signal
|
||||||
|
select {
|
||||||
|
case s.shutdownChan <- struct{}{}:
|
||||||
|
// Signal sent successfully
|
||||||
|
default:
|
||||||
|
// Channel already has a signal, don't block
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a success response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"status": "shutdown initiated",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSwitchOrg handles the /switch-org endpoint
|
||||||
|
func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req SwitchOrgRequest
|
||||||
|
decoder := json.NewDecoder(r.Body)
|
||||||
|
if err := decoder.Decode(&req); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required fields
|
||||||
|
if req.OrgID == "" {
|
||||||
|
http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||||
|
|
||||||
|
// Send the request to the main goroutine
|
||||||
|
select {
|
||||||
|
case s.switchOrgChan <- req:
|
||||||
|
// Signal sent successfully
|
||||||
|
default:
|
||||||
|
// Channel already has a pending request
|
||||||
|
http.Error(w, "Org switch already in progress", http.StatusConflict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a success response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"status": "org switch request accepted",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDisconnect handles the /disconnect endpoint
|
||||||
|
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we are already disconnected, reject new disconnect requests
|
||||||
|
s.statusMu.RLock()
|
||||||
|
alreadyDisconnected := !s.isConnected
|
||||||
|
s.statusMu.RUnlock()
|
||||||
|
if alreadyDisconnected {
|
||||||
|
http.Error(w, "Not currently connected to a server.", http.StatusConflict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Received disconnect request via API")
|
||||||
|
|
||||||
|
// Send disconnect signal
|
||||||
|
select {
|
||||||
|
case s.disconnectChan <- struct{}{}:
|
||||||
|
// Signal sent successfully
|
||||||
|
default:
|
||||||
|
// Channel already has a signal, don't block
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a success response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"status": "disconnect initiated",
|
||||||
|
})
|
||||||
|
}
|
||||||
50
api/api_unix.go
Normal file
50
api/api_unix.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
//go:build !windows
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createSocketListener creates a Unix domain socket listener
|
||||||
|
func createSocketListener(socketPath string) (net.Listener, error) {
|
||||||
|
// Ensure the directory exists
|
||||||
|
dir := filepath.Dir(socketPath)
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create socket directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove existing socket file if it exists
|
||||||
|
if err := os.RemoveAll(socketPath); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to remove existing socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.Listen("unix", socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to listen on Unix socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set socket permissions to allow access
|
||||||
|
if err := os.Chmod(socketPath, 0666); err != nil {
|
||||||
|
listener.Close()
|
||||||
|
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Created Unix socket at %s", socketPath)
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupSocket removes the Unix socket file
|
||||||
|
func cleanupSocket(socketPath string) {
|
||||||
|
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
logger.Error("Failed to remove socket file %s: %v", socketPath, err)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Removed Unix socket at %s", socketPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
41
api/api_windows.go
Normal file
41
api/api_windows.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/Microsoft/go-winio"
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createSocketListener creates a Windows named pipe listener
|
||||||
|
func createSocketListener(pipePath string) (net.Listener, error) {
|
||||||
|
// Ensure the pipe path has the correct format
|
||||||
|
if pipePath[0] != '\\' {
|
||||||
|
pipePath = `\\.\pipe\` + pipePath
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a pipe configuration that allows everyone to write
|
||||||
|
config := &winio.PipeConfig{
|
||||||
|
// Set security descriptor to allow everyone full access
|
||||||
|
// This SDDL string grants full access to Everyone (WD) and to the current owner (OW)
|
||||||
|
SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a named pipe listener using go-winio with the configuration
|
||||||
|
listener, err := winio.ListenPipe(pipePath, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to listen on named pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Created named pipe at %s with write access for everyone", pipePath)
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up
|
||||||
|
func cleanupSocket(pipePath string) {
|
||||||
|
logger.Debug("Named pipe %s will be automatically cleaned up", pipePath)
|
||||||
|
}
|
||||||
378
bind/shared_bind.go
Normal file
378
bind/shared_bind.go
Normal file
@@ -0,0 +1,378 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Endpoint represents a network endpoint for the SharedBind
|
||||||
|
type Endpoint struct {
|
||||||
|
AddrPort netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearSrc implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) ClearSrc() {}
|
||||||
|
|
||||||
|
// DstIP implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) DstIP() netip.Addr {
|
||||||
|
return e.AddrPort.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SrcIP implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) SrcIP() netip.Addr {
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DstToBytes implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) DstToBytes() []byte {
|
||||||
|
b, _ := e.AddrPort.MarshalBinary()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// DstToString implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) DstToString() string {
|
||||||
|
return e.AddrPort.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SrcToString implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) SrcToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
|
||||||
|
// and hole punch senders. It wraps a single UDP connection and implements
|
||||||
|
// reference counting to prevent premature closure.
|
||||||
|
type SharedBind struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// The underlying UDP connection
|
||||||
|
udpConn *net.UDPConn
|
||||||
|
|
||||||
|
// IPv4 and IPv6 packet connections for advanced features
|
||||||
|
ipv4PC *ipv4.PacketConn
|
||||||
|
ipv6PC *ipv6.PacketConn
|
||||||
|
|
||||||
|
// Reference counting to prevent closing while in use
|
||||||
|
refCount atomic.Int32
|
||||||
|
closed atomic.Bool
|
||||||
|
|
||||||
|
// Channels for receiving data
|
||||||
|
recvFuncs []wgConn.ReceiveFunc
|
||||||
|
|
||||||
|
// Port binding information
|
||||||
|
port uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new SharedBind from an existing UDP connection.
|
||||||
|
// The SharedBind takes ownership of the connection and will close it
|
||||||
|
// when all references are released.
|
||||||
|
func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
||||||
|
if udpConn == nil {
|
||||||
|
return nil, fmt.Errorf("udpConn cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
bind := &SharedBind{
|
||||||
|
udpConn: udpConn,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize reference count to 1 (the creator holds the first reference)
|
||||||
|
bind.refCount.Store(1)
|
||||||
|
|
||||||
|
// Get the local port
|
||||||
|
if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok {
|
||||||
|
bind.port = uint16(addr.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bind, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRef increments the reference count. Call this when sharing
|
||||||
|
// the bind with another component.
|
||||||
|
func (b *SharedBind) AddRef() {
|
||||||
|
newCount := b.refCount.Add(1)
|
||||||
|
// Optional: Add logging for debugging
|
||||||
|
_ = newCount // Placeholder for potential logging
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release decrements the reference count. When it reaches zero,
|
||||||
|
// the underlying UDP connection is closed.
|
||||||
|
func (b *SharedBind) Release() error {
|
||||||
|
newCount := b.refCount.Add(-1)
|
||||||
|
// Optional: Add logging for debugging
|
||||||
|
_ = newCount // Placeholder for potential logging
|
||||||
|
|
||||||
|
if newCount < 0 {
|
||||||
|
// This should never happen with proper usage
|
||||||
|
b.refCount.Store(0)
|
||||||
|
return fmt.Errorf("SharedBind reference count went negative")
|
||||||
|
}
|
||||||
|
|
||||||
|
if newCount == 0 {
|
||||||
|
return b.closeConnection()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeConnection actually closes the UDP connection
|
||||||
|
func (b *SharedBind) closeConnection() error {
|
||||||
|
if !b.closed.CompareAndSwap(false, true) {
|
||||||
|
// Already closed
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if b.udpConn != nil {
|
||||||
|
err = b.udpConn.Close()
|
||||||
|
b.udpConn = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ipv4PC = nil
|
||||||
|
b.ipv6PC = nil
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUDPConn returns the underlying UDP connection.
|
||||||
|
// The caller must not close this connection directly.
|
||||||
|
func (b *SharedBind) GetUDPConn() *net.UDPConn {
|
||||||
|
b.mu.RLock()
|
||||||
|
defer b.mu.RUnlock()
|
||||||
|
return b.udpConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRefCount returns the current reference count (for debugging)
|
||||||
|
func (b *SharedBind) GetRefCount() int32 {
|
||||||
|
return b.refCount.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed returns whether the bind is closed
|
||||||
|
func (b *SharedBind) IsClosed() bool {
|
||||||
|
return b.closed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteToUDP writes data to a specific UDP address.
|
||||||
|
// This is thread-safe and can be used by hole punch senders.
|
||||||
|
func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.RLock()
|
||||||
|
conn := b.udpConn
|
||||||
|
b.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn.WriteToUDP(data, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements the WireGuard Bind interface.
|
||||||
|
// It decrements the reference count and closes the connection if no references remain.
|
||||||
|
func (b *SharedBind) Close() error {
|
||||||
|
return b.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open implements the WireGuard Bind interface.
|
||||||
|
// Since the connection is already open, this just sets up the receive functions.
|
||||||
|
func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return nil, 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
if b.udpConn == nil {
|
||||||
|
return nil, 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up IPv4 and IPv6 packet connections for advanced features
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
|
||||||
|
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create receive functions
|
||||||
|
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
|
||||||
|
|
||||||
|
// Add IPv4 receive function
|
||||||
|
if b.ipv4PC != nil || runtime.GOOS != "linux" {
|
||||||
|
recvFuncs = append(recvFuncs, b.makeReceiveIPv4())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IPv6 receive function if needed
|
||||||
|
// For now, we focus on IPv4 for hole punching use case
|
||||||
|
|
||||||
|
b.recvFuncs = recvFuncs
|
||||||
|
return recvFuncs, b.port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeReceiveIPv4 creates a receive function for IPv4 packets
|
||||||
|
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.RLock()
|
||||||
|
conn := b.udpConn
|
||||||
|
pc := b.ipv4PC
|
||||||
|
b.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use batch reading on Linux for performance
|
||||||
|
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
||||||
|
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to simple read for other platforms
|
||||||
|
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiveIPv4Batch uses batch reading for better performance on Linux
|
||||||
|
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||||
|
// Create messages for batch reading
|
||||||
|
msgs := make([]ipv4.Message, len(bufs))
|
||||||
|
for i := range bufs {
|
||||||
|
msgs[i].Buffers = [][]byte{bufs[i]}
|
||||||
|
msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use
|
||||||
|
}
|
||||||
|
|
||||||
|
numMsgs, err := pc.ReadBatch(msgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numMsgs; i++ {
|
||||||
|
sizes[i] = msgs[i].N
|
||||||
|
if sizes[i] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgs[i].Addr != nil {
|
||||||
|
if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok {
|
||||||
|
addrPort := udpAddr.AddrPort()
|
||||||
|
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return numMsgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
|
||||||
|
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||||
|
n, addr, err := conn.ReadFromUDP(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sizes[0] = n
|
||||||
|
if addr != nil {
|
||||||
|
addrPort := addr.AddrPort()
|
||||||
|
eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send implements the WireGuard Bind interface.
|
||||||
|
// It sends packets to the specified endpoint.
|
||||||
|
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.RLock()
|
||||||
|
conn := b.udpConn
|
||||||
|
b.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the destination address from the endpoint
|
||||||
|
var destAddr *net.UDPAddr
|
||||||
|
|
||||||
|
// Try to cast to StdNetEndpoint first
|
||||||
|
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
|
||||||
|
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
|
||||||
|
} else {
|
||||||
|
// Fallback: construct from DstIP and DstToBytes
|
||||||
|
dstBytes := ep.DstToBytes()
|
||||||
|
if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes)
|
||||||
|
var addr netip.Addr
|
||||||
|
var port uint16
|
||||||
|
|
||||||
|
if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes)
|
||||||
|
addr, _ = netip.AddrFromSlice(dstBytes[:16])
|
||||||
|
port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8
|
||||||
|
} else { // IPv4
|
||||||
|
addr, _ = netip.AddrFromSlice(dstBytes[:4])
|
||||||
|
port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.IsValid() {
|
||||||
|
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if destAddr == nil {
|
||||||
|
return fmt.Errorf("could not extract destination address from endpoint")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send all buffers to the destination
|
||||||
|
for _, buf := range bufs {
|
||||||
|
_, err := conn.WriteToUDP(buf, destAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMark implements the WireGuard Bind interface.
|
||||||
|
// It's a no-op for this implementation.
|
||||||
|
func (b *SharedBind) SetMark(mark uint32) error {
|
||||||
|
// Not implemented for this use case
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns the preferred batch size for sending packets.
|
||||||
|
func (b *SharedBind) BatchSize() int {
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
return wgConn.IdealBatchSize
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseEndpoint creates a new endpoint from a string address.
|
||||||
|
func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
|
||||||
|
addrPort, err := netip.ParseAddrPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil
|
||||||
|
}
|
||||||
424
bind/shared_bind_test.go
Normal file
424
bind/shared_bind_test.go
Normal file
@@ -0,0 +1,424 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSharedBindCreation tests basic creation and initialization
|
||||||
|
func TestSharedBindCreation(t *testing.T) {
|
||||||
|
// Create a UDP connection
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer udpConn.Close()
|
||||||
|
|
||||||
|
// Create SharedBind
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bind == nil {
|
||||||
|
t.Fatal("SharedBind is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify initial reference count
|
||||||
|
if bind.refCount.Load() != 1 {
|
||||||
|
t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
if err := bind.Close(); err != nil {
|
||||||
|
t.Errorf("Failed to close SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindReferenceCount tests reference counting
|
||||||
|
func TestSharedBindReferenceCount(t *testing.T) {
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add references
|
||||||
|
bind.AddRef()
|
||||||
|
if bind.refCount.Load() != 2 {
|
||||||
|
t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.AddRef()
|
||||||
|
if bind.refCount.Load() != 3 {
|
||||||
|
t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release references
|
||||||
|
bind.Release()
|
||||||
|
if bind.refCount.Load() != 2 {
|
||||||
|
t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.Release()
|
||||||
|
bind.Release() // This should close the connection
|
||||||
|
|
||||||
|
if !bind.closed.Load() {
|
||||||
|
t.Error("Expected bind to be closed after all references released")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
|
||||||
|
func TestSharedBindWriteToUDP(t *testing.T) {
|
||||||
|
// Create sender
|
||||||
|
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
senderBind, err := New(senderConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer senderBind.Close()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
// Send data
|
||||||
|
testData := []byte("Hello, SharedBind!")
|
||||||
|
n, err := senderBind.WriteToUDP(testData, receiverAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteToUDP failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != len(testData) {
|
||||||
|
t.Errorf("Expected to send %d bytes, sent %d", len(testData), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive data
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, _, err = receiverConn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to receive data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf[:n]) != string(testData) {
|
||||||
|
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindConcurrentWrites tests thread-safety
|
||||||
|
func TestSharedBindConcurrentWrites(t *testing.T) {
|
||||||
|
// Create sender
|
||||||
|
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
senderBind, err := New(senderConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer senderBind.Close()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
// Launch concurrent writes
|
||||||
|
numGoroutines := 100
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
data := []byte{byte(id)}
|
||||||
|
_, err := senderBind.WriteToUDP(data, receiverAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
|
||||||
|
func TestSharedBindWireGuardInterface(t *testing.T) {
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer bind.Close()
|
||||||
|
|
||||||
|
// Test Open
|
||||||
|
recvFuncs, port, err := bind.Open(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Open failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(recvFuncs) == 0 {
|
||||||
|
t.Error("Expected at least one receive function")
|
||||||
|
}
|
||||||
|
|
||||||
|
if port == 0 {
|
||||||
|
t.Error("Expected non-zero port")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test SetMark (should be a no-op)
|
||||||
|
if err := bind.SetMark(0); err != nil {
|
||||||
|
t.Errorf("SetMark failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test BatchSize
|
||||||
|
batchSize := bind.BatchSize()
|
||||||
|
if batchSize <= 0 {
|
||||||
|
t.Error("Expected positive batch size")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindSend tests the Send method with WireGuard endpoints
|
||||||
|
func TestSharedBindSend(t *testing.T) {
|
||||||
|
// Create sender
|
||||||
|
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
senderBind, err := New(senderConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer senderBind.Close()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
// Create an endpoint
|
||||||
|
addrPort := receiverAddr.AddrPort()
|
||||||
|
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
|
||||||
|
// Send data
|
||||||
|
testData := []byte("WireGuard packet")
|
||||||
|
bufs := [][]byte{testData}
|
||||||
|
err = senderBind.Send(bufs, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Send failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive data
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, _, err := receiverConn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to receive data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf[:n]) != string(testData) {
|
||||||
|
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
|
||||||
|
func TestSharedBindMultipleUsers(t *testing.T) {
|
||||||
|
// Create shared bind
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedBind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add reference for hole punch sender
|
||||||
|
sharedBind.AddRef()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Simulate WireGuard using the bind
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
addrPort := receiverAddr.AddrPort()
|
||||||
|
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
data := []byte("WireGuard packet")
|
||||||
|
bufs := [][]byte{data}
|
||||||
|
if err := sharedBind.Send(bufs, endpoint); err != nil {
|
||||||
|
t.Errorf("WireGuard Send failed: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Simulate hole punch sender using the bind
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
data := []byte("Hole punch packet")
|
||||||
|
if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil {
|
||||||
|
t.Errorf("Hole punch WriteToUDP failed: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Release the hole punch reference
|
||||||
|
sharedBind.Release()
|
||||||
|
|
||||||
|
// Close WireGuard's reference (should close the connection)
|
||||||
|
sharedBind.Close()
|
||||||
|
|
||||||
|
if !sharedBind.closed.Load() {
|
||||||
|
t.Error("Expected bind to be closed after all users released it")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEndpoint tests the Endpoint implementation
|
||||||
|
func TestEndpoint(t *testing.T) {
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
addrPort := netip.AddrPortFrom(addr, 51820)
|
||||||
|
|
||||||
|
ep := &Endpoint{AddrPort: addrPort}
|
||||||
|
|
||||||
|
// Test DstIP
|
||||||
|
if ep.DstIP() != addr {
|
||||||
|
t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test DstToString
|
||||||
|
expected := "192.168.1.1:51820"
|
||||||
|
if ep.DstToString() != expected {
|
||||||
|
t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test DstToBytes
|
||||||
|
bytes := ep.DstToBytes()
|
||||||
|
if len(bytes) == 0 {
|
||||||
|
t.Error("Expected DstToBytes to return non-empty slice")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test SrcIP (should be zero)
|
||||||
|
if ep.SrcIP().IsValid() {
|
||||||
|
t.Error("Expected SrcIP to be invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test ClearSrc (should not panic)
|
||||||
|
ep.ClearSrc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseEndpoint tests the ParseEndpoint method
|
||||||
|
func TestParseEndpoint(t *testing.T) {
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer bind.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr bool
|
||||||
|
checkAddr func(*testing.T, wgConn.Endpoint)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid IPv4",
|
||||||
|
input: "192.168.1.1:51820",
|
||||||
|
wantErr: false,
|
||||||
|
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||||
|
if ep.DstToString() != "192.168.1.1:51820" {
|
||||||
|
t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid IPv6",
|
||||||
|
input: "[::1]:51820",
|
||||||
|
wantErr: false,
|
||||||
|
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||||
|
if ep.DstToString() != "[::1]:51820" {
|
||||||
|
t.Errorf("Expected [::1]:51820, got %s", ep.DstToString())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid - missing port",
|
||||||
|
input: "192.168.1.1",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid - bad format",
|
||||||
|
input: "not-an-address",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ep, err := bind.ParseEndpoint(tt.input)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !tt.wantErr && tt.checkAddr != nil {
|
||||||
|
tt.checkAddr(t, ep)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
562
config.go
Normal file
562
config.go
Normal file
@@ -0,0 +1,562 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OlmConfig holds all configuration options for the Olm client
|
||||||
|
type OlmConfig struct {
|
||||||
|
// Connection settings
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
OrgID string `json:"org"`
|
||||||
|
UserToken string `json:"userToken"`
|
||||||
|
|
||||||
|
// Network settings
|
||||||
|
MTU int `json:"mtu"`
|
||||||
|
DNS string `json:"dns"`
|
||||||
|
InterfaceName string `json:"interface"`
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
LogLevel string `json:"logLevel"`
|
||||||
|
|
||||||
|
// HTTP server
|
||||||
|
EnableAPI bool `json:"enableApi"`
|
||||||
|
HTTPAddr string `json:"httpAddr"`
|
||||||
|
SocketPath string `json:"socketPath"`
|
||||||
|
|
||||||
|
// Ping settings
|
||||||
|
PingInterval string `json:"pingInterval"`
|
||||||
|
PingTimeout string `json:"pingTimeout"`
|
||||||
|
|
||||||
|
// Advanced
|
||||||
|
Holepunch bool `json:"holepunch"`
|
||||||
|
TlsClientCert string `json:"tlsClientCert"`
|
||||||
|
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
|
||||||
|
|
||||||
|
// Parsed values (not in JSON)
|
||||||
|
PingIntervalDuration time.Duration `json:"-"`
|
||||||
|
PingTimeoutDuration time.Duration `json:"-"`
|
||||||
|
|
||||||
|
// Source tracking (not in JSON)
|
||||||
|
sources map[string]string `json:"-"`
|
||||||
|
|
||||||
|
Version string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigSource tracks where each config value came from
|
||||||
|
type ConfigSource string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SourceDefault ConfigSource = "default"
|
||||||
|
SourceFile ConfigSource = "file"
|
||||||
|
SourceEnv ConfigSource = "environment"
|
||||||
|
SourceCLI ConfigSource = "cli"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultConfig returns a config with default values
|
||||||
|
func DefaultConfig() *OlmConfig {
|
||||||
|
// Set OS-specific socket path
|
||||||
|
var socketPath string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
socketPath = "olm"
|
||||||
|
default: // darwin, linux, and others
|
||||||
|
socketPath = "/var/run/olm.sock"
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &OlmConfig{
|
||||||
|
MTU: 1280,
|
||||||
|
DNS: "8.8.8.8",
|
||||||
|
LogLevel: "INFO",
|
||||||
|
InterfaceName: "olm",
|
||||||
|
EnableAPI: false,
|
||||||
|
SocketPath: socketPath,
|
||||||
|
PingInterval: "3s",
|
||||||
|
PingTimeout: "5s",
|
||||||
|
Holepunch: false,
|
||||||
|
// DoNotCreateNewClient: false,
|
||||||
|
sources: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track default sources
|
||||||
|
config.sources["mtu"] = string(SourceDefault)
|
||||||
|
config.sources["dns"] = string(SourceDefault)
|
||||||
|
config.sources["logLevel"] = string(SourceDefault)
|
||||||
|
config.sources["interface"] = string(SourceDefault)
|
||||||
|
config.sources["enableApi"] = string(SourceDefault)
|
||||||
|
config.sources["httpAddr"] = string(SourceDefault)
|
||||||
|
config.sources["socketPath"] = string(SourceDefault)
|
||||||
|
config.sources["pingInterval"] = string(SourceDefault)
|
||||||
|
config.sources["pingTimeout"] = string(SourceDefault)
|
||||||
|
config.sources["holepunch"] = string(SourceDefault)
|
||||||
|
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOlmConfigPath returns the path to the olm config file
|
||||||
|
func getOlmConfigPath() string {
|
||||||
|
configFile := os.Getenv("CONFIG_FILE")
|
||||||
|
if configFile != "" {
|
||||||
|
return configFile
|
||||||
|
}
|
||||||
|
|
||||||
|
var configDir string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client")
|
||||||
|
case "windows":
|
||||||
|
configDir = filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "olm-client")
|
||||||
|
default: // linux and others
|
||||||
|
configDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||||
|
fmt.Printf("Warning: Failed to create config directory: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(configDir, "config.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfig loads configuration from file, env vars, and CLI args
|
||||||
|
// Priority: CLI args > Env vars > Config file > Defaults
|
||||||
|
// Returns: (config, showVersion, showConfig, error)
|
||||||
|
func LoadConfig(args []string) (*OlmConfig, bool, bool, error) {
|
||||||
|
// Start with defaults
|
||||||
|
config := DefaultConfig()
|
||||||
|
|
||||||
|
// Load from config file (if exists)
|
||||||
|
fileConfig, err := loadConfigFromFile()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, false, fmt.Errorf("failed to load config file: %w", err)
|
||||||
|
}
|
||||||
|
if fileConfig != nil {
|
||||||
|
mergeConfigs(config, fileConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override with environment variables
|
||||||
|
loadConfigFromEnv(config)
|
||||||
|
|
||||||
|
// Override with CLI arguments
|
||||||
|
showVersion, showConfig, err := loadConfigFromCLI(config, args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse duration strings
|
||||||
|
if err := config.parseDurations(); err != nil {
|
||||||
|
return nil, false, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, showVersion, showConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadConfigFromFile loads configuration from the JSON config file
|
||||||
|
func loadConfigFromFile() (*OlmConfig, error) {
|
||||||
|
configPath := getOlmConfigPath()
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil // File doesn't exist, not an error
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var config OlmConfig
|
||||||
|
if err := json.Unmarshal(data, &config); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadConfigFromEnv loads configuration from environment variables
|
||||||
|
func loadConfigFromEnv(config *OlmConfig) {
|
||||||
|
if val := os.Getenv("PANGOLIN_ENDPOINT"); val != "" {
|
||||||
|
config.Endpoint = val
|
||||||
|
config.sources["endpoint"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("OLM_ID"); val != "" {
|
||||||
|
config.ID = val
|
||||||
|
config.sources["id"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("OLM_SECRET"); val != "" {
|
||||||
|
config.Secret = val
|
||||||
|
config.sources["secret"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("ORG"); val != "" {
|
||||||
|
config.OrgID = val
|
||||||
|
config.sources["org"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("USER_TOKEN"); val != "" {
|
||||||
|
config.UserToken = val
|
||||||
|
config.sources["userToken"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("MTU"); val != "" {
|
||||||
|
if mtu, err := strconv.Atoi(val); err == nil {
|
||||||
|
config.MTU = mtu
|
||||||
|
config.sources["mtu"] = string(SourceEnv)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Invalid MTU value: %s, keeping current value\n", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := os.Getenv("DNS"); val != "" {
|
||||||
|
config.DNS = val
|
||||||
|
config.sources["dns"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("LOG_LEVEL"); val != "" {
|
||||||
|
config.LogLevel = val
|
||||||
|
config.sources["logLevel"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("INTERFACE"); val != "" {
|
||||||
|
config.InterfaceName = val
|
||||||
|
config.sources["interface"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("HTTP_ADDR"); val != "" {
|
||||||
|
config.HTTPAddr = val
|
||||||
|
config.sources["httpAddr"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("PING_INTERVAL"); val != "" {
|
||||||
|
config.PingInterval = val
|
||||||
|
config.sources["pingInterval"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("PING_TIMEOUT"); val != "" {
|
||||||
|
config.PingTimeout = val
|
||||||
|
config.sources["pingTimeout"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("ENABLE_API"); val == "true" {
|
||||||
|
config.EnableAPI = true
|
||||||
|
config.sources["enableApi"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("SOCKET_PATH"); val != "" {
|
||||||
|
config.SocketPath = val
|
||||||
|
config.sources["socketPath"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("HOLEPUNCH"); val == "true" {
|
||||||
|
config.Holepunch = true
|
||||||
|
config.sources["holepunch"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
|
||||||
|
// config.DoNotCreateNewClient = true
|
||||||
|
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadConfigFromCLI loads configuration from command-line arguments
|
||||||
|
func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||||
|
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
|
||||||
|
|
||||||
|
// Store original values to detect changes
|
||||||
|
origValues := map[string]interface{}{
|
||||||
|
"endpoint": config.Endpoint,
|
||||||
|
"id": config.ID,
|
||||||
|
"secret": config.Secret,
|
||||||
|
"org": config.OrgID,
|
||||||
|
"userToken": config.UserToken,
|
||||||
|
"mtu": config.MTU,
|
||||||
|
"dns": config.DNS,
|
||||||
|
"logLevel": config.LogLevel,
|
||||||
|
"interface": config.InterfaceName,
|
||||||
|
"httpAddr": config.HTTPAddr,
|
||||||
|
"socketPath": config.SocketPath,
|
||||||
|
"pingInterval": config.PingInterval,
|
||||||
|
"pingTimeout": config.PingTimeout,
|
||||||
|
"enableApi": config.EnableAPI,
|
||||||
|
"holepunch": config.Holepunch,
|
||||||
|
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define flags
|
||||||
|
serviceFlags.StringVar(&config.Endpoint, "endpoint", config.Endpoint, "Endpoint of your Pangolin server")
|
||||||
|
serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID")
|
||||||
|
serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret")
|
||||||
|
serviceFlags.StringVar(&config.OrgID, "org", config.OrgID, "Organization ID")
|
||||||
|
serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)")
|
||||||
|
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
|
||||||
|
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
|
||||||
|
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||||
|
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
|
||||||
|
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
|
||||||
|
serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)")
|
||||||
|
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
|
||||||
|
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
|
||||||
|
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
|
||||||
|
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
|
||||||
|
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
|
||||||
|
|
||||||
|
version := serviceFlags.Bool("version", false, "Print the version")
|
||||||
|
showConfig := serviceFlags.Bool("show-config", false, "Show configuration sources and exit")
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
if err := serviceFlags.Parse(args); err != nil {
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track which values were changed by CLI args
|
||||||
|
if config.Endpoint != origValues["endpoint"].(string) {
|
||||||
|
config.sources["endpoint"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.ID != origValues["id"].(string) {
|
||||||
|
config.sources["id"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.Secret != origValues["secret"].(string) {
|
||||||
|
config.sources["secret"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.OrgID != origValues["org"].(string) {
|
||||||
|
config.sources["org"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.UserToken != origValues["userToken"].(string) {
|
||||||
|
config.sources["userToken"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.MTU != origValues["mtu"].(int) {
|
||||||
|
config.sources["mtu"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.DNS != origValues["dns"].(string) {
|
||||||
|
config.sources["dns"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.LogLevel != origValues["logLevel"].(string) {
|
||||||
|
config.sources["logLevel"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.InterfaceName != origValues["interface"].(string) {
|
||||||
|
config.sources["interface"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.HTTPAddr != origValues["httpAddr"].(string) {
|
||||||
|
config.sources["httpAddr"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.SocketPath != origValues["socketPath"].(string) {
|
||||||
|
config.sources["socketPath"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.PingInterval != origValues["pingInterval"].(string) {
|
||||||
|
config.sources["pingInterval"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.PingTimeout != origValues["pingTimeout"].(string) {
|
||||||
|
config.sources["pingTimeout"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.EnableAPI != origValues["enableApi"].(bool) {
|
||||||
|
config.sources["enableApi"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.Holepunch != origValues["holepunch"].(bool) {
|
||||||
|
config.sources["holepunch"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
|
||||||
|
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
|
||||||
|
// }
|
||||||
|
|
||||||
|
return *version, *showConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDurations parses the duration strings into time.Duration
|
||||||
|
func (c *OlmConfig) parseDurations() error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Parse ping interval
|
||||||
|
if c.PingInterval != "" {
|
||||||
|
c.PingIntervalDuration, err = time.ParseDuration(c.PingInterval)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", c.PingInterval)
|
||||||
|
c.PingIntervalDuration = 3 * time.Second
|
||||||
|
c.PingInterval = "3s"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.PingIntervalDuration = 3 * time.Second
|
||||||
|
c.PingInterval = "3s"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse ping timeout
|
||||||
|
if c.PingTimeout != "" {
|
||||||
|
c.PingTimeoutDuration, err = time.ParseDuration(c.PingTimeout)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", c.PingTimeout)
|
||||||
|
c.PingTimeoutDuration = 5 * time.Second
|
||||||
|
c.PingTimeout = "5s"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.PingTimeoutDuration = 5 * time.Second
|
||||||
|
c.PingTimeout = "5s"
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeConfigs merges source config into destination (only non-empty values)
|
||||||
|
// Also tracks that these values came from a file
|
||||||
|
func mergeConfigs(dest, src *OlmConfig) {
|
||||||
|
if src.Endpoint != "" {
|
||||||
|
dest.Endpoint = src.Endpoint
|
||||||
|
dest.sources["endpoint"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.ID != "" {
|
||||||
|
dest.ID = src.ID
|
||||||
|
dest.sources["id"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.Secret != "" {
|
||||||
|
dest.Secret = src.Secret
|
||||||
|
dest.sources["secret"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.OrgID != "" {
|
||||||
|
dest.OrgID = src.OrgID
|
||||||
|
dest.sources["org"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.UserToken != "" {
|
||||||
|
dest.UserToken = src.UserToken
|
||||||
|
dest.sources["userToken"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.MTU != 0 && src.MTU != 1280 {
|
||||||
|
dest.MTU = src.MTU
|
||||||
|
dest.sources["mtu"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.DNS != "" && src.DNS != "8.8.8.8" {
|
||||||
|
dest.DNS = src.DNS
|
||||||
|
dest.sources["dns"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.LogLevel != "" && src.LogLevel != "INFO" {
|
||||||
|
dest.LogLevel = src.LogLevel
|
||||||
|
dest.sources["logLevel"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.InterfaceName != "" && src.InterfaceName != "olm" {
|
||||||
|
dest.InterfaceName = src.InterfaceName
|
||||||
|
dest.sources["interface"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.HTTPAddr != "" && src.HTTPAddr != ":9452" {
|
||||||
|
dest.HTTPAddr = src.HTTPAddr
|
||||||
|
dest.sources["httpAddr"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.SocketPath != "" {
|
||||||
|
// Check if it's not the default for any OS
|
||||||
|
isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm"
|
||||||
|
if !isDefault {
|
||||||
|
dest.SocketPath = src.SocketPath
|
||||||
|
dest.sources["socketPath"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if src.PingInterval != "" && src.PingInterval != "3s" {
|
||||||
|
dest.PingInterval = src.PingInterval
|
||||||
|
dest.sources["pingInterval"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.PingTimeout != "" && src.PingTimeout != "5s" {
|
||||||
|
dest.PingTimeout = src.PingTimeout
|
||||||
|
dest.sources["pingTimeout"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.TlsClientCert != "" {
|
||||||
|
dest.TlsClientCert = src.TlsClientCert
|
||||||
|
dest.sources["tlsClientCert"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
// For booleans, we always take the source value if explicitly set
|
||||||
|
if src.EnableAPI {
|
||||||
|
dest.EnableAPI = src.EnableAPI
|
||||||
|
dest.sources["enableApi"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.Holepunch {
|
||||||
|
dest.Holepunch = src.Holepunch
|
||||||
|
dest.sources["holepunch"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
// if src.DoNotCreateNewClient {
|
||||||
|
// dest.DoNotCreateNewClient = src.DoNotCreateNewClient
|
||||||
|
// dest.sources["doNotCreateNewClient"] = string(SourceFile)
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveConfig saves the current configuration to the config file
|
||||||
|
func SaveConfig(config *OlmConfig) error {
|
||||||
|
configPath := getOlmConfigPath()
|
||||||
|
data, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal config: %w", err)
|
||||||
|
}
|
||||||
|
return os.WriteFile(configPath, data, 0644)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShowConfig prints the configuration and the source of each value
|
||||||
|
func (c *OlmConfig) ShowConfig() {
|
||||||
|
configPath := getOlmConfigPath()
|
||||||
|
|
||||||
|
fmt.Println("\n=== Olm Configuration ===\n")
|
||||||
|
fmt.Printf("Config File: %s\n", configPath)
|
||||||
|
|
||||||
|
// Check if config file exists
|
||||||
|
if _, err := os.Stat(configPath); err == nil {
|
||||||
|
fmt.Printf("Config File Status: ✓ exists\n")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Config File Status: ✗ not found\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n--- Configuration Values ---")
|
||||||
|
fmt.Println("(Format: Setting = Value [source])\n")
|
||||||
|
|
||||||
|
// Helper to get source or default
|
||||||
|
getSource := func(key string) string {
|
||||||
|
if source, ok := c.sources[key]; ok {
|
||||||
|
return source
|
||||||
|
}
|
||||||
|
return string(SourceDefault)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to format value (mask secrets)
|
||||||
|
formatValue := func(key, value string) string {
|
||||||
|
if key == "secret" && value != "" {
|
||||||
|
if len(value) > 8 {
|
||||||
|
return value[:4] + "****" + value[len(value)-4:]
|
||||||
|
}
|
||||||
|
return "****"
|
||||||
|
}
|
||||||
|
if value == "" {
|
||||||
|
return "(not set)"
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection settings
|
||||||
|
fmt.Println("Connection:")
|
||||||
|
fmt.Printf(" endpoint = %s [%s]\n", formatValue("endpoint", c.Endpoint), getSource("endpoint"))
|
||||||
|
fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id"))
|
||||||
|
fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret"))
|
||||||
|
fmt.Printf(" org = %s [%s]\n", formatValue("org", c.OrgID), getSource("org"))
|
||||||
|
fmt.Printf(" user-token = %s [%s]\n", formatValue("userToken", c.UserToken), getSource("userToken"))
|
||||||
|
|
||||||
|
// Network settings
|
||||||
|
fmt.Println("\nNetwork:")
|
||||||
|
fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu"))
|
||||||
|
fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns"))
|
||||||
|
fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface"))
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
fmt.Println("\nLogging:")
|
||||||
|
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
|
||||||
|
|
||||||
|
// API server
|
||||||
|
fmt.Println("\nAPI Server:")
|
||||||
|
fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi"))
|
||||||
|
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
|
||||||
|
fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath"))
|
||||||
|
|
||||||
|
// Timing
|
||||||
|
fmt.Println("\nTiming:")
|
||||||
|
fmt.Printf(" ping-interval = %s [%s]\n", c.PingInterval, getSource("pingInterval"))
|
||||||
|
fmt.Printf(" ping-timeout = %s [%s]\n", c.PingTimeout, getSource("pingTimeout"))
|
||||||
|
|
||||||
|
// Advanced
|
||||||
|
fmt.Println("\nAdvanced:")
|
||||||
|
fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch"))
|
||||||
|
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
|
||||||
|
if c.TlsClientCert != "" {
|
||||||
|
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Source legend
|
||||||
|
fmt.Println("\n--- Source Legend ---")
|
||||||
|
fmt.Println(" default = Built-in default value")
|
||||||
|
fmt.Println(" file = Loaded from config file")
|
||||||
|
fmt.Println(" environment = Set via environment variable")
|
||||||
|
fmt.Println(" cli = Provided as command-line argument")
|
||||||
|
fmt.Println("\nPriority: cli > environment > file > default")
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
523
diff
Normal file
523
diff
Normal file
@@ -0,0 +1,523 @@
|
|||||||
|
diff --git a/api/api.go b/api/api.go
|
||||||
|
index dd07751..0d2e4ef 100644
|
||||||
|
--- a/api/api.go
|
||||||
|
+++ b/api/api.go
|
||||||
|
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
+// SwitchOrgRequest defines the structure for switching organizations
|
||||||
|
+type SwitchOrgRequest struct {
|
||||||
|
+ OrgID string `json:"orgId"`
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
// PeerStatus represents the status of a peer connection
|
||||||
|
type PeerStatus struct {
|
||||||
|
SiteID int `json:"siteId"`
|
||||||
|
@@ -35,6 +40,7 @@ type StatusResponse struct {
|
||||||
|
Registered bool `json:"registered"`
|
||||||
|
TunnelIP string `json:"tunnelIP,omitempty"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
+ OrgID string `json:"orgId,omitempty"`
|
||||||
|
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -46,6 +52,7 @@ type API struct {
|
||||||
|
server *http.Server
|
||||||
|
connectionChan chan ConnectionRequest
|
||||||
|
shutdownChan chan struct{}
|
||||||
|
+ switchOrgChan chan SwitchOrgRequest
|
||||||
|
statusMu sync.RWMutex
|
||||||
|
peerStatuses map[int]*PeerStatus
|
||||||
|
connectedAt time.Time
|
||||||
|
@@ -53,6 +60,7 @@ type API struct {
|
||||||
|
isRegistered bool
|
||||||
|
tunnelIP string
|
||||||
|
version string
|
||||||
|
+ orgID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||||
|
@@ -61,6 +69,7 @@ func NewAPI(addr string) *API {
|
||||||
|
addr: addr,
|
||||||
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
|
shutdownChan: make(chan struct{}, 1),
|
||||||
|
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||||
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API {
|
||||||
|
socketPath: socketPath,
|
||||||
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
|
shutdownChan: make(chan struct{}, 1),
|
||||||
|
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||||
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -85,6 +95,7 @@ func (s *API) Start() error {
|
||||||
|
mux.HandleFunc("/connect", s.handleConnect)
|
||||||
|
mux.HandleFunc("/status", s.handleStatus)
|
||||||
|
mux.HandleFunc("/exit", s.handleExit)
|
||||||
|
+ mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||||
|
|
||||||
|
s.server = &http.Server{
|
||||||
|
Handler: mux,
|
||||||
|
@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
|
||||||
|
return s.shutdownChan
|
||||||
|
}
|
||||||
|
|
||||||
|
+// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
||||||
|
+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
||||||
|
+ return s.switchOrgChan
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||||
|
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) {
|
||||||
|
s.version = version
|
||||||
|
}
|
||||||
|
|
||||||
|
+// SetOrgID sets the org ID
|
||||||
|
+func (s *API) SetOrgID(orgID string) {
|
||||||
|
+ s.statusMu.Lock()
|
||||||
|
+ defer s.statusMu.Unlock()
|
||||||
|
+ s.orgID = orgID
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
// UpdatePeerRelayStatus updates only the relay status of a peer
|
||||||
|
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
|
Registered: s.isRegistered,
|
||||||
|
TunnelIP: s.tunnelIP,
|
||||||
|
Version: s.version,
|
||||||
|
+ OrgID: s.orgID,
|
||||||
|
PeerStatuses: s.peerStatuses,
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||||
|
"status": "shutdown initiated",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
+
|
||||||
|
+// handleSwitchOrg handles the /switch-org endpoint
|
||||||
|
+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||||
|
+ if r.Method != http.MethodPost {
|
||||||
|
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
+ return
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ var req SwitchOrgRequest
|
||||||
|
+ decoder := json.NewDecoder(r.Body)
|
||||||
|
+ if err := decoder.Decode(&req); err != nil {
|
||||||
|
+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||||
|
+ return
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Validate required fields
|
||||||
|
+ if req.OrgID == "" {
|
||||||
|
+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
|
||||||
|
+ return
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||||
|
+
|
||||||
|
+ // Send the request to the main goroutine
|
||||||
|
+ select {
|
||||||
|
+ case s.switchOrgChan <- req:
|
||||||
|
+ // Signal sent successfully
|
||||||
|
+ default:
|
||||||
|
+ // Channel already has a signal, don't block
|
||||||
|
+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests)
|
||||||
|
+ return
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Return a success response
|
||||||
|
+ w.Header().Set("Content-Type", "application/json")
|
||||||
|
+ w.WriteHeader(http.StatusAccepted)
|
||||||
|
+ json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
+ "status": "org switch initiated",
|
||||||
|
+ "orgId": req.OrgID,
|
||||||
|
+ })
|
||||||
|
+}
|
||||||
|
diff --git a/olm/olm.go b/olm/olm.go
|
||||||
|
index 78080c4..5e292d6 100644
|
||||||
|
--- a/olm/olm.go
|
||||||
|
+++ b/olm/olm.go
|
||||||
|
@@ -58,6 +58,58 @@ type Config struct {
|
||||||
|
OrgID string
|
||||||
|
}
|
||||||
|
|
||||||
|
+// tunnelState holds all the active tunnel resources that need cleanup
|
||||||
|
+type tunnelState struct {
|
||||||
|
+ dev *device.Device
|
||||||
|
+ tdev tun.Device
|
||||||
|
+ uapiListener net.Listener
|
||||||
|
+ peerMonitor *peermonitor.PeerMonitor
|
||||||
|
+ stopRegister func()
|
||||||
|
+ connected bool
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
+// teardownTunnel cleans up all tunnel resources
|
||||||
|
+func teardownTunnel(state *tunnelState) {
|
||||||
|
+ if state == nil {
|
||||||
|
+ return
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ logger.Info("Tearing down tunnel...")
|
||||||
|
+
|
||||||
|
+ // Stop registration messages
|
||||||
|
+ if state.stopRegister != nil {
|
||||||
|
+ state.stopRegister()
|
||||||
|
+ state.stopRegister = nil
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Stop peer monitor
|
||||||
|
+ if state.peerMonitor != nil {
|
||||||
|
+ state.peerMonitor.Stop()
|
||||||
|
+ state.peerMonitor = nil
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Close UAPI listener
|
||||||
|
+ if state.uapiListener != nil {
|
||||||
|
+ state.uapiListener.Close()
|
||||||
|
+ state.uapiListener = nil
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Close WireGuard device
|
||||||
|
+ if state.dev != nil {
|
||||||
|
+ state.dev.Close()
|
||||||
|
+ state.dev = nil
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Close TUN device
|
||||||
|
+ if state.tdev != nil {
|
||||||
|
+ state.tdev.Close()
|
||||||
|
+ state.tdev = nil
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ state.connected = false
|
||||||
|
+ logger.Info("Tunnel teardown complete")
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
func Run(ctx context.Context, config Config) {
|
||||||
|
// Create a cancellable context for internal shutdown control
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
pingTimeout = config.PingTimeoutDuration
|
||||||
|
doHolepunch = config.Holepunch
|
||||||
|
privateKey wgtypes.Key
|
||||||
|
- connected bool
|
||||||
|
- dev *device.Device
|
||||||
|
wgData WgData
|
||||||
|
holePunchData HolePunchData
|
||||||
|
- uapiListener net.Listener
|
||||||
|
- tdev tun.Device
|
||||||
|
+ orgID = config.OrgID
|
||||||
|
)
|
||||||
|
|
||||||
|
+ // Tunnel state that can be torn down and recreated
|
||||||
|
+ tunnel := &tunnelState{}
|
||||||
|
+
|
||||||
|
stopHolepunch = make(chan struct{})
|
||||||
|
stopPing = make(chan struct{})
|
||||||
|
|
||||||
|
@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
}
|
||||||
|
|
||||||
|
apiServer.SetVersion(config.Version)
|
||||||
|
+ apiServer.SetOrgID(orgID)
|
||||||
|
if err := apiServer.Start(); err != nil {
|
||||||
|
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||||
|
}
|
||||||
|
@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
- if connected {
|
||||||
|
+ if tunnel.connected {
|
||||||
|
logger.Info("Already connected. Ignoring new connection request.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
- if stopRegister != nil {
|
||||||
|
- stopRegister()
|
||||||
|
- stopRegister = nil
|
||||||
|
+ if tunnel.stopRegister != nil {
|
||||||
|
+ tunnel.stopRegister()
|
||||||
|
+ tunnel.stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
close(stopHolepunch)
|
||||||
|
@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
// if there is an existing tunnel then close it
|
||||||
|
- if dev != nil {
|
||||||
|
+ if tunnel.dev != nil {
|
||||||
|
logger.Info("Got new message. Closing existing tunnel!")
|
||||||
|
- dev.Close()
|
||||||
|
+ tunnel.dev.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
- tdev, err = func() (tun.Device, error) {
|
||||||
|
+ tunnel.tdev, err = func() (tun.Device, error) {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
interfaceName, err := findUnusedUTUN()
|
||||||
|
if err != nil {
|
||||||
|
@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
- if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
||||||
|
+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil {
|
||||||
|
interfaceName = realInterfaceName
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||||
|
+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||||
|
|
||||||
|
- uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||||
|
+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
- conn, err := uapiListener.Accept()
|
||||||
|
+ conn, err := tunnel.uapiListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
- go dev.IpcHandle(conn)
|
||||||
|
+ go tunnel.dev.IpcHandle(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
logger.Info("UAPI listener started")
|
||||||
|
|
||||||
|
- if err = dev.Up(); err != nil {
|
||||||
|
+ if err = tunnel.dev.Up(); err != nil {
|
||||||
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
|
}
|
||||||
|
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||||
|
@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
apiServer.SetTunnelIP(wgData.TunnelIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
- peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
|
+ tunnel.peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
|
func(siteID int, connected bool, rtt time.Duration) {
|
||||||
|
if apiServer != nil {
|
||||||
|
// Find the site config to get endpoint information
|
||||||
|
@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
},
|
||||||
|
fixKey(privateKey.String()),
|
||||||
|
olm,
|
||||||
|
- dev,
|
||||||
|
+ tunnel.dev,
|
||||||
|
doHolepunch,
|
||||||
|
)
|
||||||
|
|
||||||
|
@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
// Format the endpoint before configuring the peer.
|
||||||
|
site.Endpoint = formatEndpoint(site.Endpoint)
|
||||||
|
|
||||||
|
- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
|
||||||
|
+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to configure peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
logger.Info("Configured peer %s", site.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
- peerMonitor.Start()
|
||||||
|
+ tunnel.peerMonitor.Start()
|
||||||
|
|
||||||
|
if apiServer != nil {
|
||||||
|
apiServer.SetRegistered(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
- connected = true
|
||||||
|
+ tunnel.connected = true
|
||||||
|
|
||||||
|
logger.Info("WireGuard device created.")
|
||||||
|
})
|
||||||
|
@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the peer in WireGuard
|
||||||
|
- if dev != nil {
|
||||||
|
+ if tunnel.dev != nil {
|
||||||
|
// Find the existing peer to get old data
|
||||||
|
var oldRemoteSubnets string
|
||||||
|
var oldPublicKey string
|
||||||
|
@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
// If the public key has changed, remove the old peer first
|
||||||
|
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
|
||||||
|
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
|
||||||
|
- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
|
||||||
|
+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil {
|
||||||
|
logger.Error("Failed to remove old peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
// Format the endpoint before updating the peer.
|
||||||
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
|
|
||||||
|
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to update peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the peer to WireGuard
|
||||||
|
- if dev != nil {
|
||||||
|
+ if tunnel.dev != nil {
|
||||||
|
// Format the endpoint before adding the new peer.
|
||||||
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
|
|
||||||
|
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to add peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the peer from WireGuard
|
||||||
|
- if dev != nil {
|
||||||
|
- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
||||||
|
+ if tunnel.dev != nil {
|
||||||
|
+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
||||||
|
logger.Error("Failed to remove peer: %v", err)
|
||||||
|
// Send error response if needed
|
||||||
|
return
|
||||||
|
@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||||
|
+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
|
||||||
|
@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
apiServer.SetConnectionStatus(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
- if connected {
|
||||||
|
+ if tunnel.connected {
|
||||||
|
logger.Debug("Already connected, skipping registration")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
|
||||||
|
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
|
||||||
|
|
||||||
|
- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
"publicKey": publicKey.String(),
|
||||||
|
"relay": !doHolepunch,
|
||||||
|
"olmVersion": config.Version,
|
||||||
|
- "orgId": config.OrgID,
|
||||||
|
+ "orgId": orgID,
|
||||||
|
}, 1*time.Second)
|
||||||
|
|
||||||
|
go keepSendingPing(olm)
|
||||||
|
@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
}
|
||||||
|
defer olm.Close()
|
||||||
|
|
||||||
|
+ // Listen for org switch requests from the API (after olm is created)
|
||||||
|
+ if apiServer != nil {
|
||||||
|
+ go func() {
|
||||||
|
+ for req := range apiServer.GetSwitchOrgChannel() {
|
||||||
|
+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID)
|
||||||
|
+
|
||||||
|
+ // Update the orgId
|
||||||
|
+ orgID = req.OrgID
|
||||||
|
+
|
||||||
|
+ // Teardown existing tunnel
|
||||||
|
+ teardownTunnel(tunnel)
|
||||||
|
+
|
||||||
|
+ // Reset tunnel state
|
||||||
|
+ tunnel = &tunnelState{}
|
||||||
|
+
|
||||||
|
+ // Stop holepunch
|
||||||
|
+ select {
|
||||||
|
+ case <-stopHolepunch:
|
||||||
|
+ // Channel already closed
|
||||||
|
+ default:
|
||||||
|
+ close(stopHolepunch)
|
||||||
|
+ }
|
||||||
|
+ stopHolepunch = make(chan struct{})
|
||||||
|
+
|
||||||
|
+ // Clear API server state
|
||||||
|
+ apiServer.SetRegistered(false)
|
||||||
|
+ apiServer.SetTunnelIP("")
|
||||||
|
+ apiServer.SetOrgID(orgID)
|
||||||
|
+
|
||||||
|
+ // Send new registration message with updated orgId
|
||||||
|
+ publicKey := privateKey.PublicKey()
|
||||||
|
+ logger.Info("Sending registration message with new orgId: %s", orgID)
|
||||||
|
+
|
||||||
|
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
+ "publicKey": publicKey.String(),
|
||||||
|
+ "relay": !doHolepunch,
|
||||||
|
+ "olmVersion": config.Version,
|
||||||
|
+ "orgId": orgID,
|
||||||
|
+ }, 1*time.Second)
|
||||||
|
+ }
|
||||||
|
+ }()
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Info("Context cancelled")
|
||||||
|
@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
close(stopHolepunch)
|
||||||
|
}
|
||||||
|
|
||||||
|
- if stopRegister != nil {
|
||||||
|
- stopRegister()
|
||||||
|
- stopRegister = nil
|
||||||
|
+ if tunnel.stopRegister != nil {
|
||||||
|
+ tunnel.stopRegister()
|
||||||
|
+ tunnel.stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
close(stopPing)
|
||||||
|
}
|
||||||
|
|
||||||
|
- if peerMonitor != nil {
|
||||||
|
- peerMonitor.Stop()
|
||||||
|
- }
|
||||||
|
-
|
||||||
|
- if uapiListener != nil {
|
||||||
|
- uapiListener.Close()
|
||||||
|
- }
|
||||||
|
- if dev != nil {
|
||||||
|
- dev.Close()
|
||||||
|
- }
|
||||||
|
+ // Use teardownTunnel to clean up all tunnel resources
|
||||||
|
+ teardownTunnel(tunnel)
|
||||||
|
|
||||||
|
if apiServer != nil {
|
||||||
|
apiServer.Stop()
|
||||||
@@ -1,10 +1,15 @@
|
|||||||
services:
|
services:
|
||||||
newt:
|
olm:
|
||||||
image: fosrl/newt:latest
|
image: fosrl/olm:latest
|
||||||
container_name: newt
|
container_name: olm
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
environment:
|
environment:
|
||||||
- PANGOLIN_ENDPOINT=https://example.com
|
- PANGOLIN_ENDPOINT=https://example.com
|
||||||
- NEWT_ID=2ix2t8xk22ubpfy
|
- OLM_ID=vdqnz8rwgb95cnp
|
||||||
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
- OLM_SECRET=1sw05qv1tkfdb1k81zpw05nahnnjvmhxjvf746umwagddmdg
|
||||||
- LOG_LEVEL=DEBUG
|
cap_add:
|
||||||
|
- NET_ADMIN
|
||||||
|
- SYS_MODULE
|
||||||
|
devices:
|
||||||
|
- /dev/net/tun:/dev/net/tun
|
||||||
|
network_mode: host
|
||||||
@@ -4,7 +4,7 @@ set -e
|
|||||||
|
|
||||||
# first arg is `-f` or `--some-option`
|
# first arg is `-f` or `--some-option`
|
||||||
if [ "${1#-}" != "$1" ]; then
|
if [ "${1#-}" != "$1" ]; then
|
||||||
set -- newt "$@"
|
set -- olm "$@"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
exec "$@"
|
exec "$@"
|
||||||
279
get-olm.sh
Normal file
279
get-olm.sh
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Get Olm - Cross-platform installation script
|
||||||
|
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/olm/refs/heads/main/get-olm.sh | bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# GitHub repository info
|
||||||
|
REPO="fosrl/olm"
|
||||||
|
GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
|
||||||
|
|
||||||
|
# Function to print colored output
|
||||||
|
print_status() {
|
||||||
|
echo -e "${GREEN}[INFO]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
print_warning() {
|
||||||
|
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
print_error() {
|
||||||
|
echo -e "${RED}[ERROR]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Function to get latest version from GitHub API
|
||||||
|
get_latest_version() {
|
||||||
|
local latest_info
|
||||||
|
|
||||||
|
if command -v curl >/dev/null 2>&1; then
|
||||||
|
latest_info=$(curl -fsSL "$GITHUB_API_URL" 2>/dev/null)
|
||||||
|
elif command -v wget >/dev/null 2>&1; then
|
||||||
|
latest_info=$(wget -qO- "$GITHUB_API_URL" 2>/dev/null)
|
||||||
|
else
|
||||||
|
print_error "Neither curl nor wget is available. Please install one of them." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$latest_info" ]; then
|
||||||
|
print_error "Failed to fetch latest version information" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract version from JSON response (works without jq)
|
||||||
|
local version=$(echo "$latest_info" | grep '"tag_name"' | head -1 | sed 's/.*"tag_name": *"\([^"]*\)".*/\1/')
|
||||||
|
|
||||||
|
if [ -z "$version" ]; then
|
||||||
|
print_error "Could not parse version from GitHub API response" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Remove 'v' prefix if present
|
||||||
|
version=$(echo "$version" | sed 's/^v//')
|
||||||
|
|
||||||
|
echo "$version"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Detect OS and architecture
|
||||||
|
detect_platform() {
|
||||||
|
local os arch
|
||||||
|
|
||||||
|
# Detect OS
|
||||||
|
case "$(uname -s)" in
|
||||||
|
Linux*) os="linux" ;;
|
||||||
|
Darwin*) os="darwin" ;;
|
||||||
|
MINGW*|MSYS*|CYGWIN*) os="windows" ;;
|
||||||
|
FreeBSD*) os="freebsd" ;;
|
||||||
|
*)
|
||||||
|
print_error "Unsupported operating system: $(uname -s)"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
# Detect architecture
|
||||||
|
case "$(uname -m)" in
|
||||||
|
x86_64|amd64) arch="amd64" ;;
|
||||||
|
arm64|aarch64) arch="arm64" ;;
|
||||||
|
armv7l|armv6l)
|
||||||
|
if [ "$os" = "linux" ]; then
|
||||||
|
if [ "$(uname -m)" = "armv6l" ]; then
|
||||||
|
arch="arm32v6"
|
||||||
|
else
|
||||||
|
arch="arm32"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
arch="arm64" # Default for non-Linux ARM
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
riscv64)
|
||||||
|
if [ "$os" = "linux" ]; then
|
||||||
|
arch="riscv64"
|
||||||
|
else
|
||||||
|
print_error "RISC-V architecture only supported on Linux"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
print_error "Unsupported architecture: $(uname -m)"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
echo "${os}_${arch}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get installation directory
|
||||||
|
get_install_dir() {
|
||||||
|
local platform="$1"
|
||||||
|
|
||||||
|
if [[ "$platform" == *"windows"* ]]; then
|
||||||
|
echo "$HOME/bin"
|
||||||
|
else
|
||||||
|
# For Unix-like systems, prioritize system-wide directories for sudo access
|
||||||
|
# Check in order of preference: /usr/local/bin, /usr/bin, ~/.local/bin
|
||||||
|
if [ -d "/usr/local/bin" ]; then
|
||||||
|
echo "/usr/local/bin"
|
||||||
|
elif [ -d "/usr/bin" ]; then
|
||||||
|
echo "/usr/bin"
|
||||||
|
else
|
||||||
|
# Fallback to user directory if system directories don't exist
|
||||||
|
echo "$HOME/.local/bin"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if we need sudo for installation
|
||||||
|
need_sudo() {
|
||||||
|
local install_dir="$1"
|
||||||
|
|
||||||
|
# If installing to system directory and we don't have write permission, need sudo
|
||||||
|
if [[ "$install_dir" == "/usr/local/bin" || "$install_dir" == "/usr/bin" ]]; then
|
||||||
|
if [ ! -w "$install_dir" ] 2>/dev/null; then
|
||||||
|
return 0 # Need sudo
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
return 1 # Don't need sudo
|
||||||
|
}
|
||||||
|
|
||||||
|
# Download and install olm
|
||||||
|
install_olm() {
|
||||||
|
local platform="$1"
|
||||||
|
local install_dir="$2"
|
||||||
|
local binary_name="olm_${platform}"
|
||||||
|
local exe_suffix=""
|
||||||
|
|
||||||
|
# Add .exe suffix for Windows
|
||||||
|
if [[ "$platform" == *"windows"* ]]; then
|
||||||
|
binary_name="${binary_name}.exe"
|
||||||
|
exe_suffix=".exe"
|
||||||
|
fi
|
||||||
|
|
||||||
|
local download_url="${BASE_URL}/${binary_name}"
|
||||||
|
local temp_file="/tmp/olm${exe_suffix}"
|
||||||
|
local final_path="${install_dir}/olm${exe_suffix}"
|
||||||
|
|
||||||
|
print_status "Downloading olm from ${download_url}"
|
||||||
|
|
||||||
|
# Download the binary
|
||||||
|
if command -v curl >/dev/null 2>&1; then
|
||||||
|
curl -fsSL "$download_url" -o "$temp_file"
|
||||||
|
elif command -v wget >/dev/null 2>&1; then
|
||||||
|
wget -q "$download_url" -O "$temp_file"
|
||||||
|
else
|
||||||
|
print_error "Neither curl nor wget is available. Please install one of them."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if we need sudo for installation
|
||||||
|
local use_sudo=""
|
||||||
|
if need_sudo "$install_dir"; then
|
||||||
|
print_status "Administrator privileges required for system-wide installation"
|
||||||
|
if command -v sudo >/dev/null 2>&1; then
|
||||||
|
use_sudo="sudo"
|
||||||
|
else
|
||||||
|
print_error "sudo is required for system-wide installation but not available"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Create install directory if it doesn't exist
|
||||||
|
if [ -n "$use_sudo" ]; then
|
||||||
|
$use_sudo mkdir -p "$install_dir"
|
||||||
|
else
|
||||||
|
mkdir -p "$install_dir"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Move binary to install directory
|
||||||
|
if [ -n "$use_sudo" ]; then
|
||||||
|
$use_sudo mv "$temp_file" "$final_path"
|
||||||
|
$use_sudo chmod +x "$final_path"
|
||||||
|
else
|
||||||
|
mv "$temp_file" "$final_path"
|
||||||
|
chmod +x "$final_path"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_status "olm installed to ${final_path}"
|
||||||
|
|
||||||
|
# Check if install directory is in PATH (only warn for non-system directories)
|
||||||
|
if [[ "$install_dir" != "/usr/local/bin" && "$install_dir" != "/usr/bin" ]]; then
|
||||||
|
if ! echo "$PATH" | grep -q "$install_dir"; then
|
||||||
|
print_warning "Install directory ${install_dir} is not in your PATH."
|
||||||
|
print_warning "Add it to your PATH by adding this line to your shell profile:"
|
||||||
|
print_warning " export PATH=\"${install_dir}:\$PATH\""
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Verify installation
|
||||||
|
verify_installation() {
|
||||||
|
local install_dir="$1"
|
||||||
|
local exe_suffix=""
|
||||||
|
|
||||||
|
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||||
|
exe_suffix=".exe"
|
||||||
|
fi
|
||||||
|
|
||||||
|
local olm_path="${install_dir}/olm${exe_suffix}"
|
||||||
|
|
||||||
|
if [ -f "$olm_path" ] && [ -x "$olm_path" ]; then
|
||||||
|
print_status "Installation successful!"
|
||||||
|
print_status "olm version: $("$olm_path" --version 2>/dev/null || echo "unknown")"
|
||||||
|
return 0
|
||||||
|
else
|
||||||
|
print_error "Installation failed. Binary not found or not executable."
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main installation process
|
||||||
|
main() {
|
||||||
|
print_status "Installing latest version of olm..."
|
||||||
|
|
||||||
|
# Get latest version
|
||||||
|
print_status "Fetching latest version from GitHub..."
|
||||||
|
VERSION=$(get_latest_version)
|
||||||
|
print_status "Latest version: v${VERSION}"
|
||||||
|
|
||||||
|
# Set base URL with the fetched version
|
||||||
|
BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}"
|
||||||
|
|
||||||
|
# Detect platform
|
||||||
|
PLATFORM=$(detect_platform)
|
||||||
|
print_status "Detected platform: ${PLATFORM}"
|
||||||
|
|
||||||
|
# Get install directory
|
||||||
|
INSTALL_DIR=$(get_install_dir "$PLATFORM")
|
||||||
|
print_status "Install directory: ${INSTALL_DIR}"
|
||||||
|
|
||||||
|
# Inform user about system-wide installation
|
||||||
|
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
|
||||||
|
print_status "Installing system-wide for sudo access"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install olm
|
||||||
|
install_olm "$PLATFORM" "$INSTALL_DIR"
|
||||||
|
|
||||||
|
# Verify installation
|
||||||
|
if verify_installation "$INSTALL_DIR"; then
|
||||||
|
print_status "olm is ready to use!"
|
||||||
|
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
|
||||||
|
print_status "olm is installed system-wide and accessible via sudo"
|
||||||
|
fi
|
||||||
|
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||||
|
print_status "Run 'olm --help' to get started"
|
||||||
|
else
|
||||||
|
print_status "Run 'olm --help' or 'sudo olm --help' to get started"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run main function
|
||||||
|
main "$@"
|
||||||
34
go.mod
34
go.mod
@@ -1,20 +1,22 @@
|
|||||||
module github.com/fosrl/newt
|
module github.com/fosrl/olm
|
||||||
|
|
||||||
go 1.23.1
|
go 1.25
|
||||||
|
|
||||||
toolchain go1.23.2
|
|
||||||
|
|
||||||
require golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/google/btree v1.1.2 // indirect
|
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7
|
||||||
github.com/gorilla/websocket v1.5.3 // indirect
|
github.com/vishvananda/netlink v1.3.1
|
||||||
golang.org/x/crypto v0.28.0 // indirect
|
golang.org/x/crypto v0.43.0
|
||||||
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect
|
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
|
||||||
golang.org/x/net v0.30.0 // indirect
|
golang.org/x/sys v0.37.0
|
||||||
golang.org/x/sys v0.26.0 // indirect
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||||
golang.org/x/time v0.7.0 // indirect
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
)
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
|
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
|
require (
|
||||||
|
github.com/gorilla/websocket v1.5.3 // indirect
|
||||||
|
github.com/vishvananda/netns v0.0.5 // indirect
|
||||||
|
golang.org/x/net v0.45.0 // indirect
|
||||||
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
|
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.6.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
48
go.sum
48
go.sum
@@ -1,22 +1,34 @@
|
|||||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 h1:6bSU8Efyhx1SR53iSw1Wjk5V8vDfizGAudq/GlE9b+o=
|
||||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7/go.mod h1:Ac0k2FmAMC+hu21rAK+p7EnnEGrqKO/QZuGTVHA/XDM=
|
||||||
|
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||||
|
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||||
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA=
|
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||||
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU=
|
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
|
||||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||||
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||||
|
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||||
|
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs=
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||||
|
|||||||
351
holepunch/holepunch.go
Normal file
351
holepunch/holepunch.go
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
package holepunch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/bind"
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
"golang.org/x/exp/rand"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DomainResolver is a function type for resolving domains to IP addresses
|
||||||
|
type DomainResolver func(string) (string, error)
|
||||||
|
|
||||||
|
// ExitNode represents a WireGuard exit node for hole punching
|
||||||
|
type ExitNode struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager handles UDP hole punching operations
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
stopChan chan struct{}
|
||||||
|
sharedBind *bind.SharedBind
|
||||||
|
olmID string
|
||||||
|
token string
|
||||||
|
domainResolver DomainResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new hole punch manager
|
||||||
|
func NewManager(sharedBind *bind.SharedBind, olmID string, domainResolver DomainResolver) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
sharedBind: sharedBind,
|
||||||
|
olmID: olmID,
|
||||||
|
domainResolver: domainResolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetToken updates the authentication token used for hole punching
|
||||||
|
func (m *Manager) SetToken(token string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.token = token
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning returns whether hole punching is currently active
|
||||||
|
func (m *Manager) IsRunning() bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.running
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops any ongoing hole punch operations
|
||||||
|
func (m *Manager) Stop() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if !m.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.stopChan != nil {
|
||||||
|
close(m.stopChan)
|
||||||
|
m.stopChan = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.running = false
|
||||||
|
logger.Info("Hole punch manager stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartMultipleExitNodes starts hole punching to multiple exit nodes
|
||||||
|
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
|
||||||
|
if m.running {
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Debug("UDP hole punch already running, skipping new request")
|
||||||
|
return fmt.Errorf("hole punch already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(exitNodes) == 0 {
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Warn("No exit nodes provided for hole punching")
|
||||||
|
return fmt.Errorf("no exit nodes provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.running = true
|
||||||
|
m.stopChan = make(chan struct{})
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||||
|
|
||||||
|
go m.runMultipleExitNodes(exitNodes)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
|
||||||
|
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
|
||||||
|
if m.running {
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Debug("UDP hole punch already running, skipping new request")
|
||||||
|
return fmt.Errorf("hole punch already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.running = true
|
||||||
|
m.stopChan = make(chan struct{})
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
|
||||||
|
|
||||||
|
go m.runSingleEndpoint(endpoint, serverPubKey)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runMultipleExitNodes performs hole punching to multiple exit nodes
|
||||||
|
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
||||||
|
defer func() {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.running = false
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Resolve all endpoints upfront
|
||||||
|
type resolvedExitNode struct {
|
||||||
|
remoteAddr *net.UDPAddr
|
||||||
|
publicKey string
|
||||||
|
endpointName string
|
||||||
|
}
|
||||||
|
|
||||||
|
var resolvedNodes []resolvedExitNode
|
||||||
|
for _, exitNode := range exitNodes {
|
||||||
|
host, err := m.domainResolver(exitNode.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := net.JoinHostPort(host, "21820")
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||||
|
remoteAddr: remoteAddr,
|
||||||
|
publicKey: exitNode.PublicKey,
|
||||||
|
endpointName: exitNode.Endpoint,
|
||||||
|
})
|
||||||
|
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resolvedNodes) == 0 {
|
||||||
|
logger.Error("No exit nodes could be resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send initial hole punch to all exit nodes
|
||||||
|
for _, node := range resolvedNodes {
|
||||||
|
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||||
|
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(250 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
timeout := time.NewTimer(15 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.stopChan:
|
||||||
|
logger.Debug("Hole punch stopped by signal")
|
||||||
|
return
|
||||||
|
case <-timeout.C:
|
||||||
|
logger.Debug("Hole punch timeout reached")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
// Send hole punch to all exit nodes
|
||||||
|
for _, node := range resolvedNodes {
|
||||||
|
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||||
|
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runSingleEndpoint performs hole punching to a single endpoint
|
||||||
|
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
|
||||||
|
defer func() {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.running = false
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
||||||
|
}()
|
||||||
|
|
||||||
|
host, err := m.domainResolver(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := net.JoinHostPort(host, "21820")
|
||||||
|
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute once immediately before starting the loop
|
||||||
|
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||||
|
logger.Warn("Failed to send initial hole punch: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(250 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
timeout := time.NewTimer(15 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.stopChan:
|
||||||
|
logger.Debug("Hole punch stopped by signal")
|
||||||
|
return
|
||||||
|
case <-timeout.C:
|
||||||
|
logger.Debug("Hole punch timeout reached")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||||
|
logger.Debug("Failed to send hole punch: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendHolePunch sends an encrypted hole punch packet using the shared bind
|
||||||
|
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
token := m.token
|
||||||
|
olmID := m.olmID
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if serverPubKey == "" || token == "" {
|
||||||
|
return fmt.Errorf("server public key or OLM token is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := struct {
|
||||||
|
OlmID string `json:"olmId"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}{
|
||||||
|
OlmID: olmID,
|
||||||
|
Token: token,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert payload to JSON
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the payload using the server's WireGuard public key
|
||||||
|
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encrypt payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(encryptedPayload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write to UDP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
|
||||||
|
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
|
||||||
|
// Generate an ephemeral keypair for this message
|
||||||
|
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||||
|
}
|
||||||
|
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||||
|
|
||||||
|
// Parse the server's public key
|
||||||
|
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use X25519 for key exchange
|
||||||
|
var ephPrivKeyFixed [32]byte
|
||||||
|
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||||
|
|
||||||
|
// Perform X25519 key exchange
|
||||||
|
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an AEAD cipher using the shared secret
|
||||||
|
aead, err := chacha20poly1305.New(sharedSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random nonce
|
||||||
|
nonce := make([]byte, aead.NonceSize())
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the payload
|
||||||
|
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||||
|
|
||||||
|
// Prepare the final encrypted message
|
||||||
|
encryptedMsg := struct {
|
||||||
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||||
|
Nonce []byte `json:"nonce"`
|
||||||
|
Ciphertext []byte `json:"ciphertext"`
|
||||||
|
}{
|
||||||
|
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||||
|
Nonce: nonce,
|
||||||
|
Ciphertext: ciphertext,
|
||||||
|
}
|
||||||
|
|
||||||
|
return encryptedMsg, nil
|
||||||
|
}
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
package logger
|
|
||||||
|
|
||||||
type LogLevel int
|
|
||||||
|
|
||||||
const (
|
|
||||||
DEBUG LogLevel = iota
|
|
||||||
INFO
|
|
||||||
WARN
|
|
||||||
ERROR
|
|
||||||
FATAL
|
|
||||||
)
|
|
||||||
|
|
||||||
var levelStrings = map[LogLevel]string{
|
|
||||||
DEBUG: "DEBUG",
|
|
||||||
INFO: "INFO",
|
|
||||||
WARN: "WARN",
|
|
||||||
ERROR: "ERROR",
|
|
||||||
FATAL: "FATAL",
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns the string representation of the log level
|
|
||||||
func (l LogLevel) String() string {
|
|
||||||
if s, ok := levelStrings[l]; ok {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return "UNKNOWN"
|
|
||||||
}
|
|
||||||
106
logger/logger.go
106
logger/logger.go
@@ -1,106 +0,0 @@
|
|||||||
package logger
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Logger struct holds the logger instance
|
|
||||||
type Logger struct {
|
|
||||||
logger *log.Logger
|
|
||||||
level LogLevel
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
defaultLogger *Logger
|
|
||||||
once sync.Once
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewLogger creates a new logger instance
|
|
||||||
func NewLogger() *Logger {
|
|
||||||
return &Logger{
|
|
||||||
logger: log.New(os.Stdout, "", 0),
|
|
||||||
level: DEBUG,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Init initializes the default logger
|
|
||||||
func Init() *Logger {
|
|
||||||
once.Do(func() {
|
|
||||||
defaultLogger = NewLogger()
|
|
||||||
})
|
|
||||||
return defaultLogger
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLogger returns the default logger instance
|
|
||||||
func GetLogger() *Logger {
|
|
||||||
if defaultLogger == nil {
|
|
||||||
Init()
|
|
||||||
}
|
|
||||||
return defaultLogger
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLevel sets the minimum logging level
|
|
||||||
func (l *Logger) SetLevel(level LogLevel) {
|
|
||||||
l.level = level
|
|
||||||
}
|
|
||||||
|
|
||||||
// log handles the actual logging
|
|
||||||
func (l *Logger) log(level LogLevel, format string, args ...interface{}) {
|
|
||||||
if level < l.level {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
|
||||||
message := fmt.Sprintf(format, args...)
|
|
||||||
l.logger.Printf("%s: %s %s", level.String(), timestamp, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug logs debug level messages
|
|
||||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
|
||||||
l.log(DEBUG, format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Info logs info level messages
|
|
||||||
func (l *Logger) Info(format string, args ...interface{}) {
|
|
||||||
l.log(INFO, format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Warn logs warning level messages
|
|
||||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
|
||||||
l.log(WARN, format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error logs error level messages
|
|
||||||
func (l *Logger) Error(format string, args ...interface{}) {
|
|
||||||
l.log(ERROR, format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fatal logs fatal level messages and exits
|
|
||||||
func (l *Logger) Fatal(format string, args ...interface{}) {
|
|
||||||
l.log(FATAL, format, args...)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Global helper functions
|
|
||||||
func Debug(format string, args ...interface{}) {
|
|
||||||
GetLogger().Debug(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Info(format string, args ...interface{}) {
|
|
||||||
GetLogger().Info(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Warn(format string, args ...interface{}) {
|
|
||||||
GetLogger().Warn(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Error(format string, args ...interface{}) {
|
|
||||||
GetLogger().Error(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Fatal(format string, args ...interface{}) {
|
|
||||||
GetLogger().Fatal(format, args...)
|
|
||||||
}
|
|
||||||
770
main.go
770
main.go
@@ -1,608 +1,220 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"strconv"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/proxy"
|
"github.com/fosrl/olm/olm"
|
||||||
"github.com/fosrl/newt/websocket"
|
|
||||||
|
|
||||||
"golang.org/x/net/icmp"
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
|
||||||
"golang.zx2c4.com/wireguard/device"
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type WgData struct {
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
ServerIP string `json:"serverIP"`
|
|
||||||
TunnelIP string `json:"tunnelIP"`
|
|
||||||
Targets TargetsByType `json:"targets"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TargetsByType struct {
|
|
||||||
UDP []string `json:"udp"`
|
|
||||||
TCP []string `json:"tcp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TargetData struct {
|
|
||||||
Targets []string `json:"targets"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func fixKey(key string) string {
|
|
||||||
// Remove any whitespace
|
|
||||||
key = strings.TrimSpace(key)
|
|
||||||
|
|
||||||
// Decode from base64
|
|
||||||
decoded, err := base64.StdEncoding.DecodeString(key)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Error decoding base64:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to hex
|
|
||||||
return hex.EncodeToString(decoded)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ping(tnet *netstack.Net, dst string) error {
|
|
||||||
logger.Info("Pinging %s", dst)
|
|
||||||
socket, err := tnet.Dial("ping4", dst)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create ICMP socket: %w", err)
|
|
||||||
}
|
|
||||||
defer socket.Close()
|
|
||||||
|
|
||||||
requestPing := icmp.Echo{
|
|
||||||
Seq: rand.Intn(1 << 16),
|
|
||||||
Data: []byte("gopher burrow"),
|
|
||||||
}
|
|
||||||
|
|
||||||
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal ICMP message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
|
|
||||||
return fmt.Errorf("failed to set read deadline: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
_, err = socket.Write(icmpBytes)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write ICMP packet: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := socket.Read(icmpBytes[:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read ICMP packet: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse ICMP packet: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
replyPing, ok := replyPacket.Body.(*icmp.Echo)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
|
|
||||||
return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q",
|
|
||||||
replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Ping latency: %v", time.Since(start))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
|
|
||||||
ticker := time.NewTicker(10 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
err := ping(tnet, serverIP)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Periodic ping failed: %v", err)
|
|
||||||
logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?")
|
|
||||||
}
|
|
||||||
case <-stopChan:
|
|
||||||
logger.Info("Stopping ping check")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func pingWithRetry(tnet *netstack.Net, dst string) error {
|
|
||||||
const (
|
|
||||||
maxAttempts = 5
|
|
||||||
retryDelay = 2 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
var lastErr error
|
|
||||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
|
||||||
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
|
|
||||||
|
|
||||||
if err := ping(tnet, dst); err != nil {
|
|
||||||
lastErr = err
|
|
||||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
|
||||||
|
|
||||||
if attempt < maxAttempts {
|
|
||||||
time.Sleep(retryDelay)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return fmt.Errorf("all ping attempts failed after %d tries, last error: %w",
|
|
||||||
maxAttempts, lastErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Successful ping
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// This shouldn't be reached due to the return in the loop, but added for completeness
|
|
||||||
return fmt.Errorf("unexpected error: all ping attempts failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseLogLevel(level string) logger.LogLevel {
|
|
||||||
switch strings.ToUpper(level) {
|
|
||||||
case "DEBUG":
|
|
||||||
return logger.DEBUG
|
|
||||||
case "INFO":
|
|
||||||
return logger.INFO
|
|
||||||
case "WARN":
|
|
||||||
return logger.WARN
|
|
||||||
case "ERROR":
|
|
||||||
return logger.ERROR
|
|
||||||
case "FATAL":
|
|
||||||
return logger.FATAL
|
|
||||||
default:
|
|
||||||
return logger.INFO // default to INFO if invalid level provided
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
|
||||||
switch level {
|
|
||||||
case logger.DEBUG:
|
|
||||||
return device.LogLevelVerbose
|
|
||||||
// case logger.INFO:
|
|
||||||
// return device.LogLevel
|
|
||||||
case logger.WARN:
|
|
||||||
return device.LogLevelError
|
|
||||||
case logger.ERROR, logger.FATAL:
|
|
||||||
return device.LogLevelSilent
|
|
||||||
default:
|
|
||||||
return device.LogLevelSilent
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func resolveDomain(domain string) (string, error) {
|
|
||||||
// Check if there's a port in the domain
|
|
||||||
host, port, err := net.SplitHostPort(domain)
|
|
||||||
if err != nil {
|
|
||||||
// No port found, use the domain as is
|
|
||||||
host = domain
|
|
||||||
port = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove any protocol prefix if present
|
|
||||||
if strings.HasPrefix(host, "http://") {
|
|
||||||
host = strings.TrimPrefix(host, "http://")
|
|
||||||
} else if strings.HasPrefix(host, "https://") {
|
|
||||||
host = strings.TrimPrefix(host, "https://")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lookup IP addresses
|
|
||||||
ips, err := net.LookupIP(host)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ips) == 0 {
|
|
||||||
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the first IPv4 address if available
|
|
||||||
var ipAddr string
|
|
||||||
for _, ip := range ips {
|
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
|
||||||
ipAddr = ipv4.String()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no IPv4 found, use the first IP (might be IPv6)
|
|
||||||
if ipAddr == "" {
|
|
||||||
ipAddr = ips[0].String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add port back if it existed
|
|
||||||
if port != "" {
|
|
||||||
ipAddr = net.JoinHostPort(ipAddr, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ipAddr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var (
|
// Check if we're running as a Windows service
|
||||||
endpoint string
|
if isWindowsService() {
|
||||||
id string
|
runService("OlmWireguardService", false, os.Args[1:])
|
||||||
secret string
|
fmt.Println("Running as Windows service")
|
||||||
mtu string
|
return
|
||||||
mtuInt int
|
|
||||||
dns string
|
|
||||||
privateKey wgtypes.Key
|
|
||||||
err error
|
|
||||||
logLevel string
|
|
||||||
)
|
|
||||||
|
|
||||||
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
|
|
||||||
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
|
||||||
id = os.Getenv("NEWT_ID")
|
|
||||||
secret = os.Getenv("NEWT_SECRET")
|
|
||||||
mtu = os.Getenv("MTU")
|
|
||||||
dns = os.Getenv("DNS")
|
|
||||||
logLevel = os.Getenv("LOG_LEVEL")
|
|
||||||
|
|
||||||
if endpoint == "" {
|
|
||||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
|
||||||
}
|
|
||||||
if id == "" {
|
|
||||||
flag.StringVar(&id, "id", "", "Newt ID")
|
|
||||||
}
|
|
||||||
if secret == "" {
|
|
||||||
flag.StringVar(&secret, "secret", "", "Newt secret")
|
|
||||||
}
|
|
||||||
if mtu == "" {
|
|
||||||
flag.StringVar(&mtu, "mtu", "1280", "MTU to use")
|
|
||||||
}
|
|
||||||
if dns == "" {
|
|
||||||
flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
|
|
||||||
}
|
|
||||||
if logLevel == "" {
|
|
||||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
|
||||||
}
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
logger.Init()
|
|
||||||
loggerLevel := parseLogLevel(logLevel)
|
|
||||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
|
||||||
|
|
||||||
// Validate required fields
|
|
||||||
if endpoint == "" || id == "" || secret == "" {
|
|
||||||
logger.Fatal("endpoint, id, and secret are required either via CLI flags or environment variables")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse the mtu string into an int
|
// Handle service management commands on Windows
|
||||||
mtuInt, err = strconv.Atoi(mtu)
|
if runtime.GOOS == "windows" {
|
||||||
if err != nil {
|
var command string
|
||||||
logger.Fatal("Failed to parse MTU: %v", err)
|
if len(os.Args) > 1 {
|
||||||
}
|
command = os.Args[1]
|
||||||
|
} else {
|
||||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
command = "default"
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to generate private key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new client
|
|
||||||
client, err := websocket.NewClient(
|
|
||||||
id, // CLI arg takes precedence
|
|
||||||
secret, // CLI arg takes precedence
|
|
||||||
endpoint,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to create client: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create TUN device and network stack
|
|
||||||
var tun tun.Device
|
|
||||||
var tnet *netstack.Net
|
|
||||||
var dev *device.Device
|
|
||||||
var pm *proxy.ProxyManager
|
|
||||||
var connected bool
|
|
||||||
var wgData WgData
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received terminate message")
|
|
||||||
if pm != nil {
|
|
||||||
pm.Stop()
|
|
||||||
}
|
}
|
||||||
if dev != nil {
|
|
||||||
dev.Close()
|
|
||||||
}
|
|
||||||
client.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
pingStopChan := make(chan struct{})
|
switch command {
|
||||||
defer close(pingStopChan)
|
case "install":
|
||||||
|
err := installService()
|
||||||
// Register handlers for different message types
|
|
||||||
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received registration message")
|
|
||||||
|
|
||||||
if connected {
|
|
||||||
logger.Info("Already connected! But I will send a ping anyway...")
|
|
||||||
// ping(tnet, wgData.ServerIP)
|
|
||||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Handle complete failure after all retries
|
fmt.Printf("Failed to install service: %v\n", err)
|
||||||
logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err)
|
os.Exit(1)
|
||||||
logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?")
|
|
||||||
}
|
}
|
||||||
|
fmt.Println("Service installed successfully")
|
||||||
return
|
return
|
||||||
}
|
case "remove", "uninstall":
|
||||||
|
err := removeService()
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
tun, tnet, err = netstack.CreateNetTUN(
|
|
||||||
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
|
||||||
[]netip.Addr{netip.MustParseAddr(dns)},
|
|
||||||
mtuInt)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to create TUN device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create WireGuard device
|
|
||||||
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
|
|
||||||
mapToWireGuardLogLevel(loggerLevel),
|
|
||||||
"wireguard: ",
|
|
||||||
))
|
|
||||||
|
|
||||||
endpoint, err := resolveDomain(wgData.Endpoint)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to resolve endpoint: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Configure WireGuard
|
|
||||||
config := fmt.Sprintf(`private_key=%s
|
|
||||||
public_key=%s
|
|
||||||
allowed_ip=%s/32
|
|
||||||
endpoint=%s
|
|
||||||
persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
|
||||||
|
|
||||||
err = dev.IpcSet(config)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to configure WireGuard device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bring up the device
|
|
||||||
err = dev.Up()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("WireGuard device created. Lets ping the server now...")
|
|
||||||
// Ping to bring the tunnel up on the server side quickly
|
|
||||||
// ping(tnet, wgData.ServerIP)
|
|
||||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
|
||||||
if err != nil {
|
|
||||||
// Handle complete failure after all retries
|
|
||||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !connected {
|
|
||||||
logger.Info("Starting ping check")
|
|
||||||
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create proxy manager
|
|
||||||
pm = proxy.NewProxyManager(tnet)
|
|
||||||
|
|
||||||
connected = true
|
|
||||||
|
|
||||||
// add the targets if there are any
|
|
||||||
if len(wgData.Targets.TCP) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wgData.Targets.UDP) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
|
||||||
}
|
|
||||||
|
|
||||||
err = pm.Start()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to start proxy manager: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.OnConnect(func() error {
|
|
||||||
publicKey := privateKey.PublicKey()
|
|
||||||
logger.Debug("Public key: %s", publicKey)
|
|
||||||
|
|
||||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
|
||||||
"publicKey": fmt.Sprintf("%s", publicKey),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Sent registration message")
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// Connect to the WebSocket server
|
|
||||||
if err := client.Connect(); err != nil {
|
|
||||||
logger.Fatal("Failed to connect to server: %v", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
// Wait for interrupt signal
|
|
||||||
sigCh := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
<-sigCh
|
|
||||||
|
|
||||||
// Cleanup
|
|
||||||
dev.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTargetData(data interface{}) (TargetData, error) {
|
|
||||||
var targetData TargetData
|
|
||||||
jsonData, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
return targetData, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &targetData); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
return targetData, err
|
|
||||||
}
|
|
||||||
return targetData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
|
||||||
for _, t := range targetData.Targets {
|
|
||||||
// Split the first number off of the target with : separator and use as the port
|
|
||||||
parts := strings.Split(t, ":")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
logger.Info("Invalid target format: %s", t)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the port as an int
|
|
||||||
port := 0
|
|
||||||
_, err := fmt.Sscanf(parts[0], "%d", &port)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Invalid port: %s", parts[0])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if action == "add" {
|
|
||||||
target := parts[1] + ":" + parts[2]
|
|
||||||
// Only remove the specific target if it exists
|
|
||||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Ignore "target not found" errors as this is expected for new targets
|
fmt.Printf("Failed to remove service: %v\n", err)
|
||||||
if !strings.Contains(err.Error(), "target not found") {
|
os.Exit(1)
|
||||||
logger.Error("Failed to remove existing target: %v", err)
|
}
|
||||||
|
fmt.Println("Service removed successfully")
|
||||||
|
return
|
||||||
|
case "start":
|
||||||
|
// Pass the remaining arguments (after "start") to the service
|
||||||
|
serviceArgs := os.Args[2:]
|
||||||
|
err := startService(serviceArgs)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to start service: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Println("Service started successfully")
|
||||||
|
return
|
||||||
|
case "stop":
|
||||||
|
err := stopService()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to stop service: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Println("Service stopped successfully")
|
||||||
|
return
|
||||||
|
case "status":
|
||||||
|
status, err := getServiceStatus()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to get service status: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("Service status: %s\n", status)
|
||||||
|
return
|
||||||
|
case "debug":
|
||||||
|
// get the status and if it is Not Installed then install it first
|
||||||
|
status, err := getServiceStatus()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to get service status: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if status == "Not Installed" {
|
||||||
|
err := installService()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to install service: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
fmt.Println("Service installed successfully, now running in debug mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the new target
|
// Pass the remaining arguments (after "debug") to the service
|
||||||
pm.AddTarget(proto, tunnelIP, port, target)
|
serviceArgs := os.Args[2:]
|
||||||
|
err = debugService(serviceArgs)
|
||||||
} else if action == "remove" {
|
|
||||||
logger.Info("Removing target with port %d", port)
|
|
||||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to remove target: %v", err)
|
fmt.Printf("Failed to debug service: %v\n", err)
|
||||||
return err
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
case "logs":
|
||||||
|
err := watchLogFile(false)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to watch log file: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case "config":
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
showServiceConfig()
|
||||||
|
} else {
|
||||||
|
fmt.Println("Service configuration is only available on Windows")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case "help", "--help", "-h":
|
||||||
|
fmt.Println("Olm WireGuard VPN Client")
|
||||||
|
fmt.Println("\nWindows Service Management:")
|
||||||
|
fmt.Println(" install Install the service")
|
||||||
|
fmt.Println(" remove Remove the service")
|
||||||
|
fmt.Println(" start [args] Start the service with optional arguments")
|
||||||
|
fmt.Println(" stop Stop the service")
|
||||||
|
fmt.Println(" status Show service status")
|
||||||
|
fmt.Println(" debug [args] Run service in debug mode with optional arguments")
|
||||||
|
fmt.Println(" logs Tail the service log file")
|
||||||
|
fmt.Println(" config Show current service configuration")
|
||||||
|
fmt.Println("\nExamples:")
|
||||||
|
fmt.Println(" olm start --enable-http --http-addr :9452")
|
||||||
|
fmt.Println(" olm debug --endpoint https://example.com --id myid --secret mysecret")
|
||||||
|
fmt.Println("\nFor console mode, run without arguments or with standard flags.")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// get the status and if it is Not Installed then install it first
|
||||||
|
status, err := getServiceStatus()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to get service status: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if status == "Not Installed" {
|
||||||
|
err := installService()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to install service: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Println("Service installed successfully, now running")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass the remaining arguments (after "debug") to the service
|
||||||
|
serviceArgs := os.Args[1:]
|
||||||
|
err = debugService(serviceArgs)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to debug service: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Setup Windows event logging if on Windows
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
setupWindowsEventLog()
|
||||||
|
} else {
|
||||||
|
// Initialize logger for non-Windows platforms
|
||||||
|
logger.Init()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load configuration from file, env vars, and CLI args
|
||||||
|
// Priority: CLI args > Env vars > Config file > Defaults
|
||||||
|
config, showVersion, showConfig, err := LoadConfig(os.Args[1:])
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to load configuration: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle --show-config flag
|
||||||
|
if showConfig {
|
||||||
|
config.ShowConfig()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
olmVersion := "version_replaceme"
|
||||||
|
if showVersion {
|
||||||
|
fmt.Println("Olm version " + olmVersion)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
logger.Info("Olm version " + olmVersion)
|
||||||
|
|
||||||
|
config.Version = olmVersion
|
||||||
|
|
||||||
|
if err := SaveConfig(config); err != nil {
|
||||||
|
logger.Error("Failed to save full olm config: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Saved full olm config with all options")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new olm.Config struct and copy values from the main config
|
||||||
|
olmConfig := olm.Config{
|
||||||
|
Endpoint: config.Endpoint,
|
||||||
|
ID: config.ID,
|
||||||
|
Secret: config.Secret,
|
||||||
|
UserToken: config.UserToken,
|
||||||
|
MTU: config.MTU,
|
||||||
|
DNS: config.DNS,
|
||||||
|
InterfaceName: config.InterfaceName,
|
||||||
|
LogLevel: config.LogLevel,
|
||||||
|
EnableAPI: config.EnableAPI,
|
||||||
|
HTTPAddr: config.HTTPAddr,
|
||||||
|
SocketPath: config.SocketPath,
|
||||||
|
Holepunch: config.Holepunch,
|
||||||
|
TlsClientCert: config.TlsClientCert,
|
||||||
|
PingIntervalDuration: config.PingIntervalDuration,
|
||||||
|
PingTimeoutDuration: config.PingTimeoutDuration,
|
||||||
|
Version: config.Version,
|
||||||
|
OrgID: config.OrgID,
|
||||||
|
// DoNotCreateNewClient: config.DoNotCreateNewClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a context that will be cancelled on interrupt signals
|
||||||
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
olm.Run(ctx, olmConfig)
|
||||||
}
|
}
|
||||||
|
|||||||
1
olm-binary.REMOVED.git-id
Normal file
1
olm-binary.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
573df1772c00fcb34ec68e575e973c460dc27ba8
|
||||||
1
olm-test.REMOVED.git-id
Normal file
1
olm-test.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
ba2c118fd96937229ef54dcd0b82fe5d53d94a87
|
||||||
88
olm.iss
Normal file
88
olm.iss
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
; Script generated by the Inno Setup Script Wizard.
|
||||||
|
; SEE THE DOCUMENTATION FOR DETAILS ON CREATING INNO SETUP SCRIPT FILES!
|
||||||
|
|
||||||
|
#define MyAppName "olm"
|
||||||
|
#define MyAppVersion "1.0.0"
|
||||||
|
#define MyAppPublisher "Fossorial Inc."
|
||||||
|
#define MyAppURL "https://pangolin.net"
|
||||||
|
#define MyAppExeName "olm.exe"
|
||||||
|
|
||||||
|
[Setup]
|
||||||
|
; NOTE: The value of AppId uniquely identifies this application. Do not use the same AppId value in installers for other applications.
|
||||||
|
; (To generate a new GUID, click Tools | Generate GUID inside the IDE.)
|
||||||
|
AppId={{44A24E4C-B616-476F-ADE7-8D56B930959E}
|
||||||
|
AppName={#MyAppName}
|
||||||
|
AppVersion={#MyAppVersion}
|
||||||
|
;AppVerName={#MyAppName} {#MyAppVersion}
|
||||||
|
AppPublisher={#MyAppPublisher}
|
||||||
|
AppPublisherURL={#MyAppURL}
|
||||||
|
AppSupportURL={#MyAppURL}
|
||||||
|
AppUpdatesURL={#MyAppURL}
|
||||||
|
DefaultDirName={autopf}\{#MyAppName}
|
||||||
|
UninstallDisplayIcon={app}\{#MyAppExeName}
|
||||||
|
; "ArchitecturesAllowed=x64compatible" specifies that Setup cannot run
|
||||||
|
; on anything but x64 and Windows 11 on Arm.
|
||||||
|
ArchitecturesAllowed=x64compatible
|
||||||
|
; "ArchitecturesInstallIn64BitMode=x64compatible" requests that the
|
||||||
|
; install be done in "64-bit mode" on x64 or Windows 11 on Arm,
|
||||||
|
; meaning it should use the native 64-bit Program Files directory and
|
||||||
|
; the 64-bit view of the registry.
|
||||||
|
ArchitecturesInstallIn64BitMode=x64compatible
|
||||||
|
DefaultGroupName={#MyAppName}
|
||||||
|
DisableProgramGroupPage=yes
|
||||||
|
; Uncomment the following line to run in non administrative install mode (install for current user only).
|
||||||
|
;PrivilegesRequired=lowest
|
||||||
|
OutputBaseFilename=mysetup
|
||||||
|
SolidCompression=yes
|
||||||
|
WizardStyle=modern
|
||||||
|
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed
|
||||||
|
RestartIfNeededByRun=no
|
||||||
|
ChangesEnvironment=true
|
||||||
|
|
||||||
|
[Languages]
|
||||||
|
Name: "english"; MessagesFile: "compiler:Default.isl"
|
||||||
|
|
||||||
|
[Files]
|
||||||
|
; The 'DestName' flag ensures that 'olm_windows_amd64.exe' is installed as 'olm.exe'
|
||||||
|
Source: "C:\Users\Administrator\Downloads\olm_windows_amd64.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}"; Flags: ignoreversion
|
||||||
|
Source: "C:\Users\Administrator\Downloads\wintun.dll"; DestDir: "{app}"; Flags: ignoreversion
|
||||||
|
; NOTE: Don't use "Flags: ignoreversion" on any shared system files
|
||||||
|
|
||||||
|
[Icons]
|
||||||
|
Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"
|
||||||
|
|
||||||
|
[Registry]
|
||||||
|
; Add the application's installation directory to the system PATH environment variable.
|
||||||
|
; HKLM (HKEY_LOCAL_MACHINE) is used for system-wide changes.
|
||||||
|
; The 'Path' variable is located under 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'.
|
||||||
|
; ValueType: expandsz allows for environment variables (like %ProgramFiles%) in the path.
|
||||||
|
; ValueData: "{olddata};{app}" appends the current application directory to the existing PATH.
|
||||||
|
; Flags: uninsdeletevalue ensures the entry is removed upon uninstallation.
|
||||||
|
; Check: IsWin64 ensures this is applied on 64-bit systems, which matches ArchitecturesAllowed.
|
||||||
|
[Registry]
|
||||||
|
; Add the application's installation directory to the system PATH.
|
||||||
|
Root: HKLM; Subkey: "SYSTEM\CurrentControlSet\Control\Session Manager\Environment"; \
|
||||||
|
ValueType: expandsz; ValueName: "Path"; ValueData: "{olddata};{app}"; \
|
||||||
|
Flags: uninsdeletevalue; Check: NeedsAddPath(ExpandConstant('{app}'))
|
||||||
|
|
||||||
|
[Code]
|
||||||
|
function NeedsAddPath(Path: string): boolean;
|
||||||
|
var
|
||||||
|
OrigPath: string;
|
||||||
|
begin
|
||||||
|
if not RegQueryStringValue(HKEY_LOCAL_MACHINE,
|
||||||
|
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||||
|
'Path', OrigPath)
|
||||||
|
then begin
|
||||||
|
// Path variable doesn't exist at all, so we definitely need to add it.
|
||||||
|
Result := True;
|
||||||
|
exit;
|
||||||
|
end;
|
||||||
|
|
||||||
|
// Perform a case-insensitive check to see if the path is already present.
|
||||||
|
// We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2).
|
||||||
|
if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then
|
||||||
|
Result := False
|
||||||
|
else
|
||||||
|
Result := True;
|
||||||
|
end;
|
||||||
885
olm/common.go
Normal file
885
olm/common.go
Normal file
@@ -0,0 +1,885 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/peermonitor"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
"golang.org/x/exp/rand"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WgData struct {
|
||||||
|
Sites []SiteConfig `json:"sites"`
|
||||||
|
TunnelIP string `json:"tunnelIP"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SiteConfig struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
ServerIP string `json:"serverIP"`
|
||||||
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
|
}
|
||||||
|
|
||||||
|
type TargetsByType struct {
|
||||||
|
UDP []string `json:"udp"`
|
||||||
|
TCP []string `json:"tcp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TargetData struct {
|
||||||
|
Targets []string `json:"targets"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HolePunchMessage struct {
|
||||||
|
NewtID string `json:"newtId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExitNode struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HolePunchData struct {
|
||||||
|
ExitNodes []ExitNode `json:"exitNodes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncryptedHolePunchMessage struct {
|
||||||
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||||
|
Nonce []byte `json:"nonce"`
|
||||||
|
Ciphertext []byte `json:"ciphertext"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
peerMonitor *peermonitor.PeerMonitor
|
||||||
|
stopHolepunch chan struct{}
|
||||||
|
stopRegister func()
|
||||||
|
stopPing chan struct{}
|
||||||
|
olmToken string
|
||||||
|
holePunchRunning bool
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ENV_WG_TUN_FD = "WG_TUN_FD"
|
||||||
|
ENV_WG_UAPI_FD = "WG_UAPI_FD"
|
||||||
|
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PeerAction represents a request to add, update, or remove a peer
|
||||||
|
type PeerAction struct {
|
||||||
|
Action string `json:"action"` // "add", "update", or "remove"
|
||||||
|
SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeerData represents the data needed to update a peer
|
||||||
|
type UpdatePeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
ServerIP string `json:"serverIP"`
|
||||||
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPeerData represents the data needed to add a peer
|
||||||
|
type AddPeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
ServerIP string `json:"serverIP"`
|
||||||
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeerData represents the data needed to remove a peer
|
||||||
|
type RemovePeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayPeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to format endpoints correctly
|
||||||
|
func formatEndpoint(endpoint string) string {
|
||||||
|
if endpoint == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
|
||||||
|
_, _, err := net.SplitHostPort(endpoint)
|
||||||
|
if err == nil {
|
||||||
|
return endpoint // Already valid, no change needed
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
|
||||||
|
lastColon := strings.LastIndex(endpoint, ":")
|
||||||
|
if lastColon > 0 { // Ensure there is a colon and it's not the first character
|
||||||
|
hostPart := endpoint[:lastColon]
|
||||||
|
// Check if the host part is a literal IPv6 address
|
||||||
|
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
|
||||||
|
// It is! Reformat it with brackets.
|
||||||
|
portPart := endpoint[lastColon+1:]
|
||||||
|
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's not the specific malformed case, return it as is.
|
||||||
|
return endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func fixKey(key string) string {
|
||||||
|
// Remove any whitespace
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
|
||||||
|
// Decode from base64
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(key)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Error decoding base64")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to hex
|
||||||
|
return hex.EncodeToString(decoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLogLevel(level string) logger.LogLevel {
|
||||||
|
switch strings.ToUpper(level) {
|
||||||
|
case "DEBUG":
|
||||||
|
return logger.DEBUG
|
||||||
|
case "INFO":
|
||||||
|
return logger.INFO
|
||||||
|
case "WARN":
|
||||||
|
return logger.WARN
|
||||||
|
case "ERROR":
|
||||||
|
return logger.ERROR
|
||||||
|
case "FATAL":
|
||||||
|
return logger.FATAL
|
||||||
|
default:
|
||||||
|
return logger.INFO // default to INFO if invalid level provided
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
||||||
|
switch level {
|
||||||
|
case logger.DEBUG:
|
||||||
|
return device.LogLevelVerbose
|
||||||
|
// case logger.INFO:
|
||||||
|
// return device.LogLevel
|
||||||
|
case logger.WARN:
|
||||||
|
return device.LogLevelError
|
||||||
|
case logger.ERROR, logger.FATAL:
|
||||||
|
return device.LogLevelSilent
|
||||||
|
default:
|
||||||
|
return device.LogLevelSilent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveDomain(domain string) (string, error) {
|
||||||
|
// First handle any protocol prefix
|
||||||
|
domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://")
|
||||||
|
|
||||||
|
// if there are any trailing slashes, remove them
|
||||||
|
domain = strings.TrimSuffix(domain, "/")
|
||||||
|
|
||||||
|
// Now split host and port
|
||||||
|
host, port, err := net.SplitHostPort(domain)
|
||||||
|
if err != nil {
|
||||||
|
// No port found, use the domain as is
|
||||||
|
host = domain
|
||||||
|
port = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup IP addresses
|
||||||
|
ips, err := net.LookupIP(host)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ips) == 0 {
|
||||||
|
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the first IPv4 address if available
|
||||||
|
var ipAddr string
|
||||||
|
for _, ip := range ips {
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
ipAddr = ipv4.String()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no IPv4 found, use the first IP (might be IPv6)
|
||||||
|
if ipAddr == "" {
|
||||||
|
ipAddr = ips[0].String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add port back if it existed
|
||||||
|
if port != "" {
|
||||||
|
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||||
|
if maxPort < minPort {
|
||||||
|
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a slice of all ports in the range
|
||||||
|
portRange := make([]uint16, maxPort-minPort+1)
|
||||||
|
for i := range portRange {
|
||||||
|
portRange[i] = minPort + uint16(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fisher-Yates shuffle to randomize the port order
|
||||||
|
rand.Seed(uint64(time.Now().UnixNano()))
|
||||||
|
for i := len(portRange) - 1; i > 0; i-- {
|
||||||
|
j := rand.Intn(i + 1)
|
||||||
|
portRange[i], portRange[j] = portRange[j], portRange[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try each port in the randomized order
|
||||||
|
for _, port := range portRange {
|
||||||
|
addr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: int(port),
|
||||||
|
}
|
||||||
|
conn, err := net.ListenUDP("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
continue // Port is in use or there was an error, try next port
|
||||||
|
}
|
||||||
|
_ = conn.SetDeadline(time.Now())
|
||||||
|
conn.Close()
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendPing(olm *websocket.Client) error {
|
||||||
|
err := olm.SendMessage("olm/ping", map[string]interface{}{
|
||||||
|
"timestamp": time.Now().Unix(),
|
||||||
|
"userToken": olm.GetConfig().UserToken,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send ping message: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Debug("Sent ping message")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func keepSendingPing(olm *websocket.Client) {
|
||||||
|
// Send ping immediately on startup
|
||||||
|
if err := sendPing(olm); err != nil {
|
||||||
|
logger.Error("Failed to send initial ping: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Sent initial ping message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up ticker for one minute intervals
|
||||||
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stopPing:
|
||||||
|
logger.Info("Stopping ping messages")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := sendPing(olm); err != nil {
|
||||||
|
logger.Error("Failed to send periodic ping: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||||
|
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
|
||||||
|
siteHost, err := ResolveDomain(siteConfig.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP
|
||||||
|
allowedIp := strings.Split(siteConfig.ServerIP, "/")
|
||||||
|
if len(allowedIp) > 1 {
|
||||||
|
allowedIp[1] = "32"
|
||||||
|
} else {
|
||||||
|
allowedIp = append(allowedIp, "32")
|
||||||
|
}
|
||||||
|
allowedIpStr := strings.Join(allowedIp, "/")
|
||||||
|
|
||||||
|
// Collect all allowed IPs in a slice
|
||||||
|
var allowedIPs []string
|
||||||
|
allowedIPs = append(allowedIPs, allowedIpStr)
|
||||||
|
|
||||||
|
// If we have anything in remoteSubnets, add those as well
|
||||||
|
if siteConfig.RemoteSubnets != "" {
|
||||||
|
// Split remote subnets by comma and add each one
|
||||||
|
remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",")
|
||||||
|
for _, subnet := range remoteSubnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet != "" {
|
||||||
|
allowedIPs = append(allowedIPs, subnet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct WireGuard config for this peer
|
||||||
|
var configBuilder strings.Builder
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String())))
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(siteConfig.PublicKey)))
|
||||||
|
|
||||||
|
// Add each allowed IP separately
|
||||||
|
for _, allowedIP := range allowedIPs {
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||||
|
}
|
||||||
|
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
||||||
|
configBuilder.WriteString("persistent_keepalive_interval=1\n")
|
||||||
|
|
||||||
|
config := configBuilder.String()
|
||||||
|
logger.Debug("Configuring peer with config: %s", config)
|
||||||
|
|
||||||
|
err = dev.IpcSet(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up peer monitoring
|
||||||
|
if peerMonitor != nil {
|
||||||
|
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||||
|
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||||
|
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||||
|
|
||||||
|
primaryRelay, err := ResolveDomain(endpoint) // Using global endpoint variable
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wgConfig := &peermonitor.WireGuardConfig{
|
||||||
|
SiteID: siteConfig.SiteId,
|
||||||
|
PublicKey: fixKey(siteConfig.PublicKey),
|
||||||
|
ServerIP: strings.Split(siteConfig.ServerIP, "/")[0],
|
||||||
|
Endpoint: siteConfig.Endpoint,
|
||||||
|
PrimaryRelay: primaryRelay,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeer removes a peer from the WireGuard device
|
||||||
|
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
|
||||||
|
// Construct WireGuard config to remove the peer
|
||||||
|
var configBuilder strings.Builder
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(publicKey)))
|
||||||
|
configBuilder.WriteString("remove=true\n")
|
||||||
|
|
||||||
|
config := configBuilder.String()
|
||||||
|
logger.Debug("Removing peer with config: %s", config)
|
||||||
|
|
||||||
|
err := dev.IpcSet(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop monitoring this peer
|
||||||
|
if peerMonitor != nil {
|
||||||
|
peerMonitor.RemovePeer(siteId)
|
||||||
|
logger.Info("Stopped monitoring for site %d", siteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigureInterface configures a network interface with an IP address and brings it up
|
||||||
|
func ConfigureInterface(interfaceName string, wgData WgData) error {
|
||||||
|
var ipAddr string = wgData.TunnelIP
|
||||||
|
|
||||||
|
// Parse the IP address and network
|
||||||
|
ip, ipNet, err := net.ParseCIDR(ipAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return configureLinux(interfaceName, ip, ipNet)
|
||||||
|
case "darwin":
|
||||||
|
return configureDarwin(interfaceName, ip, ipNet)
|
||||||
|
case "windows":
|
||||||
|
return configureWindows(interfaceName, ip, ipNet)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
logger.Info("Configuring Windows interface: %s", interfaceName)
|
||||||
|
|
||||||
|
// Calculate mask string (e.g., 255.255.255.0)
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
mask := net.CIDRMask(maskBits, 32)
|
||||||
|
maskIP := net.IP(mask)
|
||||||
|
|
||||||
|
// Set the IP address using netsh
|
||||||
|
cmd := exec.Command("netsh", "interface", "ipv4", "set", "address",
|
||||||
|
fmt.Sprintf("name=%s", interfaceName),
|
||||||
|
"source=static",
|
||||||
|
fmt.Sprintf("addr=%s", ip.String()),
|
||||||
|
fmt.Sprintf("mask=%s", maskIP.String()))
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("netsh command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the interface if needed (in Windows, setting the IP usually brings it up)
|
||||||
|
// But we'll explicitly enable it to be sure
|
||||||
|
cmd = exec.Command("netsh", "interface", "set", "interface",
|
||||||
|
interfaceName,
|
||||||
|
"admin=enable")
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err = cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// delay 2 seconds
|
||||||
|
time.Sleep(8 * time.Second)
|
||||||
|
|
||||||
|
// Wait for the interface to be up and have the correct IP
|
||||||
|
err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("interface did not come up within timeout: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForInterfaceUp polls the network interface until it's up or times out
|
||||||
|
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
|
||||||
|
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
pollInterval := 500 * time.Millisecond
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
// Check if interface exists and is up
|
||||||
|
iface, err := net.InterfaceByName(interfaceName)
|
||||||
|
if err == nil {
|
||||||
|
// Check if interface is up
|
||||||
|
if iface.Flags&net.FlagUp != 0 {
|
||||||
|
// Check if it has the expected IP
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err == nil {
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ipNet, ok := addr.(*net.IPNet)
|
||||||
|
if ok && ipNet.IP.Equal(expectedIP) {
|
||||||
|
logger.Info("Interface %s is up with correct IP", interfaceName)
|
||||||
|
return nil // Interface is up with correct IP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Info("Interface %s exists but is not up yet", interfaceName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Info("Interface %s not found yet: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait before next check
|
||||||
|
time.Sleep(pollInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
// Parse destination to get the IP and subnet
|
||||||
|
ip, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the subnet mask
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
mask := net.CIDRMask(maskBits, 32)
|
||||||
|
maskIP := net.IP(mask)
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
cmd = exec.Command("route", "add",
|
||||||
|
ip.String(),
|
||||||
|
"mask", maskIP.String(),
|
||||||
|
gateway,
|
||||||
|
"metric", "1")
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// First, get the interface index
|
||||||
|
indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces")
|
||||||
|
output, err := indexCmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface index: %v, output: %s", err, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the output to find the interface index
|
||||||
|
lines := strings.Split(string(output), "\n")
|
||||||
|
var ifIndex string
|
||||||
|
for _, line := range lines {
|
||||||
|
if strings.Contains(line, interfaceName) {
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) > 0 {
|
||||||
|
ifIndex = fields[0]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ifIndex == "" {
|
||||||
|
return fmt.Errorf("could not find index for interface %s", interfaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to integer to validate
|
||||||
|
idx, err := strconv.Atoi(ifIndex)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid interface index: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route via interface using the index
|
||||||
|
cmd = exec.Command("route", "add",
|
||||||
|
ip.String(),
|
||||||
|
"mask", maskIP.String(),
|
||||||
|
"0.0.0.0",
|
||||||
|
"if", strconv.Itoa(idx))
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WindowsRemoveRoute(destination string) error {
|
||||||
|
// Parse destination to get the IP
|
||||||
|
ip, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the subnet mask
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
mask := net.CIDRMask(maskBits, 32)
|
||||||
|
maskIP := net.IP(mask)
|
||||||
|
|
||||||
|
cmd := exec.Command("route", "delete",
|
||||||
|
ip.String(),
|
||||||
|
"mask", maskIP.String())
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findUnusedUTUN() (string, error) {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to list interfaces: %v", err)
|
||||||
|
}
|
||||||
|
used := make(map[int]bool)
|
||||||
|
re := regexp.MustCompile(`^utun(\d+)$`)
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 {
|
||||||
|
if num, err := strconv.Atoi(matches[1]); err == nil {
|
||||||
|
used[num] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try utun0 up to utun255.
|
||||||
|
for i := 0; i < 256; i++ {
|
||||||
|
if !used[i] {
|
||||||
|
return fmt.Sprintf("utun%d", i), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("no unused utun interface found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
logger.Info("Configuring darwin interface: %s", interfaceName)
|
||||||
|
|
||||||
|
prefix, _ := ipNet.Mask.Size()
|
||||||
|
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
|
||||||
|
|
||||||
|
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the interface
|
||||||
|
cmd = exec.Command("ifconfig", interfaceName, "up")
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err = cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
// Get the interface
|
||||||
|
link, err := netlink.LinkByName(interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the IP address attributes
|
||||||
|
addr := &netlink.Addr{
|
||||||
|
IPNet: &net.IPNet{
|
||||||
|
IP: ip,
|
||||||
|
Mask: ipNet.Mask,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the IP address to the interface
|
||||||
|
if err := netlink.AddrAdd(link, addr); err != nil {
|
||||||
|
return fmt.Errorf("failed to add IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the interface
|
||||||
|
if err := netlink.LinkSetUp(link); err != nil {
|
||||||
|
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway)
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// Route via interface
|
||||||
|
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DarwinRemoveRoute(destination string) error {
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination)
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
cmd = exec.Command("ip", "route", "add", destination, "via", gateway)
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// Route via interface
|
||||||
|
cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ip route command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LinuxRemoveRoute(destination string) error {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("ip", "route", "del", destination)
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRouteForServerIP adds an OS-specific route for the server IP
|
||||||
|
func addRouteForServerIP(serverIP, interfaceName string) error {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return DarwinAddRoute(serverIP, "", interfaceName)
|
||||||
|
}
|
||||||
|
// else if runtime.GOOS == "windows" {
|
||||||
|
// return WindowsAddRoute(serverIP, "", interfaceName)
|
||||||
|
// } else if runtime.GOOS == "linux" {
|
||||||
|
// return LinuxAddRoute(serverIP, "", interfaceName)
|
||||||
|
// }
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRouteForServerIP removes an OS-specific route for the server IP
|
||||||
|
func removeRouteForServerIP(serverIP string) error {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return DarwinRemoveRoute(serverIP)
|
||||||
|
}
|
||||||
|
// else if runtime.GOOS == "windows" {
|
||||||
|
// return WindowsRemoveRoute(serverIP)
|
||||||
|
// } else if runtime.GOOS == "linux" {
|
||||||
|
// return LinuxRemoveRoute(serverIP)
|
||||||
|
// }
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets
|
||||||
|
func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error {
|
||||||
|
if remoteSubnets == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split remote subnets by comma and add routes for each one
|
||||||
|
subnets := strings.Split(remoteSubnets, ",")
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add route based on operating system
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "windows" {
|
||||||
|
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "linux" {
|
||||||
|
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Added route for remote subnet: %s", subnet)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets
|
||||||
|
func removeRoutesForRemoteSubnets(remoteSubnets string) error {
|
||||||
|
if remoteSubnets == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split remote subnets by comma and remove routes for each one
|
||||||
|
subnets := strings.Split(remoteSubnets, ",")
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove route based on operating system
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if err := DarwinRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "windows" {
|
||||||
|
if err := WindowsRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "linux" {
|
||||||
|
if err := LinuxRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Removed route for remote subnet: %s", subnet)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
894
olm/olm.go
Normal file
894
olm/olm.go
Normal file
@@ -0,0 +1,894 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/newt/updates"
|
||||||
|
"github.com/fosrl/olm/api"
|
||||||
|
"github.com/fosrl/olm/bind"
|
||||||
|
"github.com/fosrl/olm/holepunch"
|
||||||
|
"github.com/fosrl/olm/peermonitor"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
// Connection settings
|
||||||
|
Endpoint string
|
||||||
|
ID string
|
||||||
|
Secret string
|
||||||
|
UserToken string
|
||||||
|
|
||||||
|
// Network settings
|
||||||
|
MTU int
|
||||||
|
DNS string
|
||||||
|
InterfaceName string
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
LogLevel string
|
||||||
|
|
||||||
|
// HTTP server
|
||||||
|
EnableAPI bool
|
||||||
|
HTTPAddr string
|
||||||
|
SocketPath string
|
||||||
|
|
||||||
|
// Advanced
|
||||||
|
Holepunch bool
|
||||||
|
TlsClientCert string
|
||||||
|
|
||||||
|
// Parsed values (not in JSON)
|
||||||
|
PingIntervalDuration time.Duration
|
||||||
|
PingTimeoutDuration time.Duration
|
||||||
|
|
||||||
|
// Source tracking (not in JSON)
|
||||||
|
sources map[string]string
|
||||||
|
|
||||||
|
Version string
|
||||||
|
OrgID string
|
||||||
|
DoNotCreateNewClient bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
privateKey wgtypes.Key
|
||||||
|
connected bool
|
||||||
|
dev *device.Device
|
||||||
|
wgData WgData
|
||||||
|
holePunchData HolePunchData
|
||||||
|
uapiListener net.Listener
|
||||||
|
tdev tun.Device
|
||||||
|
apiServer *api.API
|
||||||
|
olmClient *websocket.Client
|
||||||
|
tunnelCancel context.CancelFunc
|
||||||
|
tunnelRunning bool
|
||||||
|
sharedBind *bind.SharedBind
|
||||||
|
holePunchManager *holepunch.Manager
|
||||||
|
)
|
||||||
|
|
||||||
|
func Run(ctx context.Context, config Config) {
|
||||||
|
// Create a cancellable context for internal shutdown control
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
logger.GetLogger().SetLevel(parseLogLevel(config.LogLevel))
|
||||||
|
|
||||||
|
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
|
||||||
|
logger.Debug("Failed to check for updates: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Holepunch {
|
||||||
|
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.HTTPAddr != "" {
|
||||||
|
apiServer = api.NewAPI(config.HTTPAddr)
|
||||||
|
} else if config.SocketPath != "" {
|
||||||
|
apiServer = api.NewAPISocket(config.SocketPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
apiServer.SetVersion(config.Version)
|
||||||
|
apiServer.SetOrgID(config.OrgID)
|
||||||
|
|
||||||
|
if err := apiServer.Start(); err != nil {
|
||||||
|
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen for shutdown requests from the API
|
||||||
|
go func() {
|
||||||
|
<-apiServer.GetShutdownChannel()
|
||||||
|
logger.Info("Shutdown requested via API")
|
||||||
|
// Cancel the context to trigger graceful shutdown
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var (
|
||||||
|
id = config.ID
|
||||||
|
secret = config.Secret
|
||||||
|
endpoint = config.Endpoint
|
||||||
|
userToken = config.UserToken
|
||||||
|
)
|
||||||
|
|
||||||
|
// Main event loop that handles connect, disconnect, and reconnect
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Info("Context cancelled while waiting for credentials")
|
||||||
|
goto shutdown
|
||||||
|
|
||||||
|
case req := <-apiServer.GetConnectionChannel():
|
||||||
|
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
||||||
|
|
||||||
|
// Stop any existing tunnel before starting a new one
|
||||||
|
if olmClient != nil {
|
||||||
|
logger.Info("Stopping existing tunnel before starting new connection")
|
||||||
|
StopTunnel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the connection parameters
|
||||||
|
id = req.ID
|
||||||
|
secret = req.Secret
|
||||||
|
endpoint = req.Endpoint
|
||||||
|
userToken := req.UserToken
|
||||||
|
|
||||||
|
// Start the tunnel process with the new credentials
|
||||||
|
if id != "" && secret != "" && endpoint != "" {
|
||||||
|
logger.Info("Starting tunnel with new credentials")
|
||||||
|
tunnelRunning = true
|
||||||
|
go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-apiServer.GetDisconnectChannel():
|
||||||
|
logger.Info("Received disconnect request via API")
|
||||||
|
StopTunnel()
|
||||||
|
// Clear credentials so we wait for new connect call
|
||||||
|
id = ""
|
||||||
|
secret = ""
|
||||||
|
endpoint = ""
|
||||||
|
userToken = ""
|
||||||
|
|
||||||
|
default:
|
||||||
|
// If we have credentials and no tunnel is running, start it
|
||||||
|
if id != "" && secret != "" && endpoint != "" && !tunnelRunning {
|
||||||
|
logger.Info("Starting tunnel process with initial credentials")
|
||||||
|
tunnelRunning = true
|
||||||
|
go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
|
||||||
|
} else if id == "" || secret == "" || endpoint == "" {
|
||||||
|
// If we don't have credentials, check if API is enabled
|
||||||
|
if !config.EnableAPI {
|
||||||
|
missing := []string{}
|
||||||
|
if id == "" {
|
||||||
|
missing = append(missing, "id")
|
||||||
|
}
|
||||||
|
if secret == "" {
|
||||||
|
missing = append(missing, "secret")
|
||||||
|
}
|
||||||
|
if endpoint == "" {
|
||||||
|
missing = append(missing, "endpoint")
|
||||||
|
}
|
||||||
|
// exit the application because there is no way to provide the missing parameters
|
||||||
|
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
|
||||||
|
goto shutdown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sleep briefly to prevent tight loop
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdown:
|
||||||
|
Stop()
|
||||||
|
apiServer.Stop()
|
||||||
|
logger.Info("Olm service shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TunnelProcess(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) {
|
||||||
|
// Create a cancellable context for this tunnel process
|
||||||
|
tunnelCtx, cancel := context.WithCancel(ctx)
|
||||||
|
tunnelCancel = cancel
|
||||||
|
defer func() {
|
||||||
|
tunnelCancel = nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Recreate channels for this tunnel session
|
||||||
|
stopPing = make(chan struct{})
|
||||||
|
|
||||||
|
var (
|
||||||
|
interfaceName = config.InterfaceName
|
||||||
|
loggerLevel = parseLogLevel(config.LogLevel)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create a new olm client using the provided credentials
|
||||||
|
olm, err := websocket.NewClient(
|
||||||
|
id, // Use provided ID
|
||||||
|
secret, // Use provided secret
|
||||||
|
userToken, // Use provided user token OPTIONAL
|
||||||
|
endpoint, // Use provided endpoint
|
||||||
|
config.PingIntervalDuration,
|
||||||
|
config.PingTimeoutDuration,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create olm: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the client reference globally
|
||||||
|
olmClient = olm
|
||||||
|
|
||||||
|
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to generate private key: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create shared UDP socket for both holepunch and WireGuard
|
||||||
|
if sharedBind == nil {
|
||||||
|
sourcePort, err := FindAvailableUDPPort(49152, 65535)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error finding available port: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := &net.UDPAddr{
|
||||||
|
Port: int(sourcePort),
|
||||||
|
IP: net.IPv4zero,
|
||||||
|
}
|
||||||
|
|
||||||
|
udpConn, err := net.ListenUDP("udp", localAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create shared UDP socket: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedBind, err = bind.New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create shared bind: %v", err)
|
||||||
|
udpConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a reference for the hole punch senders (creator already has one reference for WireGuard)
|
||||||
|
sharedBind.AddRef()
|
||||||
|
|
||||||
|
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the holepunch manager
|
||||||
|
if holePunchManager == nil {
|
||||||
|
holePunchManager = holepunch.NewManager(sharedBind, id, ResolveDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
|
||||||
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
|
||||||
|
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
|
||||||
|
for i, node := range holePunchData.ExitNodes {
|
||||||
|
exitNodes[i] = holepunch.ExitNode{
|
||||||
|
Endpoint: node.Endpoint,
|
||||||
|
PublicKey: node.PublicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start hole punching using the manager
|
||||||
|
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
|
||||||
|
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
|
||||||
|
logger.Warn("Failed to start hole punch: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
||||||
|
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
|
||||||
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
type LegacyHolePunchData struct {
|
||||||
|
ServerPubKey string `json:"serverPubKey"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var legacyHolePunchData LegacyHolePunchData
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
|
||||||
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop any existing hole punch operations
|
||||||
|
if holePunchManager != nil {
|
||||||
|
holePunchManager.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start hole punching for the exit node
|
||||||
|
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
|
||||||
|
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
|
||||||
|
logger.Warn("Failed to start hole punch: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
if connected {
|
||||||
|
logger.Info("Already connected. Ignoring new connection request.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopRegister != nil {
|
||||||
|
stopRegister()
|
||||||
|
stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// close(stopHolepunch)
|
||||||
|
|
||||||
|
// wait 10 milliseconds to ensure the previous connection is closed
|
||||||
|
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
// if there is an existing tunnel then close it
|
||||||
|
if dev != nil {
|
||||||
|
logger.Info("Got new message. Closing existing tunnel!")
|
||||||
|
dev.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
||||||
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tdev, err = func() (tun.Device, error) {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
interfaceName, err := findUnusedUTUN()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tun.CreateTUN(interfaceName, config.MTU)
|
||||||
|
}
|
||||||
|
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
|
||||||
|
return createTUNFromFD(tunFdStr, config.MTU)
|
||||||
|
}
|
||||||
|
return tun.CreateTUN(interfaceName, config.MTU)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create TUN device: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
||||||
|
interfaceName = realInterfaceName
|
||||||
|
}
|
||||||
|
|
||||||
|
fileUAPI, err := func() (*os.File, error) {
|
||||||
|
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
|
||||||
|
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return os.NewFile(uintptr(fd), ""), nil
|
||||||
|
}
|
||||||
|
return uapiOpen(interfaceName)
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("UAPI listen error: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dev = device.NewDevice(tdev, sharedBind, device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||||
|
|
||||||
|
uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := uapiListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go dev.IpcHandle(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
logger.Info("UAPI listener started")
|
||||||
|
|
||||||
|
if err = dev.Up(); err != nil {
|
||||||
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
|
}
|
||||||
|
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||||
|
logger.Error("Failed to configure interface: %v", err)
|
||||||
|
}
|
||||||
|
apiServer.SetTunnelIP(wgData.TunnelIP)
|
||||||
|
|
||||||
|
peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
|
func(siteID int, connected bool, rtt time.Duration) {
|
||||||
|
// Find the site config to get endpoint information
|
||||||
|
var endpoint string
|
||||||
|
var isRelay bool
|
||||||
|
for _, site := range wgData.Sites {
|
||||||
|
if site.SiteId == siteID {
|
||||||
|
endpoint = site.Endpoint
|
||||||
|
// TODO: We'll need to track relay status separately
|
||||||
|
// For now, assume not using relay unless we get relay data
|
||||||
|
isRelay = !config.Holepunch
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
|
||||||
|
if connected {
|
||||||
|
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Peer %d is disconnected", siteID)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
fixKey(privateKey.String()),
|
||||||
|
olm,
|
||||||
|
dev,
|
||||||
|
config.Holepunch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := range wgData.Sites {
|
||||||
|
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
|
||||||
|
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
|
||||||
|
|
||||||
|
// Format the endpoint before configuring the peer.
|
||||||
|
site.Endpoint = formatEndpoint(site.Endpoint)
|
||||||
|
|
||||||
|
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to configure peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add route for peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Configured peer %s", site.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerMonitor.Start()
|
||||||
|
|
||||||
|
apiServer.SetRegistered(true)
|
||||||
|
|
||||||
|
connected = true
|
||||||
|
|
||||||
|
logger.Info("WireGuard device created.")
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received update-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var updateData UpdatePeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &updateData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling update data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to SiteConfig
|
||||||
|
siteConfig := SiteConfig{
|
||||||
|
SiteId: updateData.SiteId,
|
||||||
|
Endpoint: updateData.Endpoint,
|
||||||
|
PublicKey: updateData.PublicKey,
|
||||||
|
ServerIP: updateData.ServerIP,
|
||||||
|
ServerPort: updateData.ServerPort,
|
||||||
|
RemoteSubnets: updateData.RemoteSubnets,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the peer in WireGuard
|
||||||
|
if dev != nil {
|
||||||
|
// Find the existing peer to get old data
|
||||||
|
var oldRemoteSubnets string
|
||||||
|
var oldPublicKey string
|
||||||
|
for _, site := range wgData.Sites {
|
||||||
|
if site.SiteId == updateData.SiteId {
|
||||||
|
oldRemoteSubnets = site.RemoteSubnets
|
||||||
|
oldPublicKey = site.PublicKey
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the public key has changed, remove the old peer first
|
||||||
|
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
|
||||||
|
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
|
||||||
|
if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
|
||||||
|
logger.Error("Failed to remove old peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format the endpoint before updating the peer.
|
||||||
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
|
|
||||||
|
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to update peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove old remote subnet routes if they changed
|
||||||
|
if oldRemoteSubnets != siteConfig.RemoteSubnets {
|
||||||
|
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
|
||||||
|
logger.Error("Failed to remove old remote subnet routes: %v", err)
|
||||||
|
// Continue anyway to add new routes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new remote subnet routes
|
||||||
|
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add new remote subnet routes: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update successful
|
||||||
|
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||||
|
for i := range wgData.Sites {
|
||||||
|
if wgData.Sites[i].SiteId == updateData.SiteId {
|
||||||
|
wgData.Sites[i] = siteConfig
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Error("WireGuard device not initialized")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handler for adding a new peer
|
||||||
|
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received add-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var addData AddPeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &addData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling add data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to SiteConfig
|
||||||
|
siteConfig := SiteConfig{
|
||||||
|
SiteId: addData.SiteId,
|
||||||
|
Endpoint: addData.Endpoint,
|
||||||
|
PublicKey: addData.PublicKey,
|
||||||
|
ServerIP: addData.ServerIP,
|
||||||
|
ServerPort: addData.ServerPort,
|
||||||
|
RemoteSubnets: addData.RemoteSubnets,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the peer to WireGuard
|
||||||
|
if dev != nil {
|
||||||
|
// Format the endpoint before adding the new peer.
|
||||||
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
|
|
||||||
|
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to add peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add route for new peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add successful
|
||||||
|
logger.Info("Successfully added peer for site %d", addData.SiteId)
|
||||||
|
|
||||||
|
// Update WgData with the new peer
|
||||||
|
wgData.Sites = append(wgData.Sites, siteConfig)
|
||||||
|
} else {
|
||||||
|
logger.Error("WireGuard device not initialized")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handler for removing a peer
|
||||||
|
olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received remove-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var removeData RemovePeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &removeData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling remove data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the peer to remove
|
||||||
|
var peerToRemove *SiteConfig
|
||||||
|
var newSites []SiteConfig
|
||||||
|
|
||||||
|
for _, site := range wgData.Sites {
|
||||||
|
if site.SiteId == removeData.SiteId {
|
||||||
|
peerToRemove = &site
|
||||||
|
} else {
|
||||||
|
newSites = append(newSites, site)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerToRemove == nil {
|
||||||
|
logger.Error("Peer with site ID %d not found", removeData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the peer from WireGuard
|
||||||
|
if dev != nil {
|
||||||
|
if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
||||||
|
logger.Error("Failed to remove peer: %v", err)
|
||||||
|
// Send error response if needed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove route for the peer
|
||||||
|
err = removeRouteForServerIP(peerToRemove.ServerIP)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to remove route for peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove routes for remote subnets
|
||||||
|
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
|
||||||
|
logger.Error("Failed to remove routes for remote subnets: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove successful
|
||||||
|
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
|
||||||
|
|
||||||
|
// Update WgData to remove the peer
|
||||||
|
wgData.Sites = newSites
|
||||||
|
} else {
|
||||||
|
logger.Error("WireGuard device not initialized")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received relay-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var relayData RelayPeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling relay data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryRelay, err := ResolveDomain(relayData.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update HTTP server to mark this peer as using relay
|
||||||
|
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
||||||
|
|
||||||
|
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
|
||||||
|
logger.Info("Received no-sites message - no sites available for connection")
|
||||||
|
|
||||||
|
if stopRegister != nil {
|
||||||
|
stopRegister()
|
||||||
|
stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("No sites available - stopped registration and holepunch processes")
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
|
||||||
|
logger.Info("Received terminate message")
|
||||||
|
olm.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.OnConnect(func() error {
|
||||||
|
logger.Info("Websocket Connected")
|
||||||
|
|
||||||
|
apiServer.SetConnectionStatus(true)
|
||||||
|
|
||||||
|
if connected {
|
||||||
|
logger.Debug("Already connected, skipping registration")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey := privateKey.PublicKey()
|
||||||
|
|
||||||
|
if stopRegister == nil {
|
||||||
|
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
|
||||||
|
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
"publicKey": publicKey.String(),
|
||||||
|
"relay": !config.Holepunch,
|
||||||
|
"olmVersion": config.Version,
|
||||||
|
"orgId": config.OrgID,
|
||||||
|
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||||
|
}, 1*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
go keepSendingPing(olm)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.OnTokenUpdate(func(token string) {
|
||||||
|
if holePunchManager != nil {
|
||||||
|
holePunchManager.SetToken(token)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Connect to the WebSocket server
|
||||||
|
if err := olm.Connect(); err != nil {
|
||||||
|
logger.Error("Failed to connect to server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer olm.Close()
|
||||||
|
|
||||||
|
// Listen for org switch requests from the API
|
||||||
|
go func() {
|
||||||
|
for req := range apiServer.GetSwitchOrgChannel() {
|
||||||
|
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
||||||
|
|
||||||
|
// Update the config with the new orgId
|
||||||
|
config.OrgID = req.OrgID
|
||||||
|
|
||||||
|
// Mark as not connected to trigger re-registration
|
||||||
|
connected = false
|
||||||
|
|
||||||
|
Stop()
|
||||||
|
|
||||||
|
// Clear peer statuses in API
|
||||||
|
apiServer.SetRegistered(false)
|
||||||
|
apiServer.SetTunnelIP("")
|
||||||
|
apiServer.SetOrgID(config.OrgID)
|
||||||
|
|
||||||
|
// Trigger re-registration with new orgId
|
||||||
|
logger.Info("Re-registering with new orgId: %s", config.OrgID)
|
||||||
|
publicKey := privateKey.PublicKey()
|
||||||
|
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
"publicKey": publicKey.String(),
|
||||||
|
"relay": !config.Holepunch,
|
||||||
|
"olmVersion": config.Version,
|
||||||
|
"orgId": config.OrgID,
|
||||||
|
}, 1*time.Second)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for context cancellation
|
||||||
|
<-tunnelCtx.Done()
|
||||||
|
logger.Info("Tunnel process context cancelled, cleaning up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Stop() {
|
||||||
|
// Stop hole punch manager
|
||||||
|
if holePunchManager != nil {
|
||||||
|
holePunchManager.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopPing != nil {
|
||||||
|
select {
|
||||||
|
case <-stopPing:
|
||||||
|
// Channel already closed
|
||||||
|
default:
|
||||||
|
close(stopPing)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopRegister != nil {
|
||||||
|
stopRegister()
|
||||||
|
stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerMonitor != nil {
|
||||||
|
peerMonitor.Stop()
|
||||||
|
peerMonitor = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if uapiListener != nil {
|
||||||
|
uapiListener.Close()
|
||||||
|
uapiListener = nil
|
||||||
|
}
|
||||||
|
if dev != nil {
|
||||||
|
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
|
||||||
|
dev = nil
|
||||||
|
}
|
||||||
|
// Close TUN device
|
||||||
|
if tdev != nil {
|
||||||
|
tdev.Close()
|
||||||
|
tdev = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release the hole punch reference to the shared bind
|
||||||
|
if sharedBind != nil {
|
||||||
|
// Release hole punch reference (WireGuard already released its reference via dev.Close())
|
||||||
|
logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount())
|
||||||
|
sharedBind.Release()
|
||||||
|
sharedBind = nil
|
||||||
|
logger.Info("Released shared UDP bind")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Olm service stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopTunnel stops just the tunnel process and websocket connection
|
||||||
|
// without shutting down the entire application
|
||||||
|
func StopTunnel() {
|
||||||
|
logger.Info("Stopping tunnel process")
|
||||||
|
|
||||||
|
// Cancel the tunnel context if it exists
|
||||||
|
if tunnelCancel != nil {
|
||||||
|
tunnelCancel()
|
||||||
|
// Give it a moment to clean up
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the websocket connection
|
||||||
|
if olmClient != nil {
|
||||||
|
olmClient.Close()
|
||||||
|
olmClient = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
Stop()
|
||||||
|
|
||||||
|
// Reset the connected state
|
||||||
|
connected = false
|
||||||
|
tunnelRunning = false
|
||||||
|
|
||||||
|
// Update API server status
|
||||||
|
apiServer.SetConnectionStatus(false)
|
||||||
|
apiServer.SetRegistered(false)
|
||||||
|
apiServer.SetTunnelIP("")
|
||||||
|
|
||||||
|
logger.Info("Tunnel process stopped")
|
||||||
|
}
|
||||||
35
olm/unix.go
Normal file
35
olm/unix.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
|
||||||
|
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = unix.SetNonblock(int(fd), true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
file := os.NewFile(uintptr(fd), "")
|
||||||
|
return tun.CreateTUNFromFile(file, mtuInt)
|
||||||
|
}
|
||||||
|
func uapiOpen(interfaceName string) (*os.File, error) {
|
||||||
|
return ipc.UAPIOpen(interfaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||||
|
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||||
|
}
|
||||||
25
olm/windows.go
Normal file
25
olm/windows.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
|
||||||
|
return nil, errors.New("CreateTUNFromFile not supported on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func uapiOpen(interfaceName string) (*os.File, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||||
|
// On Windows, UAPIListen only takes one parameter
|
||||||
|
return ipc.UAPIListen(interfaceName)
|
||||||
|
}
|
||||||
331
peermonitor/peermonitor.go
Normal file
331
peermonitor/peermonitor.go
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
package peermonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
"github.com/fosrl/olm/wgtester"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PeerMonitorCallback is the function type for connection status change callbacks
|
||||||
|
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
|
||||||
|
|
||||||
|
// WireGuardConfig holds the WireGuard configuration for a peer
|
||||||
|
type WireGuardConfig struct {
|
||||||
|
SiteID int
|
||||||
|
PublicKey string
|
||||||
|
ServerIP string
|
||||||
|
Endpoint string
|
||||||
|
PrimaryRelay string // The primary relay endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||||
|
type PeerMonitor struct {
|
||||||
|
monitors map[int]*wgtester.Client
|
||||||
|
configs map[int]*WireGuardConfig
|
||||||
|
callback PeerMonitorCallback
|
||||||
|
mutex sync.Mutex
|
||||||
|
running bool
|
||||||
|
interval time.Duration
|
||||||
|
timeout time.Duration
|
||||||
|
maxAttempts int
|
||||||
|
privateKey string
|
||||||
|
wsClient *websocket.Client
|
||||||
|
device *device.Device
|
||||||
|
handleRelaySwitch bool // Whether to handle relay switching
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||||
|
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
|
||||||
|
return &PeerMonitor{
|
||||||
|
monitors: make(map[int]*wgtester.Client),
|
||||||
|
configs: make(map[int]*WireGuardConfig),
|
||||||
|
callback: callback,
|
||||||
|
interval: 1 * time.Second, // Default check interval
|
||||||
|
timeout: 2500 * time.Millisecond,
|
||||||
|
maxAttempts: 8,
|
||||||
|
privateKey: privateKey,
|
||||||
|
wsClient: wsClient,
|
||||||
|
device: device,
|
||||||
|
handleRelaySwitch: handleRelaySwitch,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetInterval changes how frequently peers are checked
|
||||||
|
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
pm.interval = interval
|
||||||
|
|
||||||
|
// Update interval for all existing monitors
|
||||||
|
for _, client := range pm.monitors {
|
||||||
|
client.SetPacketInterval(interval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTimeout changes the timeout for waiting for responses
|
||||||
|
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
pm.timeout = timeout
|
||||||
|
|
||||||
|
// Update timeout for all existing monitors
|
||||||
|
for _, client := range pm.monitors {
|
||||||
|
client.SetTimeout(timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||||
|
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
pm.maxAttempts = attempts
|
||||||
|
|
||||||
|
// Update max attempts for all existing monitors
|
||||||
|
for _, client := range pm.monitors {
|
||||||
|
client.SetMaxAttempts(attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPeer adds a new peer to monitor
|
||||||
|
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
// Check if we're already monitoring this peer
|
||||||
|
if _, exists := pm.monitors[siteID]; exists {
|
||||||
|
// Update the endpoint instead of creating a new monitor
|
||||||
|
pm.removePeerUnlocked(siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := wgtester.NewClient(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure the client with our settings
|
||||||
|
client.SetPacketInterval(pm.interval)
|
||||||
|
client.SetTimeout(pm.timeout)
|
||||||
|
client.SetMaxAttempts(pm.maxAttempts)
|
||||||
|
|
||||||
|
// Store the client and config
|
||||||
|
pm.monitors[siteID] = client
|
||||||
|
pm.configs[siteID] = wgConfig
|
||||||
|
|
||||||
|
// If monitor is already running, start monitoring this peer
|
||||||
|
if pm.running {
|
||||||
|
siteIDCopy := siteID // Create a copy for the closure
|
||||||
|
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||||
|
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
|
||||||
|
// This function assumes the mutex is already held by the caller
|
||||||
|
func (pm *PeerMonitor) removePeerUnlocked(siteID int) {
|
||||||
|
client, exists := pm.monitors[siteID]
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client.StopMonitor()
|
||||||
|
client.Close()
|
||||||
|
delete(pm.monitors, siteID)
|
||||||
|
delete(pm.configs, siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeer stops monitoring a peer and removes it from the monitor
|
||||||
|
func (pm *PeerMonitor) RemovePeer(siteID int) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
pm.removePeerUnlocked(siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins monitoring all peers
|
||||||
|
func (pm *PeerMonitor) Start() {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
if pm.running {
|
||||||
|
return // Already running
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.running = true
|
||||||
|
|
||||||
|
// Start monitoring all peers
|
||||||
|
for siteID, client := range pm.monitors {
|
||||||
|
siteIDCopy := siteID // Create a copy for the closure
|
||||||
|
err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||||
|
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Info("Started monitoring peer %d\n", siteID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConnectionStatusChange is called when a peer's connection status changes
|
||||||
|
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status wgtester.ConnectionStatus) {
|
||||||
|
// Call the user-provided callback first
|
||||||
|
if pm.callback != nil {
|
||||||
|
pm.callback(siteID, status.Connected, status.RTT)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If disconnected, handle failover
|
||||||
|
if !status.Connected {
|
||||||
|
// Send relay message to the server
|
||||||
|
if pm.wsClient != nil {
|
||||||
|
pm.sendRelay(siteID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFailover handles failover to the relay server when a peer is disconnected
|
||||||
|
func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
config, exists := pm.configs[siteID]
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for IPv6 and format the endpoint correctly
|
||||||
|
formattedEndpoint := relayEndpoint
|
||||||
|
if strings.Contains(relayEndpoint, ":") {
|
||||||
|
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure WireGuard to use the relay
|
||||||
|
wgConfig := fmt.Sprintf(`private_key=%s
|
||||||
|
public_key=%s
|
||||||
|
allowed_ip=%s/32
|
||||||
|
endpoint=%s:21820
|
||||||
|
persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, formattedEndpoint)
|
||||||
|
|
||||||
|
err := pm.device.IpcSet(wgConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to configure WireGuard device: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Adjusted peer %d to point to relay!\n", siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendRelay sends a relay message to the server
|
||||||
|
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||||
|
if !pm.handleRelaySwitch {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if pm.wsClient == nil {
|
||||||
|
return fmt.Errorf("websocket client is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
|
||||||
|
"siteId": siteID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send registration message: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Info("Sent relay message")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops monitoring all peers
|
||||||
|
func (pm *PeerMonitor) Stop() {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
if !pm.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.running = false
|
||||||
|
|
||||||
|
// Stop all monitors
|
||||||
|
for _, client := range pm.monitors {
|
||||||
|
client.StopMonitor()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops monitoring and cleans up resources
|
||||||
|
func (pm *PeerMonitor) Close() {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
// Stop and close all clients
|
||||||
|
for siteID, client := range pm.monitors {
|
||||||
|
client.StopMonitor()
|
||||||
|
client.Close()
|
||||||
|
delete(pm.monitors, siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.running = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeer tests connectivity to a specific peer
|
||||||
|
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
client, exists := pm.monitors[siteID]
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
connected, rtt := client.TestConnection(ctx)
|
||||||
|
return connected, rtt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAllPeers tests connectivity to all peers
|
||||||
|
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||||
|
Connected bool
|
||||||
|
RTT time.Duration
|
||||||
|
} {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
peers := make(map[int]*wgtester.Client, len(pm.monitors))
|
||||||
|
for siteID, client := range pm.monitors {
|
||||||
|
peers[siteID] = client
|
||||||
|
}
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
results := make(map[int]struct {
|
||||||
|
Connected bool
|
||||||
|
RTT time.Duration
|
||||||
|
})
|
||||||
|
for siteID, client := range peers {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||||
|
connected, rtt := client.TestConnection(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
results[siteID] = struct {
|
||||||
|
Connected bool
|
||||||
|
RTT time.Duration
|
||||||
|
}{
|
||||||
|
Connected: connected,
|
||||||
|
RTT: rtt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
335
proxy/manager.go
335
proxy/manager.go
@@ -1,335 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Target represents a proxy target with its address and port
|
|
||||||
type Target struct {
|
|
||||||
Address string
|
|
||||||
Port int
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProxyManager handles the creation and management of proxy connections
|
|
||||||
type ProxyManager struct {
|
|
||||||
tnet *netstack.Net
|
|
||||||
tcpTargets map[string]map[int]string // map[listenIP]map[port]targetAddress
|
|
||||||
udpTargets map[string]map[int]string
|
|
||||||
listeners []*gonet.TCPListener
|
|
||||||
udpConns []*gonet.UDPConn
|
|
||||||
running bool
|
|
||||||
mutex sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewProxyManager creates a new proxy manager instance
|
|
||||||
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
|
|
||||||
return &ProxyManager{
|
|
||||||
tnet: tnet,
|
|
||||||
tcpTargets: make(map[string]map[int]string),
|
|
||||||
udpTargets: make(map[string]map[int]string),
|
|
||||||
listeners: make([]*gonet.TCPListener, 0),
|
|
||||||
udpConns: make([]*gonet.UDPConn, 0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddTarget adds a new target for proxying
|
|
||||||
func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr string) error {
|
|
||||||
pm.mutex.Lock()
|
|
||||||
defer pm.mutex.Unlock()
|
|
||||||
|
|
||||||
switch proto {
|
|
||||||
case "tcp":
|
|
||||||
if pm.tcpTargets[listenIP] == nil {
|
|
||||||
pm.tcpTargets[listenIP] = make(map[int]string)
|
|
||||||
}
|
|
||||||
pm.tcpTargets[listenIP][port] = targetAddr
|
|
||||||
case "udp":
|
|
||||||
if pm.udpTargets[listenIP] == nil {
|
|
||||||
pm.udpTargets[listenIP] = make(map[int]string)
|
|
||||||
}
|
|
||||||
pm.udpTargets[listenIP][port] = targetAddr
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
|
||||||
}
|
|
||||||
|
|
||||||
if pm.running {
|
|
||||||
return pm.startTarget(proto, listenIP, port, targetAddr)
|
|
||||||
} else {
|
|
||||||
logger.Info("Not adding target because not running")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error {
|
|
||||||
pm.mutex.Lock()
|
|
||||||
defer pm.mutex.Unlock()
|
|
||||||
|
|
||||||
switch proto {
|
|
||||||
case "tcp":
|
|
||||||
if targets, ok := pm.tcpTargets[listenIP]; ok {
|
|
||||||
delete(targets, port)
|
|
||||||
// Remove and close the corresponding TCP listener
|
|
||||||
for i, listener := range pm.listeners {
|
|
||||||
if addr, ok := listener.Addr().(*net.TCPAddr); ok && addr.Port == port {
|
|
||||||
listener.Close()
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
// Remove from slice
|
|
||||||
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
|
||||||
}
|
|
||||||
case "udp":
|
|
||||||
if targets, ok := pm.udpTargets[listenIP]; ok {
|
|
||||||
delete(targets, port)
|
|
||||||
// Remove and close the corresponding UDP connection
|
|
||||||
for i, conn := range pm.udpConns {
|
|
||||||
if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok && addr.Port == port {
|
|
||||||
conn.Close()
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
// Remove from slice
|
|
||||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start begins listening for all configured proxy targets
|
|
||||||
func (pm *ProxyManager) Start() error {
|
|
||||||
pm.mutex.Lock()
|
|
||||||
defer pm.mutex.Unlock()
|
|
||||||
|
|
||||||
if pm.running {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start TCP targets
|
|
||||||
for listenIP, targets := range pm.tcpTargets {
|
|
||||||
for port, targetAddr := range targets {
|
|
||||||
if err := pm.startTarget("tcp", listenIP, port, targetAddr); err != nil {
|
|
||||||
return fmt.Errorf("failed to start TCP target: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start UDP targets
|
|
||||||
for listenIP, targets := range pm.udpTargets {
|
|
||||||
for port, targetAddr := range targets {
|
|
||||||
if err := pm.startTarget("udp", listenIP, port, targetAddr); err != nil {
|
|
||||||
return fmt.Errorf("failed to start UDP target: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pm.running = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) Stop() error {
|
|
||||||
pm.mutex.Lock()
|
|
||||||
defer pm.mutex.Unlock()
|
|
||||||
|
|
||||||
if !pm.running {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set running to false first to signal handlers to stop
|
|
||||||
pm.running = false
|
|
||||||
|
|
||||||
// Close TCP listeners
|
|
||||||
for i := len(pm.listeners) - 1; i >= 0; i-- {
|
|
||||||
listener := pm.listeners[i]
|
|
||||||
if err := listener.Close(); err != nil {
|
|
||||||
logger.Error("Error closing TCP listener: %v", err)
|
|
||||||
}
|
|
||||||
// Remove from slice
|
|
||||||
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close UDP connections
|
|
||||||
for i := len(pm.udpConns) - 1; i >= 0; i-- {
|
|
||||||
conn := pm.udpConns[i]
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
logger.Error("Error closing UDP connection: %v", err)
|
|
||||||
}
|
|
||||||
// Remove from slice
|
|
||||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear the target maps
|
|
||||||
for k := range pm.tcpTargets {
|
|
||||||
delete(pm.tcpTargets, k)
|
|
||||||
}
|
|
||||||
for k := range pm.udpTargets {
|
|
||||||
delete(pm.udpTargets, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Give active connections a chance to close gracefully
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr string) error {
|
|
||||||
switch proto {
|
|
||||||
case "tcp":
|
|
||||||
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{Port: port})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create TCP listener: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pm.listeners = append(pm.listeners, listener)
|
|
||||||
go pm.handleTCPProxy(listener, targetAddr)
|
|
||||||
|
|
||||||
case "udp":
|
|
||||||
addr := &net.UDPAddr{Port: port}
|
|
||||||
conn, err := pm.tnet.ListenUDP(addr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create UDP listener: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pm.udpConns = append(pm.udpConns, conn)
|
|
||||||
go pm.handleUDPProxy(conn, targetAddr)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) {
|
|
||||||
for {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
// Check if we're shutting down or the listener was closed
|
|
||||||
if !pm.running {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for specific network errors that indicate the listener is closed
|
|
||||||
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
|
||||||
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Error("Error accepting TCP connection: %v", err)
|
|
||||||
// Don't hammer the CPU if we hit a temporary error
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
target, err := net.Dial("tcp", targetAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error connecting to target: %v", err)
|
|
||||||
conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a WaitGroup to ensure both copy operations complete
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
io.Copy(target, conn)
|
|
||||||
target.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
io.Copy(conn, target)
|
|
||||||
conn.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait for both copies to complete
|
|
||||||
wg.Wait()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
|
||||||
buffer := make([]byte, 65507) // Max UDP packet size
|
|
||||||
clientConns := make(map[string]*net.UDPConn)
|
|
||||||
var clientsMutex sync.RWMutex
|
|
||||||
|
|
||||||
for {
|
|
||||||
n, remoteAddr, err := conn.ReadFrom(buffer)
|
|
||||||
if err != nil {
|
|
||||||
if !pm.running {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Error("Error reading UDP packet: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
clientKey := remoteAddr.String()
|
|
||||||
clientsMutex.RLock()
|
|
||||||
targetConn, exists := clientConns[clientKey]
|
|
||||||
clientsMutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error resolving target address: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error connecting to target: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
clientsMutex.Lock()
|
|
||||||
clientConns[clientKey] = targetConn
|
|
||||||
clientsMutex.Unlock()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
buffer := make([]byte, 65507)
|
|
||||||
for {
|
|
||||||
n, _, err := targetConn.ReadFromUDP(buffer)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error reading from target: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = conn.WriteTo(buffer[:n], remoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error writing to client: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = targetConn.Write(buffer[:n])
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error writing to target: %v", err)
|
|
||||||
targetConn.Close()
|
|
||||||
clientsMutex.Lock()
|
|
||||||
delete(clientConns, clientKey)
|
|
||||||
clientsMutex.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 774 KiB |
54
service_unix.go
Normal file
54
service_unix.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Service management functions are not available on non-Windows platforms
|
||||||
|
func installService() error {
|
||||||
|
return fmt.Errorf("service management is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeService() error {
|
||||||
|
return fmt.Errorf("service management is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func startService(args []string) error {
|
||||||
|
_ = args // unused on Unix platforms
|
||||||
|
return fmt.Errorf("service management is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopService() error {
|
||||||
|
return fmt.Errorf("service management is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getServiceStatus() (string, error) {
|
||||||
|
return "", fmt.Errorf("service management is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func debugService(args []string) error {
|
||||||
|
_ = args // unused on Unix platforms
|
||||||
|
return fmt.Errorf("debug service is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWindowsService() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func runService(name string, isDebug bool, args []string) {
|
||||||
|
// No-op on non-Windows platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupWindowsEventLog() {
|
||||||
|
// No-op on non-Windows platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func watchLogFile(end bool) error {
|
||||||
|
return fmt.Errorf("watching log file is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func showServiceConfig() {
|
||||||
|
fmt.Println("Service configuration is only available on Windows")
|
||||||
|
}
|
||||||
631
service_windows.go
Normal file
631
service_windows.go
Normal file
@@ -0,0 +1,631 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"golang.org/x/sys/windows/svc"
|
||||||
|
"golang.org/x/sys/windows/svc/debug"
|
||||||
|
"golang.org/x/sys/windows/svc/eventlog"
|
||||||
|
"golang.org/x/sys/windows/svc/mgr"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
serviceName = "OlmWireguardService"
|
||||||
|
serviceDisplayName = "Olm WireGuard VPN Service"
|
||||||
|
serviceDescription = "Olm WireGuard VPN client service for secure network connectivity"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Global variable to store service arguments
|
||||||
|
var serviceArgs []string
|
||||||
|
|
||||||
|
// getServiceArgsPath returns the path where service arguments are stored
|
||||||
|
func getServiceArgsPath() string {
|
||||||
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm")
|
||||||
|
return filepath.Join(logDir, "service_args.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveServiceArgs saves the service arguments to a file
|
||||||
|
func saveServiceArgs(args []string) error {
|
||||||
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm")
|
||||||
|
err := os.MkdirAll(logDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create config directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
argsPath := getServiceArgsPath()
|
||||||
|
data, err := json.Marshal(args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal service args: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.WriteFile(argsPath, data, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write service args: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadServiceArgs loads the service arguments from a file
|
||||||
|
func loadServiceArgs() ([]string, error) {
|
||||||
|
argsPath := getServiceArgsPath()
|
||||||
|
data, err := os.ReadFile(argsPath)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return []string{}, nil // Return empty args if file doesn't exist
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to read service args: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var args []string
|
||||||
|
err = json.Unmarshal(data, &args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal service args: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type olmService struct {
|
||||||
|
elog debug.Log
|
||||||
|
ctx context.Context
|
||||||
|
stop context.CancelFunc
|
||||||
|
args []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) {
|
||||||
|
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
|
||||||
|
changes <- svc.Status{State: svc.StartPending}
|
||||||
|
|
||||||
|
s.elog.Info(1, fmt.Sprintf("Service Execute called with args: %v", args))
|
||||||
|
|
||||||
|
// Load saved service arguments
|
||||||
|
savedArgs, err := loadServiceArgs()
|
||||||
|
if err != nil {
|
||||||
|
s.elog.Error(1, fmt.Sprintf("Failed to load service args: %v", err))
|
||||||
|
// Continue with empty args if loading fails
|
||||||
|
savedArgs = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine service start args with saved args, giving priority to service start args
|
||||||
|
finalArgs := []string{}
|
||||||
|
if len(args) > 0 {
|
||||||
|
// Skip the first arg which is typically the service name
|
||||||
|
if len(args) > 1 {
|
||||||
|
finalArgs = append(finalArgs, args[1:]...)
|
||||||
|
}
|
||||||
|
s.elog.Info(1, fmt.Sprintf("Using service start parameters: %v", finalArgs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no service start parameters, use saved args
|
||||||
|
if len(finalArgs) == 0 && len(savedArgs) > 0 {
|
||||||
|
finalArgs = savedArgs
|
||||||
|
s.elog.Info(1, fmt.Sprintf("Using saved service args: %v", finalArgs))
|
||||||
|
}
|
||||||
|
|
||||||
|
s.args = finalArgs
|
||||||
|
|
||||||
|
// Start the main olm functionality
|
||||||
|
olmDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
s.runOlm()
|
||||||
|
close(olmDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
|
||||||
|
s.elog.Info(1, "Service status set to Running")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case c := <-r:
|
||||||
|
switch c.Cmd {
|
||||||
|
case svc.Interrogate:
|
||||||
|
changes <- c.CurrentStatus
|
||||||
|
case svc.Stop, svc.Shutdown:
|
||||||
|
s.elog.Info(1, "Service stopping")
|
||||||
|
changes <- svc.Status{State: svc.StopPending}
|
||||||
|
if s.stop != nil {
|
||||||
|
s.stop()
|
||||||
|
}
|
||||||
|
// Wait for main logic to finish or timeout
|
||||||
|
select {
|
||||||
|
case <-olmDone:
|
||||||
|
s.elog.Info(1, "Main logic finished gracefully")
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
s.elog.Info(1, "Timeout waiting for main logic to finish")
|
||||||
|
}
|
||||||
|
return false, 0
|
||||||
|
default:
|
||||||
|
s.elog.Error(1, fmt.Sprintf("Unexpected control request #%d", c))
|
||||||
|
}
|
||||||
|
case <-olmDone:
|
||||||
|
s.elog.Info(1, "Main olm logic completed, stopping service")
|
||||||
|
changes <- svc.Status{State: svc.StopPending}
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *olmService) runOlm() {
|
||||||
|
// Create a context that can be cancelled when the service stops
|
||||||
|
s.ctx, s.stop = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
// Setup logging for service mode
|
||||||
|
s.elog.Info(1, "Starting Olm main logic")
|
||||||
|
|
||||||
|
// Run the main olm logic and wait for it to complete
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
s.elog.Error(1, fmt.Sprintf("Olm panic: %v", r))
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Call the main olm function with stored arguments
|
||||||
|
runOlmMainWithArgs(s.ctx, s.args)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for either context cancellation or main logic completion
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
s.elog.Info(1, "Olm service context cancelled")
|
||||||
|
case <-done:
|
||||||
|
s.elog.Info(1, "Olm main logic completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runService(name string, isDebug bool, args []string) {
|
||||||
|
var err error
|
||||||
|
var elog debug.Log
|
||||||
|
|
||||||
|
if isDebug {
|
||||||
|
elog = debug.New(name)
|
||||||
|
fmt.Printf("Starting %s service in debug mode\n", name)
|
||||||
|
} else {
|
||||||
|
elog, err = eventlog.Open(name)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to open event log: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer elog.Close()
|
||||||
|
|
||||||
|
elog.Info(1, fmt.Sprintf("Starting %s service", name))
|
||||||
|
run := svc.Run
|
||||||
|
if isDebug {
|
||||||
|
run = debug.Run
|
||||||
|
}
|
||||||
|
|
||||||
|
service := &olmService{elog: elog, args: args}
|
||||||
|
err = run(name, service)
|
||||||
|
if err != nil {
|
||||||
|
elog.Error(1, fmt.Sprintf("%s service failed: %v", name, err))
|
||||||
|
if isDebug {
|
||||||
|
fmt.Printf("Service failed: %v\n", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
elog.Info(1, fmt.Sprintf("%s service stopped", name))
|
||||||
|
if isDebug {
|
||||||
|
fmt.Printf("%s service stopped\n", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func installService() error {
|
||||||
|
exepath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get executable path: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
|
||||||
|
s, err := m.OpenService(serviceName)
|
||||||
|
if err == nil {
|
||||||
|
s.Close()
|
||||||
|
return fmt.Errorf("service %s already exists", serviceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := mgr.Config{
|
||||||
|
ServiceType: 0x10, // SERVICE_WIN32_OWN_PROCESS
|
||||||
|
StartType: mgr.StartManual,
|
||||||
|
ErrorControl: mgr.ErrorNormal,
|
||||||
|
DisplayName: serviceDisplayName,
|
||||||
|
Description: serviceDescription,
|
||||||
|
BinaryPathName: exepath,
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err = m.CreateService(serviceName, exepath, config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
err = eventlog.InstallAsEventCreate(serviceName, eventlog.Error|eventlog.Warning|eventlog.Info)
|
||||||
|
if err != nil {
|
||||||
|
s.Delete()
|
||||||
|
return fmt.Errorf("failed to install event log: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeService() error {
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
|
||||||
|
s, err := m.OpenService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("service %s is not installed", serviceName)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
// Stop the service if it's running
|
||||||
|
status, err := s.Query()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to query service status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.State != svc.Stopped {
|
||||||
|
_, err = s.Control(svc.Stop)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to stop service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for service to stop
|
||||||
|
timeout := time.Now().Add(30 * time.Second)
|
||||||
|
for status.State != svc.Stopped {
|
||||||
|
if timeout.Before(time.Now()) {
|
||||||
|
return fmt.Errorf("timeout waiting for service to stop")
|
||||||
|
}
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
status, err = s.Query()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to query service status: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.Delete()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = eventlog.Remove(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove event log: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func startService(args []string) error {
|
||||||
|
// Save the service arguments as backup
|
||||||
|
if len(args) > 0 {
|
||||||
|
err := saveServiceArgs(args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to save service args: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
|
||||||
|
s, err := m.OpenService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("service %s is not installed", serviceName)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
// Pass arguments directly to the service start call
|
||||||
|
err = s.Start(args...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopService() error {
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
|
||||||
|
s, err := m.OpenService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("service %s is not installed", serviceName)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
status, err := s.Control(svc.Stop)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to stop service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.Now().Add(30 * time.Second)
|
||||||
|
for status.State != svc.Stopped {
|
||||||
|
if timeout.Before(time.Now()) {
|
||||||
|
return fmt.Errorf("timeout waiting for service to stop")
|
||||||
|
}
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
status, err = s.Query()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to query service status: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func debugService(args []string) error {
|
||||||
|
// Save the service arguments before starting
|
||||||
|
if len(args) > 0 {
|
||||||
|
err := saveServiceArgs(args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to save service args: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the service with the provided arguments
|
||||||
|
err := startService(args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watch the log file
|
||||||
|
return watchLogFile(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func watchLogFile(end bool) error {
|
||||||
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
|
||||||
|
logPath := filepath.Join(logDir, "olm.log")
|
||||||
|
|
||||||
|
// Ensure the log directory exists
|
||||||
|
err := os.MkdirAll(logDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create log directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the log file to be created if it doesn't exist
|
||||||
|
var file *os.File
|
||||||
|
for i := 0; i < 30; i++ { // Wait up to 15 seconds
|
||||||
|
file, err = os.Open(logPath)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if i == 0 {
|
||||||
|
fmt.Printf("Waiting for log file to be created...\n")
|
||||||
|
}
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open log file after waiting: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// Seek to the end of the file to only show new logs
|
||||||
|
_, err = file.Seek(0, 2)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to seek to end of file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up signal handling for graceful exit
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Create a ticker to check for new content
|
||||||
|
ticker := time.NewTicker(500 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
buffer := make([]byte, 4096)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-sigCh:
|
||||||
|
fmt.Printf("\n\nStopping log watch...\n")
|
||||||
|
// stop the service if needed
|
||||||
|
if end {
|
||||||
|
if err := stopService(); err != nil {
|
||||||
|
fmt.Printf("Failed to stop service: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf("Log watch stopped.\n")
|
||||||
|
return nil
|
||||||
|
case <-ticker.C:
|
||||||
|
// Read new content
|
||||||
|
n, err := file.Read(buffer)
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
// Try to reopen the file in case it was recreated
|
||||||
|
file.Close()
|
||||||
|
file, err = os.Open(logPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error reopening log file: %v", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
// Print the new content
|
||||||
|
fmt.Print(string(buffer[:n]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getServiceStatus() (string, error) {
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to connect to service manager: %v", err)
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
|
||||||
|
s, err := m.OpenService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return "Not Installed", nil
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
status, err := s.Query()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to query service status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch status.State {
|
||||||
|
case svc.Stopped:
|
||||||
|
return "Stopped", nil
|
||||||
|
case svc.StartPending:
|
||||||
|
return "Starting", nil
|
||||||
|
case svc.StopPending:
|
||||||
|
return "Stopping", nil
|
||||||
|
case svc.Running:
|
||||||
|
return "Running", nil
|
||||||
|
case svc.ContinuePending:
|
||||||
|
return "Continue Pending", nil
|
||||||
|
case svc.PausePending:
|
||||||
|
return "Pause Pending", nil
|
||||||
|
case svc.Paused:
|
||||||
|
return "Paused", nil
|
||||||
|
default:
|
||||||
|
return "Unknown", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// showServiceConfig displays current saved service configuration
|
||||||
|
func showServiceConfig() {
|
||||||
|
configPath := getServiceArgsPath()
|
||||||
|
fmt.Printf("Service configuration file: %s\n", configPath)
|
||||||
|
|
||||||
|
args, err := loadServiceArgs()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("No saved configuration found or error loading: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) == 0 {
|
||||||
|
fmt.Println("No saved service arguments found")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Saved service arguments: %v\n", args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWindowsService() bool {
|
||||||
|
isWindowsService, err := svc.IsWindowsService()
|
||||||
|
return err == nil && isWindowsService
|
||||||
|
}
|
||||||
|
|
||||||
|
// rotateLogFile handles daily log rotation
|
||||||
|
func rotateLogFile(logDir string, logFile string) error {
|
||||||
|
// Get current log file info
|
||||||
|
info, err := os.Stat(logFile)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil // No current log file to rotate
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to stat log file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if log file is from today
|
||||||
|
now := time.Now()
|
||||||
|
fileTime := info.ModTime()
|
||||||
|
|
||||||
|
// If the log file is from today, no rotation needed
|
||||||
|
if now.Year() == fileTime.Year() && now.YearDay() == fileTime.YearDay() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create rotated filename with date
|
||||||
|
rotatedName := fmt.Sprintf("olm-%s.log", fileTime.Format("2006-01-02"))
|
||||||
|
rotatedPath := filepath.Join(logDir, rotatedName)
|
||||||
|
|
||||||
|
// Rename current log file to dated filename
|
||||||
|
err = os.Rename(logFile, rotatedPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to rotate log file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up old log files (keep last 30 days)
|
||||||
|
cleanupOldLogFiles(logDir, 30)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupOldLogFiles removes log files older than specified days
|
||||||
|
func cleanupOldLogFiles(logDir string, daysToKeep int) {
|
||||||
|
cutoff := time.Now().AddDate(0, 0, -daysToKeep)
|
||||||
|
|
||||||
|
files, err := os.ReadDir(logDir)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
if !file.IsDir() && strings.HasPrefix(file.Name(), "olm-") && strings.HasSuffix(file.Name(), ".log") {
|
||||||
|
filePath := filepath.Join(logDir, file.Name())
|
||||||
|
info, err := file.Info()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.ModTime().Before(cutoff) {
|
||||||
|
os.Remove(filePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupWindowsEventLog() {
|
||||||
|
// Create log directory if it doesn't exist
|
||||||
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
|
||||||
|
err := os.MkdirAll(logDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to create log directory: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logFile := filepath.Join(logDir, "olm.log")
|
||||||
|
|
||||||
|
// Rotate log file if needed
|
||||||
|
err = rotateLogFile(logDir, logFile)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to rotate log file: %v\n", err)
|
||||||
|
// Continue anyway to create new log file
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to open log file: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the custom logger output
|
||||||
|
logger.GetLogger().SetOutput(file)
|
||||||
|
|
||||||
|
log.Printf("Olm service logging initialized - log file: %s", logFile)
|
||||||
|
}
|
||||||
@@ -2,38 +2,81 @@ package websocket
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"software.sslmate.com/src/go-pkcs12"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type TokenResponse struct {
|
||||||
conn *websocket.Conn
|
Data struct {
|
||||||
config *Config
|
Token string `json:"token"`
|
||||||
baseURL string
|
} `json:"data"`
|
||||||
handlers map[string]MessageHandler
|
Success bool `json:"success"`
|
||||||
done chan struct{}
|
Message string `json:"message"`
|
||||||
handlersMux sync.RWMutex
|
}
|
||||||
|
|
||||||
|
type WSMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Data interface{} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is not json anymore
|
||||||
|
type Config struct {
|
||||||
|
ID string
|
||||||
|
Secret string
|
||||||
|
Endpoint string
|
||||||
|
TlsClientCert string // legacy PKCS12 file path
|
||||||
|
UserToken string // optional user token for websocket authentication
|
||||||
|
}
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
config *Config
|
||||||
|
conn *websocket.Conn
|
||||||
|
baseURL string
|
||||||
|
handlers map[string]MessageHandler
|
||||||
|
done chan struct{}
|
||||||
|
handlersMux sync.RWMutex
|
||||||
reconnectInterval time.Duration
|
reconnectInterval time.Duration
|
||||||
isConnected bool
|
isConnected bool
|
||||||
reconnectMux sync.RWMutex
|
reconnectMux sync.RWMutex
|
||||||
|
pingInterval time.Duration
|
||||||
onConnect func() error
|
pingTimeout time.Duration
|
||||||
|
onConnect func() error
|
||||||
|
onTokenUpdate func(token string)
|
||||||
|
writeMux sync.Mutex
|
||||||
|
clientType string // Type of client (e.g., "newt", "olm")
|
||||||
|
tlsConfig TLSConfig
|
||||||
|
configNeedsSave bool // Flag to track if config needs to be saved
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientOption func(*Client)
|
type ClientOption func(*Client)
|
||||||
|
|
||||||
type MessageHandler func(message WSMessage)
|
type MessageHandler func(message WSMessage)
|
||||||
|
|
||||||
|
// TLSConfig holds TLS configuration options
|
||||||
|
type TLSConfig struct {
|
||||||
|
// New separate certificate support
|
||||||
|
ClientCertFile string
|
||||||
|
ClientKeyFile string
|
||||||
|
CAFiles []string
|
||||||
|
|
||||||
|
// Existing PKCS12 support (deprecated)
|
||||||
|
PKCS12File string
|
||||||
|
}
|
||||||
|
|
||||||
// WithBaseURL sets the base URL for the client
|
// WithBaseURL sets the base URL for the client
|
||||||
func WithBaseURL(url string) ClientOption {
|
func WithBaseURL(url string) ClientOption {
|
||||||
return func(c *Client) {
|
return func(c *Client) {
|
||||||
@@ -41,16 +84,32 @@ func WithBaseURL(url string) ClientOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithTLSConfig sets the TLS configuration for the client
|
||||||
|
func WithTLSConfig(config TLSConfig) ClientOption {
|
||||||
|
return func(c *Client) {
|
||||||
|
c.tlsConfig = config
|
||||||
|
// For backward compatibility, also set the legacy field
|
||||||
|
if config.PKCS12File != "" {
|
||||||
|
c.config.TlsClientCert = config.PKCS12File
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) OnConnect(callback func() error) {
|
func (c *Client) OnConnect(callback func() error) {
|
||||||
c.onConnect = callback
|
c.onConnect = callback
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new Newt client
|
func (c *Client) OnTokenUpdate(callback func(token string)) {
|
||||||
func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
|
c.onTokenUpdate = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new websocket client
|
||||||
|
func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
NewtID: newtID,
|
ID: ID,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
|
UserToken: userToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &Client{
|
client := &Client{
|
||||||
@@ -58,39 +117,59 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
|
|||||||
baseURL: endpoint, // default value
|
baseURL: endpoint, // default value
|
||||||
handlers: make(map[string]MessageHandler),
|
handlers: make(map[string]MessageHandler),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
reconnectInterval: 10 * time.Second,
|
reconnectInterval: 3 * time.Second,
|
||||||
isConnected: false,
|
isConnected: false,
|
||||||
|
pingInterval: pingInterval,
|
||||||
|
pingTimeout: pingTimeout,
|
||||||
|
clientType: "olm",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options before loading config
|
// Apply options before loading config
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
|
if opt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
opt(client)
|
opt(client)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load existing config if available
|
|
||||||
if err := client.loadConfig(); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) GetConfig() *Config {
|
||||||
|
return c.config
|
||||||
|
}
|
||||||
|
|
||||||
// Connect establishes the WebSocket connection
|
// Connect establishes the WebSocket connection
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
go c.connectWithRetry()
|
go c.connectWithRetry()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the WebSocket connection
|
// Close closes the WebSocket connection gracefully
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
close(c.done)
|
// Signal shutdown to all goroutines first
|
||||||
if c.conn != nil {
|
select {
|
||||||
return c.conn.Close()
|
case <-c.done:
|
||||||
|
// Already closed
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
close(c.done)
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop the ping monitor
|
// Set connection status to false
|
||||||
c.setConnected(false)
|
c.setConnected(false)
|
||||||
|
|
||||||
|
// Close the WebSocket connection gracefully
|
||||||
|
if c.conn != nil {
|
||||||
|
// Send close message
|
||||||
|
c.writeMux.Lock()
|
||||||
|
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||||
|
c.writeMux.Unlock()
|
||||||
|
|
||||||
|
// Close the connection
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,9 +184,49 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
|
|||||||
Data: data,
|
Data: data,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Debug("Sending message: %s, data: %+v", messageType, data)
|
||||||
|
|
||||||
|
c.writeMux.Lock()
|
||||||
|
defer c.writeMux.Unlock()
|
||||||
return c.conn.WriteJSON(msg)
|
return c.conn.WriteJSON(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
|
||||||
|
stopChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
count := 0
|
||||||
|
maxAttempts := 10
|
||||||
|
|
||||||
|
err := c.SendMessage(messageType, data) // Send immediately
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send initial message: %v", err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
if count >= maxAttempts {
|
||||||
|
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = c.SendMessage(messageType, data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send message: %v", err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
case <-stopChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return func() {
|
||||||
|
close(stopChan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterHandler registers a handler for a specific message type
|
// RegisterHandler registers a handler for a specific message type
|
||||||
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
||||||
c.handlersMux.Lock()
|
c.handlersMux.Lock()
|
||||||
@@ -115,30 +234,6 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
|||||||
c.handlers[messageType] = handler
|
c.handlers[messageType] = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// readPump pumps messages from the WebSocket connection
|
|
||||||
func (c *Client) readPump() {
|
|
||||||
defer c.conn.Close()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
var msg WSMessage
|
|
||||||
err := c.conn.ReadJSON(&msg)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.handlersMux.RLock()
|
|
||||||
if handler, ok := c.handlers[msg.Type]; ok {
|
|
||||||
handler(msg)
|
|
||||||
}
|
|
||||||
c.handlersMux.RUnlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) getToken() (string, error) {
|
func (c *Client) getToken() (string, error) {
|
||||||
// Parse the base URL to ensure we have the correct hostname
|
// Parse the base URL to ensure we have the correct hostname
|
||||||
baseURL, err := url.Parse(c.baseURL)
|
baseURL, err := url.Parse(c.baseURL)
|
||||||
@@ -149,57 +244,33 @@ func (c *Client) getToken() (string, error) {
|
|||||||
// Ensure we have the base URL without trailing slashes
|
// Ensure we have the base URL without trailing slashes
|
||||||
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
||||||
|
|
||||||
// If we already have a token, try to use it
|
var tlsConfig *tls.Config = nil
|
||||||
if c.config.Token != "" {
|
|
||||||
tokenCheckData := map[string]interface{}{
|
// Use new TLS configuration method
|
||||||
"newtId": c.config.NewtID,
|
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||||
"secret": c.config.Secret,
|
tlsConfig, err = c.setupTLS()
|
||||||
"token": c.config.Token,
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(tokenCheckData)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to marshal token check data: %w", err)
|
return "", fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new request
|
|
||||||
req, err := http.NewRequest(
|
|
||||||
"POST",
|
|
||||||
baseEndpoint+"/api/v1/auth/newt/get-token",
|
|
||||||
bytes.NewBuffer(jsonData),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to create request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set headers
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
|
||||||
|
|
||||||
// Make the request
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to check token validity: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var tokenResp TokenResponse
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
||||||
return "", fmt.Errorf("failed to decode token check response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If token is still valid, return it
|
|
||||||
if tokenResp.Success && tokenResp.Message == "Token session already valid" {
|
|
||||||
return c.config.Token, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a new token
|
// Check for environment variable to skip TLS verification
|
||||||
tokenData := map[string]interface{}{
|
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
|
||||||
"newtId": c.config.NewtID,
|
if tlsConfig == nil {
|
||||||
|
tlsConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenData map[string]interface{}
|
||||||
|
|
||||||
|
tokenData = map[string]interface{}{
|
||||||
|
"olmId": c.config.ID,
|
||||||
"secret": c.config.Secret,
|
"secret": c.config.Secret,
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(tokenData)
|
jsonData, err := json.Marshal(tokenData)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to marshal token request data: %w", err)
|
return "", fmt.Errorf("failed to marshal token request data: %w", err)
|
||||||
}
|
}
|
||||||
@@ -207,7 +278,7 @@ func (c *Client) getToken() (string, error) {
|
|||||||
// Create a new request
|
// Create a new request
|
||||||
req, err := http.NewRequest(
|
req, err := http.NewRequest(
|
||||||
"POST",
|
"POST",
|
||||||
baseEndpoint+"/api/v1/auth/newt/get-token",
|
baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token",
|
||||||
bytes.NewBuffer(jsonData),
|
bytes.NewBuffer(jsonData),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -220,14 +291,26 @@ func (c *Client) getToken() (string, error) {
|
|||||||
|
|
||||||
// Make the request
|
// Make the request
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
|
if tlsConfig != nil {
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
TLSClientConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to request new token: %w", err)
|
return "", fmt.Errorf("failed to request new token: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||||
|
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
var tokenResp TokenResponse
|
var tokenResp TokenResponse
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||||
|
logger.Error("Failed to decode token response.")
|
||||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,6 +322,8 @@ func (c *Client) getToken() (string, error) {
|
|||||||
return "", fmt.Errorf("received empty token from server")
|
return "", fmt.Errorf("received empty token from server")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
||||||
|
|
||||||
return tokenResp.Data.Token, nil
|
return tokenResp.Data.Token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,6 +351,10 @@ func (c *Client) establishConnection() error {
|
|||||||
return fmt.Errorf("failed to get token: %w", err)
|
return fmt.Errorf("failed to get token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.onTokenUpdate != nil {
|
||||||
|
c.onTokenUpdate(token)
|
||||||
|
}
|
||||||
|
|
||||||
// Parse the base URL to determine protocol and hostname
|
// Parse the base URL to determine protocol and hostname
|
||||||
baseURL, err := url.Parse(c.baseURL)
|
baseURL, err := url.Parse(c.baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -288,10 +377,35 @@ func (c *Client) establishConnection() error {
|
|||||||
// Add token to query parameters
|
// Add token to query parameters
|
||||||
q := u.Query()
|
q := u.Query()
|
||||||
q.Set("token", token)
|
q.Set("token", token)
|
||||||
|
q.Set("clientType", c.clientType)
|
||||||
|
if c.config.UserToken != "" {
|
||||||
|
q.Set("userToken", c.config.UserToken)
|
||||||
|
}
|
||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
// Connect to WebSocket
|
// Connect to WebSocket
|
||||||
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
dialer := websocket.DefaultDialer
|
||||||
|
|
||||||
|
// Use new TLS configuration method
|
||||||
|
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||||
|
logger.Info("Setting up TLS configuration for WebSocket connection")
|
||||||
|
tlsConfig, err := c.setupTLS()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||||
|
}
|
||||||
|
dialer.TLSClientConfig = tlsConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for environment variable to skip TLS verification for WebSocket connection
|
||||||
|
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
|
||||||
|
if dialer.TLSClientConfig == nil {
|
||||||
|
dialer.TLSClientConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
dialer.TLSClientConfig.InsecureSkipVerify = true
|
||||||
|
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, _, err := dialer.Dial(u.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||||
}
|
}
|
||||||
@@ -301,8 +415,8 @@ func (c *Client) establishConnection() error {
|
|||||||
|
|
||||||
// Start the ping monitor
|
// Start the ping monitor
|
||||||
go c.pingMonitor()
|
go c.pingMonitor()
|
||||||
// Start the read pump
|
// Start the read pump with disconnect detection
|
||||||
go c.readPump()
|
go c.readPumpWithDisconnectDetection()
|
||||||
|
|
||||||
if c.onConnect != nil {
|
if c.onConnect != nil {
|
||||||
if err := c.onConnect(); err != nil {
|
if err := c.onConnect(); err != nil {
|
||||||
@@ -313,8 +427,72 @@ func (c *Client) establishConnection() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupTLS configures TLS based on the TLS configuration
|
||||||
|
func (c *Client) setupTLS() (*tls.Config, error) {
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
|
||||||
|
// Handle new separate certificate configuration
|
||||||
|
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
|
||||||
|
logger.Info("Loading separate certificate files for mTLS")
|
||||||
|
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||||
|
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||||
|
|
||||||
|
// Load client certificate and key
|
||||||
|
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load client certificate pair: %w", err)
|
||||||
|
}
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||||
|
|
||||||
|
// Load CA certificates for remote validation if specified
|
||||||
|
if len(c.tlsConfig.CAFiles) > 0 {
|
||||||
|
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
for _, caFile := range c.tlsConfig.CAFiles {
|
||||||
|
caCert, err := os.ReadFile(caFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read CA file %s: %w", caFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as PEM first, then DER
|
||||||
|
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||||
|
// If PEM parsing failed, try DER
|
||||||
|
cert, err := x509.ParseCertificate(caCert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse CA certificate from %s: %w", caFile, err)
|
||||||
|
}
|
||||||
|
caCertPool.AddCert(cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tlsConfig.RootCAs = caCertPool
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to existing PKCS12 implementation for backward compatibility
|
||||||
|
if c.tlsConfig.PKCS12File != "" {
|
||||||
|
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
|
||||||
|
return c.setupPKCS12TLS()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy fallback using config.TlsClientCert
|
||||||
|
if c.config.TlsClientCert != "" {
|
||||||
|
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||||
|
return loadClientCertificate(c.config.TlsClientCert)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupPKCS12TLS loads TLS configuration from PKCS12 file
|
||||||
|
func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
|
||||||
|
return loadClientCertificate(c.tlsConfig.PKCS12File)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pingMonitor sends pings at a short interval and triggers reconnect on failure
|
||||||
func (c *Client) pingMonitor() {
|
func (c *Client) pingMonitor() {
|
||||||
ticker := time.NewTicker(30 * time.Second)
|
ticker := time.NewTicker(c.pingInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -322,11 +500,74 @@ func (c *Client) pingMonitor() {
|
|||||||
case <-c.done:
|
case <-c.done:
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
|
if c.conn == nil {
|
||||||
logger.Error("Ping failed: %v", err)
|
|
||||||
c.reconnect()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.writeMux.Lock()
|
||||||
|
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
|
||||||
|
c.writeMux.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
// Check if we're shutting down before logging error and reconnecting
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
// Expected during shutdown
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
logger.Error("Ping failed: %v", err)
|
||||||
|
c.reconnect()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
|
||||||
|
func (c *Client) readPumpWithDisconnectDetection() {
|
||||||
|
defer func() {
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
// Only attempt reconnect if we're not shutting down
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
// Shutting down, don't reconnect
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
c.reconnect()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
var msg WSMessage
|
||||||
|
err := c.conn.ReadJSON(&msg)
|
||||||
|
if err != nil {
|
||||||
|
// Check if we're shutting down before logging error
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
// Expected during shutdown, don't log as error
|
||||||
|
logger.Debug("WebSocket connection closed during shutdown")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// Unexpected error during normal operation
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
||||||
|
logger.Error("WebSocket read error: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Debug("WebSocket connection closed: %v", err)
|
||||||
|
}
|
||||||
|
return // triggers reconnect via defer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.handlersMux.RLock()
|
||||||
|
if handler, ok := c.handlers[msg.Type]; ok {
|
||||||
|
handler(msg)
|
||||||
|
}
|
||||||
|
c.handlersMux.RUnlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -335,9 +576,16 @@ func (c *Client) reconnect() {
|
|||||||
c.setConnected(false)
|
c.setConnected(false)
|
||||||
if c.conn != nil {
|
if c.conn != nil {
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
go c.connectWithRetry()
|
// Only reconnect if we're not shutting down
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
go c.connectWithRetry()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) setConnected(status bool) {
|
func (c *Client) setConnected(status bool) {
|
||||||
@@ -345,3 +593,42 @@ func (c *Client) setConnected(status bool) {
|
|||||||
defer c.reconnectMux.Unlock()
|
defer c.reconnectMux.Unlock()
|
||||||
c.isConnected = status
|
c.isConnected = status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
|
||||||
|
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||||
|
logger.Info("Loading tls-client-cert %s", p12Path)
|
||||||
|
// Read the PKCS12 file
|
||||||
|
p12Data, err := os.ReadFile(p12Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse PKCS12 with empty password for non-encrypted files
|
||||||
|
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate
|
||||||
|
cert := tls.Certificate{
|
||||||
|
Certificate: [][]byte{certificate.Raw},
|
||||||
|
PrivateKey: privateKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional: Add CA certificates if present
|
||||||
|
rootCAs, err := x509.SystemCertPool()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
|
||||||
|
}
|
||||||
|
if len(caCerts) > 0 {
|
||||||
|
for _, caCert := range caCerts {
|
||||||
|
rootCAs.AddCert(caCert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TLS configuration
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
)
|
|
||||||
|
|
||||||
func getConfigPath() string {
|
|
||||||
var configDir string
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "darwin":
|
|
||||||
configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "newt-client")
|
|
||||||
case "windows":
|
|
||||||
configDir = filepath.Join(os.Getenv("APPDATA"), "newt-client")
|
|
||||||
default: // linux and others
|
|
||||||
configDir = filepath.Join(os.Getenv("HOME"), ".config", "newt-client")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
|
||||||
log.Printf("Failed to create config directory: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return filepath.Join(configDir, "config.json")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) loadConfig() error {
|
|
||||||
if c.config.NewtID != "" && c.config.Secret != "" && c.config.Endpoint != "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
configPath := getConfigPath()
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var config Config
|
|
||||||
if err := json.Unmarshal(data, &config); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.NewtID == "" {
|
|
||||||
c.config.NewtID = config.NewtID
|
|
||||||
}
|
|
||||||
if c.config.Token == "" {
|
|
||||||
c.config.Token = config.Token
|
|
||||||
}
|
|
||||||
if c.config.Secret == "" {
|
|
||||||
c.config.Secret = config.Secret
|
|
||||||
}
|
|
||||||
if c.config.Endpoint == "" {
|
|
||||||
c.config.Endpoint = config.Endpoint
|
|
||||||
c.baseURL = config.Endpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) saveConfig() error {
|
|
||||||
configPath := getConfigPath()
|
|
||||||
data, err := json.MarshalIndent(c.config, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return os.WriteFile(configPath, data, 0644)
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package websocket
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
NewtID string `json:"newtId"`
|
|
||||||
Secret string `json:"secret"`
|
|
||||||
Token string `json:"token"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TokenResponse struct {
|
|
||||||
Data struct {
|
|
||||||
Token string `json:"token"`
|
|
||||||
} `json:"data"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type WSMessage struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Data interface{} `json:"data"`
|
|
||||||
}
|
|
||||||
260
wgtester/wgtester.go
Normal file
260
wgtester/wgtester.go
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
package wgtester
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Magic bytes to identify our packets
|
||||||
|
magicHeader uint32 = 0xDEADBEEF
|
||||||
|
// Request packet type
|
||||||
|
packetTypeRequest uint8 = 1
|
||||||
|
// Response packet type
|
||||||
|
packetTypeResponse uint8 = 2
|
||||||
|
// Packet format:
|
||||||
|
// - 4 bytes: magic header (0xDEADBEEF)
|
||||||
|
// - 1 byte: packet type (1 = request, 2 = response)
|
||||||
|
// - 8 bytes: timestamp (for round-trip timing)
|
||||||
|
packetSize = 13
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client handles checking connectivity to a server
|
||||||
|
type Client struct {
|
||||||
|
conn *net.UDPConn
|
||||||
|
serverAddr string
|
||||||
|
monitorRunning bool
|
||||||
|
monitorLock sync.Mutex
|
||||||
|
connLock sync.Mutex // Protects connection operations
|
||||||
|
shutdownCh chan struct{}
|
||||||
|
packetInterval time.Duration
|
||||||
|
timeout time.Duration
|
||||||
|
maxAttempts int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionStatus represents the current connection state
|
||||||
|
type ConnectionStatus struct {
|
||||||
|
Connected bool
|
||||||
|
RTT time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new connection test client
|
||||||
|
func NewClient(serverAddr string) (*Client, error) {
|
||||||
|
return &Client{
|
||||||
|
serverAddr: serverAddr,
|
||||||
|
shutdownCh: make(chan struct{}),
|
||||||
|
packetInterval: 2 * time.Second,
|
||||||
|
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||||
|
maxAttempts: 3, // Default max attempts
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPacketInterval changes how frequently packets are sent in monitor mode
|
||||||
|
func (c *Client) SetPacketInterval(interval time.Duration) {
|
||||||
|
c.packetInterval = interval
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTimeout changes the timeout for waiting for responses
|
||||||
|
func (c *Client) SetTimeout(timeout time.Duration) {
|
||||||
|
c.timeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||||
|
func (c *Client) SetMaxAttempts(attempts int) {
|
||||||
|
c.maxAttempts = attempts
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close cleans up client resources
|
||||||
|
func (c *Client) Close() {
|
||||||
|
c.StopMonitor()
|
||||||
|
|
||||||
|
c.connLock.Lock()
|
||||||
|
defer c.connLock.Unlock()
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureConnection makes sure we have an active UDP connection
|
||||||
|
func (c *Client) ensureConnection() error {
|
||||||
|
c.connLock.Lock()
|
||||||
|
defer c.connLock.Unlock()
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.conn, err = net.DialUDP("udp", nil, serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConnection checks if the connection to the server is working
|
||||||
|
// Returns true if connected, false otherwise
|
||||||
|
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||||
|
if err := c.ensureConnection(); err != nil {
|
||||||
|
logger.Warn("Failed to ensure connection: %v", err)
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare packet buffer
|
||||||
|
packet := make([]byte, packetSize)
|
||||||
|
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
|
||||||
|
packet[4] = packetTypeRequest
|
||||||
|
|
||||||
|
// Send multiple attempts as specified
|
||||||
|
for attempt := 0; attempt < c.maxAttempts; attempt++ {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false, 0
|
||||||
|
default:
|
||||||
|
// Add current timestamp to packet
|
||||||
|
timestamp := time.Now().UnixNano()
|
||||||
|
binary.BigEndian.PutUint64(packet[5:13], uint64(timestamp))
|
||||||
|
|
||||||
|
// Lock the connection for the entire send/receive operation
|
||||||
|
c.connLock.Lock()
|
||||||
|
|
||||||
|
// Check if connection is still valid after acquiring lock
|
||||||
|
if c.conn == nil {
|
||||||
|
c.connLock.Unlock()
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
|
||||||
|
_, err := c.conn.Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
c.connLock.Unlock()
|
||||||
|
logger.Info("Error sending packet: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Debug("Successfully sent monitor packet")
|
||||||
|
|
||||||
|
// Set read deadline
|
||||||
|
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||||
|
|
||||||
|
// Wait for response
|
||||||
|
responseBuffer := make([]byte, packetSize)
|
||||||
|
n, err := c.conn.Read(responseBuffer)
|
||||||
|
c.connLock.Unlock()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
// Timeout, try next attempt
|
||||||
|
time.Sleep(100 * time.Millisecond) // Brief pause between attempts
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Error("Error reading response: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != packetSize {
|
||||||
|
continue // Malformed packet
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
magic := binary.BigEndian.Uint32(responseBuffer[0:4])
|
||||||
|
packetType := responseBuffer[4]
|
||||||
|
if magic != magicHeader || packetType != packetTypeResponse {
|
||||||
|
continue // Not our response
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the original timestamp and calculate RTT
|
||||||
|
sentTimestamp := int64(binary.BigEndian.Uint64(responseBuffer[5:13]))
|
||||||
|
rtt := time.Duration(time.Now().UnixNano() - sentTimestamp)
|
||||||
|
|
||||||
|
return true, rtt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConnectionWithTimeout tries to test connection with a timeout
|
||||||
|
// Returns true if connected, false otherwise
|
||||||
|
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
return c.TestConnection(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MonitorCallback is the function type for connection status change callbacks
|
||||||
|
type MonitorCallback func(status ConnectionStatus)
|
||||||
|
|
||||||
|
// StartMonitor begins monitoring the connection and calls the callback
|
||||||
|
// when the connection status changes
|
||||||
|
func (c *Client) StartMonitor(callback MonitorCallback) error {
|
||||||
|
c.monitorLock.Lock()
|
||||||
|
defer c.monitorLock.Unlock()
|
||||||
|
|
||||||
|
if c.monitorRunning {
|
||||||
|
logger.Info("Monitor already running")
|
||||||
|
return nil // Already running
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ensureConnection(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.monitorRunning = true
|
||||||
|
c.shutdownCh = make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var lastConnected bool
|
||||||
|
firstRun := true
|
||||||
|
|
||||||
|
ticker := time.NewTicker(c.packetInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.shutdownCh:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||||
|
connected, rtt := c.TestConnection(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Callback if status changed or it's the first check
|
||||||
|
if connected != lastConnected || firstRun {
|
||||||
|
callback(ConnectionStatus{
|
||||||
|
Connected: connected,
|
||||||
|
RTT: rtt,
|
||||||
|
})
|
||||||
|
lastConnected = connected
|
||||||
|
firstRun = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopMonitor stops the connection monitoring
|
||||||
|
func (c *Client) StopMonitor() {
|
||||||
|
c.monitorLock.Lock()
|
||||||
|
defer c.monitorLock.Unlock()
|
||||||
|
|
||||||
|
if !c.monitorRunning {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(c.shutdownCh)
|
||||||
|
c.monitorRunning = false
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user