mirror of
https://github.com/fosrl/olm.git
synced 2026-02-09 06:26:44 +00:00
Compare commits
214 Commits
1.0.0-beta
...
bind
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d891cfa97 | ||
|
|
78e3bb374a | ||
|
|
a61c7ca1ee | ||
|
|
7696ba2e36 | ||
|
|
235877c379 | ||
|
|
befab0f8d1 | ||
|
|
914d080a57 | ||
|
|
a274b4b38f | ||
|
|
ce3c585514 | ||
|
|
963d8abad5 | ||
|
|
38eb56381f | ||
|
|
43b3822090 | ||
|
|
b0fb370c4d | ||
|
|
99328ee76f | ||
|
|
36fc3ea253 | ||
|
|
a7979259f3 | ||
|
|
ea6fa72bc0 | ||
|
|
f9adde6b1d | ||
|
|
ba25586646 | ||
|
|
952ab63e8d | ||
|
|
5e84f802ed | ||
|
|
f40b0ff820 | ||
|
|
95a4840374 | ||
|
|
27424170e4 | ||
|
|
a8ace6f64a | ||
|
|
3fa1073f49 | ||
|
|
76d86c10ff | ||
|
|
2d34c6c8b2 | ||
|
|
a7f3477bdd | ||
|
|
af0a72d296 | ||
|
|
d1e836e760 | ||
|
|
8dd45c4ca2 | ||
|
|
9db009058b | ||
|
|
29c01deb05 | ||
|
|
7224d9824d | ||
|
|
8afc28fdff | ||
|
|
4ba2fb7b53 | ||
|
|
2e6076923d | ||
|
|
4c001dc751 | ||
|
|
2b8e240752 | ||
|
|
bee490713d | ||
|
|
1cb7fd94ab | ||
|
|
dc9a547950 | ||
|
|
2be0933246 | ||
|
|
c0b1cd6bde | ||
|
|
dd00289f8e | ||
|
|
b23a02ee97 | ||
|
|
80f726cfea | ||
|
|
aa8828186f | ||
|
|
0a990d196d | ||
|
|
00e8050949 | ||
|
|
18ee4c93fb | ||
|
|
8fa2da00b6 | ||
|
|
b851cd73c9 | ||
|
|
4c19d7ef6d | ||
|
|
cbecb9a0ce | ||
|
|
4fc8db08ba | ||
|
|
7ca46e0a75 | ||
|
|
a4ea5143af | ||
|
|
e9257b6423 | ||
|
|
3c9d3a1d2c | ||
|
|
b426f14190 | ||
|
|
d48acfba39 | ||
|
|
35b48cd8e5 | ||
|
|
15bca53309 | ||
|
|
898b599db5 | ||
|
|
c07bba18bb | ||
|
|
4c24d3b808 | ||
|
|
ad4ab3d04f | ||
|
|
e21153fae1 | ||
|
|
41c3360e23 | ||
|
|
1960d32443 | ||
|
|
74b83b3303 | ||
|
|
c2c3470868 | ||
|
|
2bda3dc3cc | ||
|
|
52573c8664 | ||
|
|
0d3c34e23f | ||
|
|
891df5c74b | ||
|
|
6f3f162d2b | ||
|
|
f6fa5fd02c | ||
|
|
8f4e0ba29e | ||
|
|
32b7dc7c43 | ||
|
|
78d2ebe1de | ||
|
|
014f8eb4e5 | ||
|
|
cd42803291 | ||
|
|
5c5b303994 | ||
|
|
968873da22 | ||
|
|
e2772f918b | ||
|
|
cdf6a31b67 | ||
|
|
2ce72065a7 | ||
|
|
b3e7aafb58 | ||
|
|
dd610ad850 | ||
|
|
b4b0a832e7 | ||
|
|
79963c1f66 | ||
|
|
5f95282161 | ||
|
|
1cca54f9d5 | ||
|
|
219df22919 | ||
|
|
337d9934fd | ||
|
|
bba4d72a78 | ||
|
|
fb5c793126 | ||
|
|
1821dbb672 | ||
|
|
4fda6fe031 | ||
|
|
f286f0faf6 | ||
|
|
cba3d607bf | ||
|
|
5ca12834a1 | ||
|
|
9d41154daa | ||
|
|
63933b57fc | ||
|
|
c25d77597d | ||
|
|
ad080046a1 | ||
|
|
c1f7cf93a5 | ||
|
|
f3f112fc42 | ||
|
|
612a9ddb15 | ||
|
|
ad1fa2e59a | ||
|
|
29235f6100 | ||
|
|
848ac6b0c4 | ||
|
|
6ab66e6c36 | ||
|
|
d7f29d4709 | ||
|
|
6fb2b68e21 | ||
|
|
8d72e77d57 | ||
|
|
25a9b83496 | ||
|
|
4d33016389 | ||
|
|
0f717aec01 | ||
|
|
4c58cd6eff | ||
|
|
3ad36f95e1 | ||
|
|
b58e7c9fad | ||
|
|
85a8a737e8 | ||
|
|
8e83a83294 | ||
|
|
c1ef56001f | ||
|
|
c04e727bd3 | ||
|
|
becc214078 | ||
|
|
13e7f55b30 | ||
|
|
0be3ee7eee | ||
|
|
0b1724a3f3 | ||
|
|
e606264deb | ||
|
|
5497eb8a4e | ||
|
|
2159371371 | ||
|
|
ad8a94fdc8 | ||
|
|
61b7feef80 | ||
|
|
4cb31df3c8 | ||
|
|
3b0eef6d60 | ||
|
|
5d305f1d03 | ||
|
|
31e5d4e3bd | ||
|
|
8fb9468d08 | ||
|
|
5bbd5016aa | ||
|
|
8c40b8c578 | ||
|
|
e35f7c2d36 | ||
|
|
aeb8f203a4 | ||
|
|
73bd036e58 | ||
|
|
eb6b310304 | ||
|
|
3d70ff190f | ||
|
|
7e2d7b93a1 | ||
|
|
f50ff67057 | ||
|
|
76d5e95fbf | ||
|
|
b6db70e285 | ||
|
|
3819823d95 | ||
|
|
b2830e8473 | ||
|
|
43a43b429d | ||
|
|
1593f22691 | ||
|
|
4883402393 | ||
|
|
02eab1ff52 | ||
|
|
e0ca38bb35 | ||
|
|
8d46ae3aa2 | ||
|
|
7424caca8a | ||
|
|
9d9f10a799 | ||
|
|
a42d2b75dd | ||
|
|
8b09545cf6 | ||
|
|
eb77be09e2 | ||
|
|
ad01296c41 | ||
|
|
b553209712 | ||
|
|
c5098f0cd0 | ||
|
|
5ec1aac0d1 | ||
|
|
313ef42883 | ||
|
|
085c98668d | ||
|
|
6107d20e26 | ||
|
|
66edae4288 | ||
|
|
f69a7f647d | ||
|
|
e8bd55bed9 | ||
|
|
b23eda9c06 | ||
|
|
76503f3f2c | ||
|
|
9c3112f9bd | ||
|
|
462af30d16 | ||
|
|
fa6038eb38 | ||
|
|
f346b6cc5d | ||
|
|
f20b9ebb14 | ||
|
|
39bfe5b230 | ||
|
|
a1a3dd9ba2 | ||
|
|
7b1492f327 | ||
|
|
4e50819785 | ||
|
|
f8dccbec80 | ||
|
|
0c5c59cf00 | ||
|
|
868bb55f87 | ||
|
|
5b4245402a | ||
|
|
f7a705e6f8 | ||
|
|
3a63657822 | ||
|
|
759780508a | ||
|
|
533886f2e4 | ||
|
|
79f8745909 | ||
|
|
7b663027ac | ||
|
|
e90e55d982 | ||
|
|
a46fb23cdd | ||
|
|
10982b47a5 | ||
|
|
ab12098c9c | ||
|
|
446eb4d6f1 | ||
|
|
313afdb4c5 | ||
|
|
235a3b9426 | ||
|
|
c298ff52f3 | ||
|
|
75518b2e04 | ||
|
|
739f708ff7 | ||
|
|
2897b92f72 | ||
|
|
2c612d4018 | ||
|
|
41f0973308 | ||
|
|
4a791bdb6e | ||
|
|
9497f9c96f | ||
|
|
e17276b0c4 |
@@ -1,6 +1,6 @@
|
|||||||
.gitignore
|
.gitignore
|
||||||
.dockerignore
|
.dockerignore
|
||||||
newt
|
olm
|
||||||
*.json
|
*.json
|
||||||
README.md
|
README.md
|
||||||
Makefile
|
Makefile
|
||||||
|
|||||||
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Summary
|
||||||
|
description: A clear and concise summary of the requested feature.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Motivation
|
||||||
|
description: |
|
||||||
|
Why is this feature important?
|
||||||
|
Explain the problem this feature would solve or what use case it would enable.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Proposed Solution
|
||||||
|
description: |
|
||||||
|
How would you like to see this feature implemented?
|
||||||
|
Provide as much detail as possible about the desired behavior, configuration, or changes.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Alternatives Considered
|
||||||
|
description: Describe any alternative solutions or workarounds you've thought about.
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Additional Context
|
||||||
|
description: Add any other context, mockups, or screenshots about the feature request here.
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Before submitting, please:
|
||||||
|
- Check if there is an existing issue for this feature.
|
||||||
|
- Clearly explain the benefit and use case.
|
||||||
|
- Be as specific as possible to help contributors evaluate and implement.
|
||||||
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Create a bug report
|
||||||
|
labels: []
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Describe the Bug
|
||||||
|
description: A clear and concise description of what the bug is.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Environment
|
||||||
|
description: Please fill out the relevant details below for your environment.
|
||||||
|
value: |
|
||||||
|
- OS Type & Version: (e.g., Ubuntu 22.04)
|
||||||
|
- Pangolin Version:
|
||||||
|
- Gerbil Version:
|
||||||
|
- Traefik Version:
|
||||||
|
- Newt Version:
|
||||||
|
- Olm Version: (if applicable)
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: To Reproduce
|
||||||
|
description: |
|
||||||
|
Steps to reproduce the behavior, please provide a clear description of how to reproduce the issue, based on the linked minimal reproduction. Screenshots can be provided in the issue body below.
|
||||||
|
|
||||||
|
If using code blocks, make sure syntax highlighting is correct and double-check that the rendered preview is not broken.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Expected Behavior
|
||||||
|
description: A clear and concise description of what you expected to happen.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Before posting the issue go through the steps you've written down to make sure the steps provided are detailed and clear.
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Contributors should be able to follow the steps provided in order to reproduce the bug.
|
||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: false
|
||||||
|
contact_links:
|
||||||
|
- name: Need help or have questions?
|
||||||
|
url: https://github.com/orgs/fosrl/discussions
|
||||||
|
about: Ask questions, get help, and discuss with other community members
|
||||||
|
- name: Request a Feature
|
||||||
|
url: https://github.com/orgs/fosrl/discussions/new?category=feature-requests
|
||||||
|
about: Feature requests should be opened as discussions so others can upvote and comment
|
||||||
40
.github/dependabot.yml
vendored
Normal file
40
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "gomod"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
groups:
|
||||||
|
dev-patch-updates:
|
||||||
|
dependency-type: "development"
|
||||||
|
update-types:
|
||||||
|
- "patch"
|
||||||
|
dev-minor-updates:
|
||||||
|
dependency-type: "development"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
prod-patch-updates:
|
||||||
|
dependency-type: "production"
|
||||||
|
update-types:
|
||||||
|
- "patch"
|
||||||
|
prod-minor-updates:
|
||||||
|
dependency-type: "production"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
|
||||||
|
- package-ecosystem: "docker"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
groups:
|
||||||
|
patch-updates:
|
||||||
|
update-types:
|
||||||
|
- "patch"
|
||||||
|
minor-updates:
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
|
||||||
|
- package-ecosystem: "github-actions"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
60
.github/workflows/cicd.yml
vendored
Normal file
60
.github/workflows/cicd.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
name: CI/CD Pipeline
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- "*"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release:
|
||||||
|
name: Build and Release
|
||||||
|
runs-on: amd64-runner
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v3
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Log in to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||||
|
|
||||||
|
- name: Extract tag name
|
||||||
|
id: get-tag
|
||||||
|
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version: 1.25
|
||||||
|
|
||||||
|
- name: Update version in main.go
|
||||||
|
run: |
|
||||||
|
TAG=${{ env.TAG }}
|
||||||
|
if [ -f main.go ]; then
|
||||||
|
sed -i 's/version_replaceme/'"$TAG"'/' main.go
|
||||||
|
echo "Updated main.go with version $TAG"
|
||||||
|
else
|
||||||
|
echo "main.go not found"
|
||||||
|
fi
|
||||||
|
- name: Build and push Docker images
|
||||||
|
run: |
|
||||||
|
TAG=${{ env.TAG }}
|
||||||
|
make docker-build-release tag=$TAG
|
||||||
|
|
||||||
|
- name: Build binaries
|
||||||
|
run: |
|
||||||
|
make go-build-release
|
||||||
|
|
||||||
|
- name: Upload artifacts from /bin
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: binaries
|
||||||
|
path: bin/
|
||||||
132
.github/workflows/mirror.yaml
vendored
Normal file
132
.github/workflows/mirror.yaml
vendored
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
name: Mirror & Sign (Docker Hub to GHCR)
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: {}
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
id-token: write # for keyless OIDC
|
||||||
|
|
||||||
|
env:
|
||||||
|
SOURCE_IMAGE: docker.io/fosrl/olm
|
||||||
|
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
mirror-and-dual-sign:
|
||||||
|
runs-on: amd64-runner
|
||||||
|
steps:
|
||||||
|
- name: Install skopeo + jq
|
||||||
|
run: |
|
||||||
|
sudo apt-get update -y
|
||||||
|
sudo apt-get install -y skopeo jq
|
||||||
|
skopeo --version
|
||||||
|
|
||||||
|
- name: Install cosign
|
||||||
|
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||||
|
|
||||||
|
- name: Input check
|
||||||
|
run: |
|
||||||
|
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
|
||||||
|
echo "Source : ${SOURCE_IMAGE}"
|
||||||
|
echo "Target : ${DEST_IMAGE}"
|
||||||
|
|
||||||
|
# Auth for skopeo (containers-auth)
|
||||||
|
- name: Skopeo login to GHCR
|
||||||
|
run: |
|
||||||
|
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
|
||||||
|
|
||||||
|
# Auth for cosign (docker-config)
|
||||||
|
- name: Docker login to GHCR (for cosign)
|
||||||
|
run: |
|
||||||
|
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
|
||||||
|
|
||||||
|
- name: List source tags
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
|
||||||
|
| jq -r '.Tags[]' | sort -u > src-tags.txt
|
||||||
|
echo "Found source tags: $(wc -l < src-tags.txt)"
|
||||||
|
head -n 20 src-tags.txt || true
|
||||||
|
|
||||||
|
- name: List destination tags (skip existing)
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
|
||||||
|
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
|
||||||
|
else
|
||||||
|
: > dst-tags.txt
|
||||||
|
fi
|
||||||
|
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
|
||||||
|
|
||||||
|
- name: Mirror, dual-sign, and verify
|
||||||
|
env:
|
||||||
|
# keyless
|
||||||
|
COSIGN_YES: "true"
|
||||||
|
# key-based
|
||||||
|
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||||
|
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||||
|
# verify
|
||||||
|
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
copied=0; skipped=0; v_ok=0; errs=0
|
||||||
|
|
||||||
|
issuer="https://token.actions.githubusercontent.com"
|
||||||
|
id_regex="^https://github.com/${{ github.repository }}/.+"
|
||||||
|
|
||||||
|
while read -r tag; do
|
||||||
|
[ -z "$tag" ] && continue
|
||||||
|
|
||||||
|
if grep -Fxq "$tag" dst-tags.txt; then
|
||||||
|
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
|
||||||
|
skipped=$((skipped+1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
|
||||||
|
if ! skopeo copy --all --retry-times 3 \
|
||||||
|
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
|
||||||
|
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
|
||||||
|
errs=$((errs+1)); continue
|
||||||
|
fi
|
||||||
|
copied=$((copied+1))
|
||||||
|
|
||||||
|
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
|
||||||
|
ref="${DEST_IMAGE}@${digest}"
|
||||||
|
|
||||||
|
echo "==> cosign sign (keyless) --recursive ${ref}"
|
||||||
|
if ! cosign sign --recursive "${ref}"; then
|
||||||
|
echo "::warning title=Keyless sign failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> cosign sign (key) --recursive ${ref}"
|
||||||
|
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
|
||||||
|
echo "::warning title=Key sign failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> cosign verify (public key) ${ref}"
|
||||||
|
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
|
||||||
|
echo "::warning title=Verify(pubkey) failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> cosign verify (keyless policy) ${ref}"
|
||||||
|
if ! cosign verify \
|
||||||
|
--certificate-oidc-issuer "${issuer}" \
|
||||||
|
--certificate-identity-regexp "${id_regex}" \
|
||||||
|
"${ref}" -o text; then
|
||||||
|
echo "::warning title=Verify(keyless) failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
else
|
||||||
|
v_ok=$((v_ok+1))
|
||||||
|
fi
|
||||||
|
done < src-tags.txt
|
||||||
|
|
||||||
|
echo "---- Summary ----"
|
||||||
|
echo "Copied : $copied"
|
||||||
|
echo "Skipped : $skipped"
|
||||||
|
echo "Verified OK : $v_ok"
|
||||||
|
echo "Errors : $errs"
|
||||||
28
.github/workflows/test.yml
vendored
Normal file
28
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
name: Run Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- dev
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version: 1.25
|
||||||
|
|
||||||
|
- name: Build go
|
||||||
|
run: go build
|
||||||
|
|
||||||
|
- name: Build Docker image
|
||||||
|
run: make build
|
||||||
|
|
||||||
|
- name: Build binaries
|
||||||
|
run: make go-build-release
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1 +1,2 @@
|
|||||||
newt
|
.DS_Store
|
||||||
|
bin/
|
||||||
1
.go-version
Normal file
1
.go-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
1.25
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
Contributions are welcome! Please see the following page in our documentation with future plans and feature ideas if you are looking for a place to start.
|
Contributions are welcome!
|
||||||
|
|
||||||
https://docs.fossorial.io/roadmap
|
Please see the contribution and local development guide on the docs page before getting started:
|
||||||
|
|
||||||
|
https://docs.pangolin.net/development/contributing
|
||||||
|
|
||||||
### Licensing Considerations
|
### Licensing Considerations
|
||||||
|
|
||||||
@@ -15,4 +17,4 @@ By creating this pull request, I grant the project maintainers an unlimited,
|
|||||||
perpetual license to use, modify, and redistribute these contributions under any terms they
|
perpetual license to use, modify, and redistribute these contributions under any terms they
|
||||||
choose, including both the AGPLv3 and the Fossorial Commercial license terms. I
|
choose, including both the AGPLv3 and the Fossorial Commercial license terms. I
|
||||||
represent that I have the right to grant this license for all contributed content.
|
represent that I have the right to grant this license for all contributed content.
|
||||||
```
|
```
|
||||||
|
|||||||
12
Dockerfile
12
Dockerfile
@@ -1,4 +1,4 @@
|
|||||||
FROM golang:1.23.1-alpine AS builder
|
FROM golang:1.25-alpine AS builder
|
||||||
|
|
||||||
# Set the working directory inside the container
|
# Set the working directory inside the container
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
@@ -13,15 +13,15 @@ RUN go mod download
|
|||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Build the application
|
# Build the application
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /newt
|
RUN CGO_ENABLED=0 GOOS=linux go build -o /olm
|
||||||
|
|
||||||
# Start a new stage from scratch
|
# Start a new stage from scratch
|
||||||
FROM ubuntu:22.04 AS runner
|
FROM alpine:3.22 AS runner
|
||||||
|
|
||||||
RUN apt-get update && apt-get install ca-certificates -y && rm -rf /var/lib/apt/lists/*
|
RUN apk --no-cache add ca-certificates
|
||||||
|
|
||||||
# Copy the pre-built binary file from the previous stage and the entrypoint script
|
# Copy the pre-built binary file from the previous stage and the entrypoint script
|
||||||
COPY --from=builder /newt /usr/local/bin/
|
COPY --from=builder /olm /usr/local/bin/
|
||||||
COPY entrypoint.sh /
|
COPY entrypoint.sh /
|
||||||
|
|
||||||
RUN chmod +x /entrypoint.sh
|
RUN chmod +x /entrypoint.sh
|
||||||
@@ -30,4 +30,4 @@ RUN chmod +x /entrypoint.sh
|
|||||||
ENTRYPOINT ["/entrypoint.sh"]
|
ENTRYPOINT ["/entrypoint.sh"]
|
||||||
|
|
||||||
# Command to run the executable
|
# Command to run the executable
|
||||||
CMD ["newt"]
|
CMD ["olm"]
|
||||||
31
Makefile
31
Makefile
@@ -1,17 +1,26 @@
|
|||||||
|
|
||||||
all: build push
|
all: go-build-release
|
||||||
|
|
||||||
build:
|
docker-build-release:
|
||||||
docker build -t fosrl/newt:latest .
|
@if [ -z "$(tag)" ]; then \
|
||||||
|
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
|
||||||
push:
|
exit 1; \
|
||||||
docker push fosrl/newt:latest
|
fi
|
||||||
|
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:latest -f Dockerfile --push .
|
||||||
test:
|
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:$(tag) -f Dockerfile --push .
|
||||||
docker run fosrl/newt:latest
|
|
||||||
|
|
||||||
local:
|
local:
|
||||||
CGO_ENABLED=0 go build -o newt
|
CGO_ENABLED=0 go build -o bin/olm
|
||||||
|
|
||||||
|
build:
|
||||||
|
docker build -t fosrl/olm:latest .
|
||||||
|
|
||||||
|
go-build-release:
|
||||||
|
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/olm_linux_arm64
|
||||||
|
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/olm_linux_amd64
|
||||||
|
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/olm_darwin_arm64
|
||||||
|
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/olm_darwin_amd64
|
||||||
|
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/olm_windows_amd64.exe
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm newt
|
rm olm
|
||||||
303
README.md
303
README.md
@@ -1,73 +1,292 @@
|
|||||||
# Newt
|
# Olm
|
||||||
|
|
||||||
Newt is a fully user space [WireGuard](https://www.wireguard.com/) tunnel client and TCP/UDP proxy, designed to securely expose private resources controlled by Pangolin. By using Newt, you don't need to manage complex WireGuard tunnels and NATing.
|
Olm is a [WireGuard](https://www.wireguard.com/) tunnel client designed to securely connect your computer to Newt sites running on remote networks.
|
||||||
|
|
||||||
### Installation and Documentation
|
### Installation and Documentation
|
||||||
|
|
||||||
Newt is used with Pangolin and Gerbil as part of the larger system. See documentation below:
|
Olm is used with Pangolin and Newt as part of the larger system. See documentation below:
|
||||||
|
|
||||||
- [Installation Instructions](https://docs.fossorial.io)
|
- [Full Documentation](https://docs.pangolin.net)
|
||||||
- [Full Documentation](https://docs.fossorial.io)
|
|
||||||
|
|
||||||
## Preview
|
|
||||||
|
|
||||||
<img src="public/screenshots/preview.png" alt="Preview"/>
|
|
||||||
|
|
||||||
_Sample output of a Newt container connected to Pangolin and hosting various resource target proxies._
|
|
||||||
|
|
||||||
## Key Functions
|
## Key Functions
|
||||||
|
|
||||||
### Registers with Pangolin
|
### Registers with Pangolin
|
||||||
|
|
||||||
Using the Newt ID and a secret the client will make HTTP requests to Pangolin to receive a session token. Using that token it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket.
|
Using the Olm ID and a secret, the olm will make HTTP requests to Pangolin to receive a session token. Using that token, it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket.
|
||||||
|
|
||||||
### Receives WireGuard Control Messages
|
### Receives WireGuard Control Messages
|
||||||
|
|
||||||
When Newt receives WireGuard control messages, it will use the information encoded (endpoint, public key) to bring up a WireGuard tunnel using [netstack](https://github.com/WireGuard/wireguard-go/blob/master/tun/netstack/examples/http_server.go) fully in user space. It will ping over the tunnel to ensure the peer on the Gerbil side is brought up.
|
When Olm receives WireGuard control messages, it will use the information encoded (endpoint, public key) to bring up a WireGuard tunnel on your computer to a remote Newt. It will ping over the tunnel to ensure the peer is brought up.
|
||||||
|
|
||||||
### Receives Proxy Control Messages
|
|
||||||
|
|
||||||
When Newt receives WireGuard control messages, it will use the information encoded to crate local low level TCP and UDP proxies attached to the virtual tunnel in order to relay traffic to programmed targets.
|
|
||||||
|
|
||||||
## CLI Args
|
## CLI Args
|
||||||
|
|
||||||
- `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket.
|
- `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket.
|
||||||
- `id`: Newt ID generated by Pangolin to identify the client.
|
- `id`: Olm ID generated by Pangolin to identify the olm.
|
||||||
- `secret`: A unique secret (not shared and kept private) used to authenticate the client ID with the websocket in order to receive commands.
|
- `secret`: A unique secret (not shared and kept private) used to authenticate the olm ID with the websocket in order to receive commands.
|
||||||
- `dns`: DNS server to use to resolve the endpoint
|
- `mtu` (optional): MTU for the internal WG interface. Default: 1280
|
||||||
- `log-level` (optional): The log level to use. Default: INFO
|
- `dns` (optional): DNS server to use to resolve the endpoint. Default: 8.8.8.8
|
||||||
|
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: INFO
|
||||||
|
- `ping-interval` (optional): Interval for pinging the server. Default: 3s
|
||||||
|
- `ping-timeout` (optional): Timeout for each ping. Default: 5s
|
||||||
|
- `interface` (optional): Name of the WireGuard interface. Default: olm
|
||||||
|
- `enable-http` (optional): Enable HTTP server for receiving connection requests. Default: false
|
||||||
|
- `http-addr` (optional): HTTP server address (e.g., ':9452'). Default: :9452
|
||||||
|
- `holepunch` (optional): Enable hole punching. Default: false
|
||||||
|
|
||||||
Example:
|
## Environment Variables
|
||||||
|
|
||||||
|
All CLI arguments can also be set via environment variables:
|
||||||
|
|
||||||
|
- `PANGOLIN_ENDPOINT`: Equivalent to `--endpoint`
|
||||||
|
- `OLM_ID`: Equivalent to `--id`
|
||||||
|
- `OLM_SECRET`: Equivalent to `--secret`
|
||||||
|
- `MTU`: Equivalent to `--mtu`
|
||||||
|
- `DNS`: Equivalent to `--dns`
|
||||||
|
- `LOG_LEVEL`: Equivalent to `--log-level`
|
||||||
|
- `INTERFACE`: Equivalent to `--interface`
|
||||||
|
- `HTTP_ADDR`: Equivalent to `--http-addr`
|
||||||
|
- `PING_INTERVAL`: Equivalent to `--ping-interval`
|
||||||
|
- `PING_TIMEOUT`: Equivalent to `--ping-timeout`
|
||||||
|
- `HOLEPUNCH`: Set to "true" to enable hole punching (equivalent to `--holepunch`)
|
||||||
|
- `CONFIG_FILE`: Set to the location of a JSON file to load secret values
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./newt \
|
olm \
|
||||||
--id 31frd0uzbjvp721 \
|
--id 31frd0uzbjvp721 \
|
||||||
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
||||||
--endpoint https://example.com
|
--endpoint https://example.com
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can also run it with Docker compose. For example, a service in your `docker-compose.yml` might look like this using environment vars (recommended):
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
services:
|
services:
|
||||||
newt:
|
olm:
|
||||||
image: fosrl/newt
|
image: fosrl/olm
|
||||||
container_name: newt
|
container_name: olm
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
command:
|
network_mode: host
|
||||||
- --id 31frd0uzbjvp721 \
|
devices:
|
||||||
- --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
- /dev/net/tun:/dev/net/tun
|
||||||
- --endpoint https://example.com
|
environment:
|
||||||
|
- PANGOLIN_ENDPOINT=https://example.com
|
||||||
|
- OLM_ID=31frd0uzbjvp721
|
||||||
|
- OLM_SECRET=h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also pass the CLI args to the container:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
services:
|
||||||
|
olm:
|
||||||
|
image: fosrl/olm
|
||||||
|
container_name: olm
|
||||||
|
restart: unless-stopped
|
||||||
|
network_mode: host
|
||||||
|
devices:
|
||||||
|
- /dev/net/tun:/dev/net/tun
|
||||||
|
command:
|
||||||
|
- --id 31frd0uzbjvp721
|
||||||
|
- --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6
|
||||||
|
- --endpoint https://example.com
|
||||||
|
```
|
||||||
|
|
||||||
|
**Docker Configuration Notes:**
|
||||||
|
|
||||||
|
- `network_mode: host` brings the olm network interface to the host system, allowing the WireGuard tunnel to function properly
|
||||||
|
- `devices: - /dev/net/tun:/dev/net/tun` is required to give the container access to the TUN device for creating WireGuard interfaces
|
||||||
|
|
||||||
|
## Loading secrets from files
|
||||||
|
|
||||||
|
You can use `CONFIG_FILE` to define a location of a config file to store the credentials between runs.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cat ~/.config/olm-client/config.json
|
||||||
|
{
|
||||||
|
"id": "spmzu8rbpzj1qq6",
|
||||||
|
"secret": "f6v61mjutwme2kkydbw3fjo227zl60a2tsf5psw9r25hgae3",
|
||||||
|
"endpoint": "https://app.pangolin.net",
|
||||||
|
"tlsClientCert": ""
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This file is also written to when newt first starts up. So you do not need to run every time with --id and secret if you have run it once!
|
||||||
|
|
||||||
|
Default locations:
|
||||||
|
|
||||||
|
- **macOS**: `~/Library/Application Support/olm-client/config.json`
|
||||||
|
- **Windows**: `%PROGRAMDATA%\olm\olm-client\config.json`
|
||||||
|
- **Linux/Others**: `~/.config/olm-client/config.json`
|
||||||
|
|
||||||
|
## Hole Punching
|
||||||
|
|
||||||
|
In the default mode, olm "relays" traffic through Gerbil in the cloud to get down to newt. This is a little more reliable. Support for NAT hole punching is also EXPERIMENTAL right now using the `--holepunch` flag. This will attempt to orchestrate a NAT hole punch between the two sites so that traffic flows directly. This will save data costs and speed. If it fails it should fall back to relaying.
|
||||||
|
|
||||||
|
Right now, basic NAT hole punching is supported. We plan to add:
|
||||||
|
|
||||||
|
- [ ] Birthday paradox
|
||||||
|
- [ ] UPnP
|
||||||
|
- [ ] LAN detection
|
||||||
|
|
||||||
|
## Windows Service
|
||||||
|
|
||||||
|
On Windows, olm has to be installed and run as a Windows service. When running it with the cli args live above it will attempt to install and run the service to function like a cli tool. You can also run the following:
|
||||||
|
|
||||||
|
### Service Management Commands
|
||||||
|
|
||||||
|
```
|
||||||
|
# Install the service
|
||||||
|
olm.exe install
|
||||||
|
|
||||||
|
# Start the service
|
||||||
|
olm.exe start
|
||||||
|
|
||||||
|
# Stop the service
|
||||||
|
olm.exe stop
|
||||||
|
|
||||||
|
# Check service status
|
||||||
|
olm.exe status
|
||||||
|
|
||||||
|
# Remove the service
|
||||||
|
olm.exe remove
|
||||||
|
|
||||||
|
# Run in debug mode (console output) with our without id & secret
|
||||||
|
olm.exe debug
|
||||||
|
|
||||||
|
# Show help
|
||||||
|
olm.exe help
|
||||||
|
```
|
||||||
|
|
||||||
|
Note running the service requires credentials in `%PROGRAMDATA%\olm\olm-client\config.json`.
|
||||||
|
|
||||||
|
### Service Configuration
|
||||||
|
|
||||||
|
When running as a service, Olm will read configuration from environment variables or you can modify the service to include command-line arguments:
|
||||||
|
|
||||||
|
1. Install the service: `olm.exe install`
|
||||||
|
2. Set the credentials in `%PROGRAMDATA%\olm\olm-client\config.json`. Hint: if you run olm once with --id and --secret this file will be populated!
|
||||||
|
3. Start the service: `olm.exe start`
|
||||||
|
|
||||||
|
### Service Logs
|
||||||
|
|
||||||
|
When running as a service, logs are written to:
|
||||||
|
|
||||||
|
- Windows Event Log (Application log, source: "OlmWireguardService")
|
||||||
|
- Log files in: `%PROGRAMDATA%\olm\logs\olm.log`
|
||||||
|
|
||||||
|
You can view the Windows Event Log using Event Viewer or PowerShell:
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
Get-EventLog -LogName Application -Source "OlmWireguardService" -Newest 10
|
||||||
|
```
|
||||||
|
|
||||||
|
## HTTP Endpoints
|
||||||
|
|
||||||
|
Olm can be controlled with an embedded http server when using `--enable-http`. This allows you to start it as a daemon and trigger it with the following endpoints:
|
||||||
|
|
||||||
|
### POST /connect
|
||||||
|
Initiates a new connection request.
|
||||||
|
|
||||||
|
**Request Body:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "string",
|
||||||
|
"secret": "string",
|
||||||
|
"endpoint": "string"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required Fields:**
|
||||||
|
- `id`: Connection identifier
|
||||||
|
- `secret`: Authentication secret
|
||||||
|
- `endpoint`: Target endpoint URL
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
- **Status Code:** `202 Accepted`
|
||||||
|
- **Content-Type:** `application/json`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "connection request accepted"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Error Responses:**
|
||||||
|
- `405 Method Not Allowed` - Non-POST requests
|
||||||
|
- `400 Bad Request` - Invalid JSON or missing required fields
|
||||||
|
|
||||||
|
### GET /status
|
||||||
|
Returns the current connection status and peer information.
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
- **Status Code:** `200 OK`
|
||||||
|
- **Content-Type:** `application/json`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "connected",
|
||||||
|
"connected": true,
|
||||||
|
"tunnelIP": "100.89.128.3/20",
|
||||||
|
"version": "version_replaceme",
|
||||||
|
"peers": {
|
||||||
|
"10": {
|
||||||
|
"siteId": 10,
|
||||||
|
"connected": true,
|
||||||
|
"rtt": 145338339,
|
||||||
|
"lastSeen": "2025-08-13T14:39:17.208334428-07:00",
|
||||||
|
"endpoint": "p.fosrl.io:21820",
|
||||||
|
"isRelay": true
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"siteId": 8,
|
||||||
|
"connected": false,
|
||||||
|
"rtt": 0,
|
||||||
|
"lastSeen": "2025-08-13T14:39:19.663823645-07:00",
|
||||||
|
"endpoint": "p.fosrl.io:21820",
|
||||||
|
"isRelay": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Fields:**
|
||||||
|
- `status`: Overall connection status ("connected" or "disconnected")
|
||||||
|
- `connected`: Boolean connection state
|
||||||
|
- `tunnelIP`: IP address and subnet of the tunnel (when connected)
|
||||||
|
- `version`: Olm version string
|
||||||
|
- `peers`: Map of peer statuses by site ID
|
||||||
|
- `siteId`: Peer site identifier
|
||||||
|
- `connected`: Boolean peer connection state
|
||||||
|
- `rtt`: Peer round-trip time (integer, nanoseconds)
|
||||||
|
- `lastSeen`: Last time peer was seen (RFC3339 timestamp)
|
||||||
|
- `endpoint`: Peer endpoint address
|
||||||
|
- `isRelay`: Whether the peer is relayed (true) or direct (false)
|
||||||
|
|
||||||
|
**Error Responses:**
|
||||||
|
- `405 Method Not Allowed` - Non-GET requests
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Connect to a peer
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/connect \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"id": "31frd0uzbjvp721",
|
||||||
|
"secret": "h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6",
|
||||||
|
"endpoint": "https://example.com"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check connection status
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8080/status
|
||||||
```
|
```
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
### Container
|
|
||||||
|
|
||||||
Ensure Docker is installed.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make
|
|
||||||
```
|
|
||||||
|
|
||||||
### Binary
|
### Binary
|
||||||
|
|
||||||
Make sure to have Go 1.23.1 installed.
|
Make sure to have Go 1.23.1 installed.
|
||||||
@@ -78,8 +297,8 @@ make local
|
|||||||
|
|
||||||
## Licensing
|
## Licensing
|
||||||
|
|
||||||
Newt is dual licensed under the AGPLv3 and the Fossorial Commercial license. For inquiries about commercial licensing, please contact us.
|
Olm is dual licensed under the AGPLv3 and the Fossorial Commercial license. For inquiries about commercial licensing, please contact us.
|
||||||
|
|
||||||
## Contributions
|
## Contributions
|
||||||
|
|
||||||
Please see [CONTRIBUTIONS](./CONTRIBUTING.md) in the repository for guidelines and best practices.
|
Please see [CONTRIBUTIONS](./CONTRIBUTING.md) in the repository for guidelines and best practices.
|
||||||
|
|||||||
14
SECURITY.md
Normal file
14
SECURITY.md
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# Security Policy
|
||||||
|
|
||||||
|
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
||||||
|
|
||||||
|
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
||||||
|
2. Send a detailed report to [security@pangolin.net](mailto:security@pangolin.net) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||||
|
|
||||||
|
- Description and location of the vulnerability.
|
||||||
|
- Potential impact of the vulnerability.
|
||||||
|
- Steps to reproduce the vulnerability.
|
||||||
|
- Potential solutions to fix the vulnerability.
|
||||||
|
- Your name/handle and a link for recognition (optional).
|
||||||
|
|
||||||
|
We aim to address the issue as soon as possible.
|
||||||
411
api/api.go
Normal file
411
api/api.go
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionRequest defines the structure for an incoming connection request
|
||||||
|
type ConnectionRequest struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
UserToken string `json:"userToken,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwitchOrgRequest defines the structure for switching organizations
|
||||||
|
type SwitchOrgRequest struct {
|
||||||
|
OrgID string `json:"orgId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerStatus represents the status of a peer connection
|
||||||
|
type PeerStatus struct {
|
||||||
|
SiteID int `json:"siteId"`
|
||||||
|
Connected bool `json:"connected"`
|
||||||
|
RTT time.Duration `json:"rtt"`
|
||||||
|
LastSeen time.Time `json:"lastSeen"`
|
||||||
|
Endpoint string `json:"endpoint,omitempty"`
|
||||||
|
IsRelay bool `json:"isRelay"`
|
||||||
|
PeerIP string `json:"peerAddress,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusResponse is returned by the status endpoint
|
||||||
|
type StatusResponse struct {
|
||||||
|
Connected bool `json:"connected"`
|
||||||
|
Registered bool `json:"registered"`
|
||||||
|
TunnelIP string `json:"tunnelIP,omitempty"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
OrgID string `json:"orgId,omitempty"`
|
||||||
|
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// API represents the HTTP server and its state
|
||||||
|
type API struct {
|
||||||
|
addr string
|
||||||
|
socketPath string
|
||||||
|
listener net.Listener
|
||||||
|
server *http.Server
|
||||||
|
connectionChan chan ConnectionRequest
|
||||||
|
switchOrgChan chan SwitchOrgRequest
|
||||||
|
shutdownChan chan struct{}
|
||||||
|
disconnectChan chan struct{}
|
||||||
|
statusMu sync.RWMutex
|
||||||
|
peerStatuses map[int]*PeerStatus
|
||||||
|
connectedAt time.Time
|
||||||
|
isConnected bool
|
||||||
|
isRegistered bool
|
||||||
|
tunnelIP string
|
||||||
|
version string
|
||||||
|
orgID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||||
|
func NewAPI(addr string) *API {
|
||||||
|
s := &API{
|
||||||
|
addr: addr,
|
||||||
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
|
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||||
|
shutdownChan: make(chan struct{}, 1),
|
||||||
|
disconnectChan: make(chan struct{}, 1),
|
||||||
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
|
||||||
|
func NewAPISocket(socketPath string) *API {
|
||||||
|
s := &API{
|
||||||
|
socketPath: socketPath,
|
||||||
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
|
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||||
|
shutdownChan: make(chan struct{}, 1),
|
||||||
|
disconnectChan: make(chan struct{}, 1),
|
||||||
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the HTTP server
|
||||||
|
func (s *API) Start() error {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/connect", s.handleConnect)
|
||||||
|
mux.HandleFunc("/status", s.handleStatus)
|
||||||
|
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||||
|
mux.HandleFunc("/disconnect", s.handleDisconnect)
|
||||||
|
mux.HandleFunc("/exit", s.handleExit)
|
||||||
|
|
||||||
|
s.server = &http.Server{
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if s.socketPath != "" {
|
||||||
|
// Use platform-specific socket listener
|
||||||
|
s.listener, err = createSocketListener(s.socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create socket listener: %w", err)
|
||||||
|
}
|
||||||
|
logger.Info("Starting HTTP server on socket %s", s.socketPath)
|
||||||
|
} else {
|
||||||
|
// Use TCP listener
|
||||||
|
s.listener, err = net.Listen("tcp", s.addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create TCP listener: %w", err)
|
||||||
|
}
|
||||||
|
logger.Info("Starting HTTP server on %s", s.addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
logger.Error("HTTP server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the HTTP server
|
||||||
|
func (s *API) Stop() error {
|
||||||
|
logger.Info("Stopping api server")
|
||||||
|
|
||||||
|
// Close the server first, which will also close the listener gracefully
|
||||||
|
if s.server != nil {
|
||||||
|
s.server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up socket file if using Unix socket
|
||||||
|
if s.socketPath != "" {
|
||||||
|
cleanupSocket(s.socketPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionChannel returns the channel for receiving connection requests
|
||||||
|
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
|
||||||
|
return s.connectionChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
||||||
|
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
||||||
|
return s.switchOrgChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetShutdownChannel returns the channel for receiving shutdown requests
|
||||||
|
func (s *API) GetShutdownChannel() <-chan struct{} {
|
||||||
|
return s.shutdownChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisconnectChannel returns the channel for receiving disconnect requests
|
||||||
|
func (s *API) GetDisconnectChannel() <-chan struct{} {
|
||||||
|
return s.disconnectChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||||
|
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
|
||||||
|
status, exists := s.peerStatuses[siteID]
|
||||||
|
if !exists {
|
||||||
|
status = &PeerStatus{
|
||||||
|
SiteID: siteID,
|
||||||
|
}
|
||||||
|
s.peerStatuses[siteID] = status
|
||||||
|
}
|
||||||
|
|
||||||
|
status.Connected = connected
|
||||||
|
status.RTT = rtt
|
||||||
|
status.LastSeen = time.Now()
|
||||||
|
status.Endpoint = endpoint
|
||||||
|
status.IsRelay = isRelay
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConnectionStatus sets the overall connection status
|
||||||
|
func (s *API) SetConnectionStatus(isConnected bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
|
||||||
|
s.isConnected = isConnected
|
||||||
|
|
||||||
|
if isConnected {
|
||||||
|
s.connectedAt = time.Now()
|
||||||
|
} else {
|
||||||
|
// Clear peer statuses when disconnected
|
||||||
|
s.peerStatuses = make(map[int]*PeerStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *API) SetRegistered(registered bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.isRegistered = registered
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTunnelIP sets the tunnel IP address
|
||||||
|
func (s *API) SetTunnelIP(tunnelIP string) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.tunnelIP = tunnelIP
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVersion sets the olm version
|
||||||
|
func (s *API) SetVersion(version string) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.version = version
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOrgID sets the organization ID
|
||||||
|
func (s *API) SetOrgID(orgID string) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.orgID = orgID
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeerRelayStatus updates only the relay status of a peer
|
||||||
|
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
|
||||||
|
status, exists := s.peerStatuses[siteID]
|
||||||
|
if !exists {
|
||||||
|
status = &PeerStatus{
|
||||||
|
SiteID: siteID,
|
||||||
|
}
|
||||||
|
s.peerStatuses[siteID] = status
|
||||||
|
}
|
||||||
|
|
||||||
|
status.Endpoint = endpoint
|
||||||
|
status.IsRelay = isRelay
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConnect handles the /connect endpoint
|
||||||
|
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we are already connected, reject new connection requests
|
||||||
|
s.statusMu.RLock()
|
||||||
|
alreadyConnected := s.isConnected
|
||||||
|
s.statusMu.RUnlock()
|
||||||
|
if alreadyConnected {
|
||||||
|
http.Error(w, "Already connected to a server. Disconnect first before connecting again.", http.StatusConflict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ConnectionRequest
|
||||||
|
decoder := json.NewDecoder(r.Body)
|
||||||
|
if err := decoder.Decode(&req); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required fields
|
||||||
|
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
|
||||||
|
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the request to the main goroutine
|
||||||
|
s.connectionChan <- req
|
||||||
|
|
||||||
|
// Return a success response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"status": "connection request accepted",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleStatus handles the /status endpoint
|
||||||
|
func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.statusMu.RLock()
|
||||||
|
defer s.statusMu.RUnlock()
|
||||||
|
|
||||||
|
resp := StatusResponse{
|
||||||
|
Connected: s.isConnected,
|
||||||
|
Registered: s.isRegistered,
|
||||||
|
TunnelIP: s.tunnelIP,
|
||||||
|
Version: s.version,
|
||||||
|
OrgID: s.orgID,
|
||||||
|
PeerStatuses: s.peerStatuses,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleExit handles the /exit endpoint
|
||||||
|
func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Received exit request via API")
|
||||||
|
|
||||||
|
// Send shutdown signal
|
||||||
|
select {
|
||||||
|
case s.shutdownChan <- struct{}{}:
|
||||||
|
// Signal sent successfully
|
||||||
|
default:
|
||||||
|
// Channel already has a signal, don't block
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a success response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"status": "shutdown initiated",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSwitchOrg handles the /switch-org endpoint
|
||||||
|
func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req SwitchOrgRequest
|
||||||
|
decoder := json.NewDecoder(r.Body)
|
||||||
|
if err := decoder.Decode(&req); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required fields
|
||||||
|
if req.OrgID == "" {
|
||||||
|
http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||||
|
|
||||||
|
// Send the request to the main goroutine
|
||||||
|
select {
|
||||||
|
case s.switchOrgChan <- req:
|
||||||
|
// Signal sent successfully
|
||||||
|
default:
|
||||||
|
// Channel already has a pending request
|
||||||
|
http.Error(w, "Org switch already in progress", http.StatusConflict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a success response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"status": "org switch request accepted",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDisconnect handles the /disconnect endpoint
|
||||||
|
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we are already disconnected, reject new disconnect requests
|
||||||
|
s.statusMu.RLock()
|
||||||
|
alreadyDisconnected := !s.isConnected
|
||||||
|
s.statusMu.RUnlock()
|
||||||
|
if alreadyDisconnected {
|
||||||
|
http.Error(w, "Not currently connected to a server.", http.StatusConflict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Received disconnect request via API")
|
||||||
|
|
||||||
|
// Send disconnect signal
|
||||||
|
select {
|
||||||
|
case s.disconnectChan <- struct{}{}:
|
||||||
|
// Signal sent successfully
|
||||||
|
default:
|
||||||
|
// Channel already has a signal, don't block
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a success response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"status": "disconnect initiated",
|
||||||
|
})
|
||||||
|
}
|
||||||
50
api/api_unix.go
Normal file
50
api/api_unix.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
//go:build !windows
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createSocketListener creates a Unix domain socket listener
|
||||||
|
func createSocketListener(socketPath string) (net.Listener, error) {
|
||||||
|
// Ensure the directory exists
|
||||||
|
dir := filepath.Dir(socketPath)
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create socket directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove existing socket file if it exists
|
||||||
|
if err := os.RemoveAll(socketPath); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to remove existing socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.Listen("unix", socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to listen on Unix socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set socket permissions to allow access
|
||||||
|
if err := os.Chmod(socketPath, 0666); err != nil {
|
||||||
|
listener.Close()
|
||||||
|
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Created Unix socket at %s", socketPath)
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupSocket removes the Unix socket file
|
||||||
|
func cleanupSocket(socketPath string) {
|
||||||
|
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
logger.Error("Failed to remove socket file %s: %v", socketPath, err)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Removed Unix socket at %s", socketPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
41
api/api_windows.go
Normal file
41
api/api_windows.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/Microsoft/go-winio"
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createSocketListener creates a Windows named pipe listener
|
||||||
|
func createSocketListener(pipePath string) (net.Listener, error) {
|
||||||
|
// Ensure the pipe path has the correct format
|
||||||
|
if pipePath[0] != '\\' {
|
||||||
|
pipePath = `\\.\pipe\` + pipePath
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a pipe configuration that allows everyone to write
|
||||||
|
config := &winio.PipeConfig{
|
||||||
|
// Set security descriptor to allow everyone full access
|
||||||
|
// This SDDL string grants full access to Everyone (WD) and to the current owner (OW)
|
||||||
|
SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a named pipe listener using go-winio with the configuration
|
||||||
|
listener, err := winio.ListenPipe(pipePath, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to listen on named pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Created named pipe at %s with write access for everyone", pipePath)
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up
|
||||||
|
func cleanupSocket(pipePath string) {
|
||||||
|
logger.Debug("Named pipe %s will be automatically cleaned up", pipePath)
|
||||||
|
}
|
||||||
378
bind/shared_bind.go
Normal file
378
bind/shared_bind.go
Normal file
@@ -0,0 +1,378 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Endpoint represents a network endpoint for the SharedBind
|
||||||
|
type Endpoint struct {
|
||||||
|
AddrPort netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearSrc implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) ClearSrc() {}
|
||||||
|
|
||||||
|
// DstIP implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) DstIP() netip.Addr {
|
||||||
|
return e.AddrPort.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SrcIP implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) SrcIP() netip.Addr {
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DstToBytes implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) DstToBytes() []byte {
|
||||||
|
b, _ := e.AddrPort.MarshalBinary()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// DstToString implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) DstToString() string {
|
||||||
|
return e.AddrPort.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SrcToString implements the wgConn.Endpoint interface
|
||||||
|
func (e *Endpoint) SrcToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
|
||||||
|
// and hole punch senders. It wraps a single UDP connection and implements
|
||||||
|
// reference counting to prevent premature closure.
|
||||||
|
type SharedBind struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// The underlying UDP connection
|
||||||
|
udpConn *net.UDPConn
|
||||||
|
|
||||||
|
// IPv4 and IPv6 packet connections for advanced features
|
||||||
|
ipv4PC *ipv4.PacketConn
|
||||||
|
ipv6PC *ipv6.PacketConn
|
||||||
|
|
||||||
|
// Reference counting to prevent closing while in use
|
||||||
|
refCount atomic.Int32
|
||||||
|
closed atomic.Bool
|
||||||
|
|
||||||
|
// Channels for receiving data
|
||||||
|
recvFuncs []wgConn.ReceiveFunc
|
||||||
|
|
||||||
|
// Port binding information
|
||||||
|
port uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new SharedBind from an existing UDP connection.
|
||||||
|
// The SharedBind takes ownership of the connection and will close it
|
||||||
|
// when all references are released.
|
||||||
|
func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
||||||
|
if udpConn == nil {
|
||||||
|
return nil, fmt.Errorf("udpConn cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
bind := &SharedBind{
|
||||||
|
udpConn: udpConn,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize reference count to 1 (the creator holds the first reference)
|
||||||
|
bind.refCount.Store(1)
|
||||||
|
|
||||||
|
// Get the local port
|
||||||
|
if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok {
|
||||||
|
bind.port = uint16(addr.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bind, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRef increments the reference count. Call this when sharing
|
||||||
|
// the bind with another component.
|
||||||
|
func (b *SharedBind) AddRef() {
|
||||||
|
newCount := b.refCount.Add(1)
|
||||||
|
// Optional: Add logging for debugging
|
||||||
|
_ = newCount // Placeholder for potential logging
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release decrements the reference count. When it reaches zero,
|
||||||
|
// the underlying UDP connection is closed.
|
||||||
|
func (b *SharedBind) Release() error {
|
||||||
|
newCount := b.refCount.Add(-1)
|
||||||
|
// Optional: Add logging for debugging
|
||||||
|
_ = newCount // Placeholder for potential logging
|
||||||
|
|
||||||
|
if newCount < 0 {
|
||||||
|
// This should never happen with proper usage
|
||||||
|
b.refCount.Store(0)
|
||||||
|
return fmt.Errorf("SharedBind reference count went negative")
|
||||||
|
}
|
||||||
|
|
||||||
|
if newCount == 0 {
|
||||||
|
return b.closeConnection()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeConnection actually closes the UDP connection
|
||||||
|
func (b *SharedBind) closeConnection() error {
|
||||||
|
if !b.closed.CompareAndSwap(false, true) {
|
||||||
|
// Already closed
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if b.udpConn != nil {
|
||||||
|
err = b.udpConn.Close()
|
||||||
|
b.udpConn = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ipv4PC = nil
|
||||||
|
b.ipv6PC = nil
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUDPConn returns the underlying UDP connection.
|
||||||
|
// The caller must not close this connection directly.
|
||||||
|
func (b *SharedBind) GetUDPConn() *net.UDPConn {
|
||||||
|
b.mu.RLock()
|
||||||
|
defer b.mu.RUnlock()
|
||||||
|
return b.udpConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRefCount returns the current reference count (for debugging)
|
||||||
|
func (b *SharedBind) GetRefCount() int32 {
|
||||||
|
return b.refCount.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed returns whether the bind is closed
|
||||||
|
func (b *SharedBind) IsClosed() bool {
|
||||||
|
return b.closed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteToUDP writes data to a specific UDP address.
|
||||||
|
// This is thread-safe and can be used by hole punch senders.
|
||||||
|
func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.RLock()
|
||||||
|
conn := b.udpConn
|
||||||
|
b.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn.WriteToUDP(data, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements the WireGuard Bind interface.
|
||||||
|
// It decrements the reference count and closes the connection if no references remain.
|
||||||
|
func (b *SharedBind) Close() error {
|
||||||
|
return b.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open implements the WireGuard Bind interface.
|
||||||
|
// Since the connection is already open, this just sets up the receive functions.
|
||||||
|
func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return nil, 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
if b.udpConn == nil {
|
||||||
|
return nil, 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up IPv4 and IPv6 packet connections for advanced features
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
|
||||||
|
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create receive functions
|
||||||
|
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
|
||||||
|
|
||||||
|
// Add IPv4 receive function
|
||||||
|
if b.ipv4PC != nil || runtime.GOOS != "linux" {
|
||||||
|
recvFuncs = append(recvFuncs, b.makeReceiveIPv4())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IPv6 receive function if needed
|
||||||
|
// For now, we focus on IPv4 for hole punching use case
|
||||||
|
|
||||||
|
b.recvFuncs = recvFuncs
|
||||||
|
return recvFuncs, b.port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeReceiveIPv4 creates a receive function for IPv4 packets
|
||||||
|
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.RLock()
|
||||||
|
conn := b.udpConn
|
||||||
|
pc := b.ipv4PC
|
||||||
|
b.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use batch reading on Linux for performance
|
||||||
|
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
||||||
|
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to simple read for other platforms
|
||||||
|
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiveIPv4Batch uses batch reading for better performance on Linux
|
||||||
|
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||||
|
// Create messages for batch reading
|
||||||
|
msgs := make([]ipv4.Message, len(bufs))
|
||||||
|
for i := range bufs {
|
||||||
|
msgs[i].Buffers = [][]byte{bufs[i]}
|
||||||
|
msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use
|
||||||
|
}
|
||||||
|
|
||||||
|
numMsgs, err := pc.ReadBatch(msgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numMsgs; i++ {
|
||||||
|
sizes[i] = msgs[i].N
|
||||||
|
if sizes[i] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgs[i].Addr != nil {
|
||||||
|
if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok {
|
||||||
|
addrPort := udpAddr.AddrPort()
|
||||||
|
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return numMsgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
|
||||||
|
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||||
|
n, addr, err := conn.ReadFromUDP(bufs[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sizes[0] = n
|
||||||
|
if addr != nil {
|
||||||
|
addrPort := addr.AddrPort()
|
||||||
|
eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send implements the WireGuard Bind interface.
|
||||||
|
// It sends packets to the specified endpoint.
|
||||||
|
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.RLock()
|
||||||
|
conn := b.udpConn
|
||||||
|
b.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the destination address from the endpoint
|
||||||
|
var destAddr *net.UDPAddr
|
||||||
|
|
||||||
|
// Try to cast to StdNetEndpoint first
|
||||||
|
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
|
||||||
|
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
|
||||||
|
} else {
|
||||||
|
// Fallback: construct from DstIP and DstToBytes
|
||||||
|
dstBytes := ep.DstToBytes()
|
||||||
|
if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes)
|
||||||
|
var addr netip.Addr
|
||||||
|
var port uint16
|
||||||
|
|
||||||
|
if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes)
|
||||||
|
addr, _ = netip.AddrFromSlice(dstBytes[:16])
|
||||||
|
port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8
|
||||||
|
} else { // IPv4
|
||||||
|
addr, _ = netip.AddrFromSlice(dstBytes[:4])
|
||||||
|
port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.IsValid() {
|
||||||
|
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if destAddr == nil {
|
||||||
|
return fmt.Errorf("could not extract destination address from endpoint")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send all buffers to the destination
|
||||||
|
for _, buf := range bufs {
|
||||||
|
_, err := conn.WriteToUDP(buf, destAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMark implements the WireGuard Bind interface.
|
||||||
|
// It's a no-op for this implementation.
|
||||||
|
func (b *SharedBind) SetMark(mark uint32) error {
|
||||||
|
// Not implemented for this use case
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns the preferred batch size for sending packets.
|
||||||
|
func (b *SharedBind) BatchSize() int {
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
return wgConn.IdealBatchSize
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseEndpoint creates a new endpoint from a string address.
|
||||||
|
func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
|
||||||
|
addrPort, err := netip.ParseAddrPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil
|
||||||
|
}
|
||||||
424
bind/shared_bind_test.go
Normal file
424
bind/shared_bind_test.go
Normal file
@@ -0,0 +1,424 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSharedBindCreation tests basic creation and initialization
|
||||||
|
func TestSharedBindCreation(t *testing.T) {
|
||||||
|
// Create a UDP connection
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer udpConn.Close()
|
||||||
|
|
||||||
|
// Create SharedBind
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bind == nil {
|
||||||
|
t.Fatal("SharedBind is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify initial reference count
|
||||||
|
if bind.refCount.Load() != 1 {
|
||||||
|
t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
if err := bind.Close(); err != nil {
|
||||||
|
t.Errorf("Failed to close SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindReferenceCount tests reference counting
|
||||||
|
func TestSharedBindReferenceCount(t *testing.T) {
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add references
|
||||||
|
bind.AddRef()
|
||||||
|
if bind.refCount.Load() != 2 {
|
||||||
|
t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.AddRef()
|
||||||
|
if bind.refCount.Load() != 3 {
|
||||||
|
t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release references
|
||||||
|
bind.Release()
|
||||||
|
if bind.refCount.Load() != 2 {
|
||||||
|
t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.Release()
|
||||||
|
bind.Release() // This should close the connection
|
||||||
|
|
||||||
|
if !bind.closed.Load() {
|
||||||
|
t.Error("Expected bind to be closed after all references released")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
|
||||||
|
func TestSharedBindWriteToUDP(t *testing.T) {
|
||||||
|
// Create sender
|
||||||
|
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
senderBind, err := New(senderConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer senderBind.Close()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
// Send data
|
||||||
|
testData := []byte("Hello, SharedBind!")
|
||||||
|
n, err := senderBind.WriteToUDP(testData, receiverAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteToUDP failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != len(testData) {
|
||||||
|
t.Errorf("Expected to send %d bytes, sent %d", len(testData), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive data
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, _, err = receiverConn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to receive data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf[:n]) != string(testData) {
|
||||||
|
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindConcurrentWrites tests thread-safety
|
||||||
|
func TestSharedBindConcurrentWrites(t *testing.T) {
|
||||||
|
// Create sender
|
||||||
|
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
senderBind, err := New(senderConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer senderBind.Close()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
// Launch concurrent writes
|
||||||
|
numGoroutines := 100
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
data := []byte{byte(id)}
|
||||||
|
_, err := senderBind.WriteToUDP(data, receiverAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
|
||||||
|
func TestSharedBindWireGuardInterface(t *testing.T) {
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer bind.Close()
|
||||||
|
|
||||||
|
// Test Open
|
||||||
|
recvFuncs, port, err := bind.Open(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Open failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(recvFuncs) == 0 {
|
||||||
|
t.Error("Expected at least one receive function")
|
||||||
|
}
|
||||||
|
|
||||||
|
if port == 0 {
|
||||||
|
t.Error("Expected non-zero port")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test SetMark (should be a no-op)
|
||||||
|
if err := bind.SetMark(0); err != nil {
|
||||||
|
t.Errorf("SetMark failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test BatchSize
|
||||||
|
batchSize := bind.BatchSize()
|
||||||
|
if batchSize <= 0 {
|
||||||
|
t.Error("Expected positive batch size")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindSend tests the Send method with WireGuard endpoints
|
||||||
|
func TestSharedBindSend(t *testing.T) {
|
||||||
|
// Create sender
|
||||||
|
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
senderBind, err := New(senderConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer senderBind.Close()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
// Create an endpoint
|
||||||
|
addrPort := receiverAddr.AddrPort()
|
||||||
|
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
|
||||||
|
// Send data
|
||||||
|
testData := []byte("WireGuard packet")
|
||||||
|
bufs := [][]byte{testData}
|
||||||
|
err = senderBind.Send(bufs, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Send failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive data
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, _, err := receiverConn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to receive data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf[:n]) != string(testData) {
|
||||||
|
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
|
||||||
|
func TestSharedBindMultipleUsers(t *testing.T) {
|
||||||
|
// Create shared bind
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedBind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add reference for hole punch sender
|
||||||
|
sharedBind.AddRef()
|
||||||
|
|
||||||
|
// Create receiver
|
||||||
|
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
defer receiverConn.Close()
|
||||||
|
|
||||||
|
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Simulate WireGuard using the bind
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
addrPort := receiverAddr.AddrPort()
|
||||||
|
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
data := []byte("WireGuard packet")
|
||||||
|
bufs := [][]byte{data}
|
||||||
|
if err := sharedBind.Send(bufs, endpoint); err != nil {
|
||||||
|
t.Errorf("WireGuard Send failed: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Simulate hole punch sender using the bind
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
data := []byte("Hole punch packet")
|
||||||
|
if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil {
|
||||||
|
t.Errorf("Hole punch WriteToUDP failed: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Release the hole punch reference
|
||||||
|
sharedBind.Release()
|
||||||
|
|
||||||
|
// Close WireGuard's reference (should close the connection)
|
||||||
|
sharedBind.Close()
|
||||||
|
|
||||||
|
if !sharedBind.closed.Load() {
|
||||||
|
t.Error("Expected bind to be closed after all users released it")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEndpoint tests the Endpoint implementation
|
||||||
|
func TestEndpoint(t *testing.T) {
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
addrPort := netip.AddrPortFrom(addr, 51820)
|
||||||
|
|
||||||
|
ep := &Endpoint{AddrPort: addrPort}
|
||||||
|
|
||||||
|
// Test DstIP
|
||||||
|
if ep.DstIP() != addr {
|
||||||
|
t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test DstToString
|
||||||
|
expected := "192.168.1.1:51820"
|
||||||
|
if ep.DstToString() != expected {
|
||||||
|
t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test DstToBytes
|
||||||
|
bytes := ep.DstToBytes()
|
||||||
|
if len(bytes) == 0 {
|
||||||
|
t.Error("Expected DstToBytes to return non-empty slice")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test SrcIP (should be zero)
|
||||||
|
if ep.SrcIP().IsValid() {
|
||||||
|
t.Error("Expected SrcIP to be invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test ClearSrc (should not panic)
|
||||||
|
ep.ClearSrc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseEndpoint tests the ParseEndpoint method
|
||||||
|
func TestParseEndpoint(t *testing.T) {
|
||||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bind, err := New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||||
|
}
|
||||||
|
defer bind.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr bool
|
||||||
|
checkAddr func(*testing.T, wgConn.Endpoint)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid IPv4",
|
||||||
|
input: "192.168.1.1:51820",
|
||||||
|
wantErr: false,
|
||||||
|
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||||
|
if ep.DstToString() != "192.168.1.1:51820" {
|
||||||
|
t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid IPv6",
|
||||||
|
input: "[::1]:51820",
|
||||||
|
wantErr: false,
|
||||||
|
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||||
|
if ep.DstToString() != "[::1]:51820" {
|
||||||
|
t.Errorf("Expected [::1]:51820, got %s", ep.DstToString())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid - missing port",
|
||||||
|
input: "192.168.1.1",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid - bad format",
|
||||||
|
input: "not-an-address",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ep, err := bind.ParseEndpoint(tt.input)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !tt.wantErr && tt.checkAddr != nil {
|
||||||
|
tt.checkAddr(t, ep)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
562
config.go
Normal file
562
config.go
Normal file
@@ -0,0 +1,562 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OlmConfig holds all configuration options for the Olm client
|
||||||
|
type OlmConfig struct {
|
||||||
|
// Connection settings
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
OrgID string `json:"org"`
|
||||||
|
UserToken string `json:"userToken"`
|
||||||
|
|
||||||
|
// Network settings
|
||||||
|
MTU int `json:"mtu"`
|
||||||
|
DNS string `json:"dns"`
|
||||||
|
InterfaceName string `json:"interface"`
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
LogLevel string `json:"logLevel"`
|
||||||
|
|
||||||
|
// HTTP server
|
||||||
|
EnableAPI bool `json:"enableApi"`
|
||||||
|
HTTPAddr string `json:"httpAddr"`
|
||||||
|
SocketPath string `json:"socketPath"`
|
||||||
|
|
||||||
|
// Ping settings
|
||||||
|
PingInterval string `json:"pingInterval"`
|
||||||
|
PingTimeout string `json:"pingTimeout"`
|
||||||
|
|
||||||
|
// Advanced
|
||||||
|
Holepunch bool `json:"holepunch"`
|
||||||
|
TlsClientCert string `json:"tlsClientCert"`
|
||||||
|
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
|
||||||
|
|
||||||
|
// Parsed values (not in JSON)
|
||||||
|
PingIntervalDuration time.Duration `json:"-"`
|
||||||
|
PingTimeoutDuration time.Duration `json:"-"`
|
||||||
|
|
||||||
|
// Source tracking (not in JSON)
|
||||||
|
sources map[string]string `json:"-"`
|
||||||
|
|
||||||
|
Version string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigSource tracks where each config value came from
|
||||||
|
type ConfigSource string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SourceDefault ConfigSource = "default"
|
||||||
|
SourceFile ConfigSource = "file"
|
||||||
|
SourceEnv ConfigSource = "environment"
|
||||||
|
SourceCLI ConfigSource = "cli"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultConfig returns a config with default values
|
||||||
|
func DefaultConfig() *OlmConfig {
|
||||||
|
// Set OS-specific socket path
|
||||||
|
var socketPath string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
socketPath = "olm"
|
||||||
|
default: // darwin, linux, and others
|
||||||
|
socketPath = "/var/run/olm.sock"
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &OlmConfig{
|
||||||
|
MTU: 1280,
|
||||||
|
DNS: "8.8.8.8",
|
||||||
|
LogLevel: "INFO",
|
||||||
|
InterfaceName: "olm",
|
||||||
|
EnableAPI: false,
|
||||||
|
SocketPath: socketPath,
|
||||||
|
PingInterval: "3s",
|
||||||
|
PingTimeout: "5s",
|
||||||
|
Holepunch: false,
|
||||||
|
// DoNotCreateNewClient: false,
|
||||||
|
sources: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track default sources
|
||||||
|
config.sources["mtu"] = string(SourceDefault)
|
||||||
|
config.sources["dns"] = string(SourceDefault)
|
||||||
|
config.sources["logLevel"] = string(SourceDefault)
|
||||||
|
config.sources["interface"] = string(SourceDefault)
|
||||||
|
config.sources["enableApi"] = string(SourceDefault)
|
||||||
|
config.sources["httpAddr"] = string(SourceDefault)
|
||||||
|
config.sources["socketPath"] = string(SourceDefault)
|
||||||
|
config.sources["pingInterval"] = string(SourceDefault)
|
||||||
|
config.sources["pingTimeout"] = string(SourceDefault)
|
||||||
|
config.sources["holepunch"] = string(SourceDefault)
|
||||||
|
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOlmConfigPath returns the path to the olm config file
|
||||||
|
func getOlmConfigPath() string {
|
||||||
|
configFile := os.Getenv("CONFIG_FILE")
|
||||||
|
if configFile != "" {
|
||||||
|
return configFile
|
||||||
|
}
|
||||||
|
|
||||||
|
var configDir string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client")
|
||||||
|
case "windows":
|
||||||
|
configDir = filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "olm-client")
|
||||||
|
default: // linux and others
|
||||||
|
configDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||||
|
fmt.Printf("Warning: Failed to create config directory: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(configDir, "config.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfig loads configuration from file, env vars, and CLI args
|
||||||
|
// Priority: CLI args > Env vars > Config file > Defaults
|
||||||
|
// Returns: (config, showVersion, showConfig, error)
|
||||||
|
func LoadConfig(args []string) (*OlmConfig, bool, bool, error) {
|
||||||
|
// Start with defaults
|
||||||
|
config := DefaultConfig()
|
||||||
|
|
||||||
|
// Load from config file (if exists)
|
||||||
|
fileConfig, err := loadConfigFromFile()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, false, fmt.Errorf("failed to load config file: %w", err)
|
||||||
|
}
|
||||||
|
if fileConfig != nil {
|
||||||
|
mergeConfigs(config, fileConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override with environment variables
|
||||||
|
loadConfigFromEnv(config)
|
||||||
|
|
||||||
|
// Override with CLI arguments
|
||||||
|
showVersion, showConfig, err := loadConfigFromCLI(config, args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse duration strings
|
||||||
|
if err := config.parseDurations(); err != nil {
|
||||||
|
return nil, false, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, showVersion, showConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadConfigFromFile loads configuration from the JSON config file
|
||||||
|
func loadConfigFromFile() (*OlmConfig, error) {
|
||||||
|
configPath := getOlmConfigPath()
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil // File doesn't exist, not an error
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var config OlmConfig
|
||||||
|
if err := json.Unmarshal(data, &config); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadConfigFromEnv loads configuration from environment variables
|
||||||
|
func loadConfigFromEnv(config *OlmConfig) {
|
||||||
|
if val := os.Getenv("PANGOLIN_ENDPOINT"); val != "" {
|
||||||
|
config.Endpoint = val
|
||||||
|
config.sources["endpoint"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("OLM_ID"); val != "" {
|
||||||
|
config.ID = val
|
||||||
|
config.sources["id"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("OLM_SECRET"); val != "" {
|
||||||
|
config.Secret = val
|
||||||
|
config.sources["secret"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("ORG"); val != "" {
|
||||||
|
config.OrgID = val
|
||||||
|
config.sources["org"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("USER_TOKEN"); val != "" {
|
||||||
|
config.UserToken = val
|
||||||
|
config.sources["userToken"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("MTU"); val != "" {
|
||||||
|
if mtu, err := strconv.Atoi(val); err == nil {
|
||||||
|
config.MTU = mtu
|
||||||
|
config.sources["mtu"] = string(SourceEnv)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Invalid MTU value: %s, keeping current value\n", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := os.Getenv("DNS"); val != "" {
|
||||||
|
config.DNS = val
|
||||||
|
config.sources["dns"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("LOG_LEVEL"); val != "" {
|
||||||
|
config.LogLevel = val
|
||||||
|
config.sources["logLevel"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("INTERFACE"); val != "" {
|
||||||
|
config.InterfaceName = val
|
||||||
|
config.sources["interface"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("HTTP_ADDR"); val != "" {
|
||||||
|
config.HTTPAddr = val
|
||||||
|
config.sources["httpAddr"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("PING_INTERVAL"); val != "" {
|
||||||
|
config.PingInterval = val
|
||||||
|
config.sources["pingInterval"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("PING_TIMEOUT"); val != "" {
|
||||||
|
config.PingTimeout = val
|
||||||
|
config.sources["pingTimeout"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("ENABLE_API"); val == "true" {
|
||||||
|
config.EnableAPI = true
|
||||||
|
config.sources["enableApi"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("SOCKET_PATH"); val != "" {
|
||||||
|
config.SocketPath = val
|
||||||
|
config.sources["socketPath"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("HOLEPUNCH"); val == "true" {
|
||||||
|
config.Holepunch = true
|
||||||
|
config.sources["holepunch"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
|
||||||
|
// config.DoNotCreateNewClient = true
|
||||||
|
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadConfigFromCLI loads configuration from command-line arguments
|
||||||
|
func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||||
|
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
|
||||||
|
|
||||||
|
// Store original values to detect changes
|
||||||
|
origValues := map[string]interface{}{
|
||||||
|
"endpoint": config.Endpoint,
|
||||||
|
"id": config.ID,
|
||||||
|
"secret": config.Secret,
|
||||||
|
"org": config.OrgID,
|
||||||
|
"userToken": config.UserToken,
|
||||||
|
"mtu": config.MTU,
|
||||||
|
"dns": config.DNS,
|
||||||
|
"logLevel": config.LogLevel,
|
||||||
|
"interface": config.InterfaceName,
|
||||||
|
"httpAddr": config.HTTPAddr,
|
||||||
|
"socketPath": config.SocketPath,
|
||||||
|
"pingInterval": config.PingInterval,
|
||||||
|
"pingTimeout": config.PingTimeout,
|
||||||
|
"enableApi": config.EnableAPI,
|
||||||
|
"holepunch": config.Holepunch,
|
||||||
|
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define flags
|
||||||
|
serviceFlags.StringVar(&config.Endpoint, "endpoint", config.Endpoint, "Endpoint of your Pangolin server")
|
||||||
|
serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID")
|
||||||
|
serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret")
|
||||||
|
serviceFlags.StringVar(&config.OrgID, "org", config.OrgID, "Organization ID")
|
||||||
|
serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)")
|
||||||
|
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
|
||||||
|
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
|
||||||
|
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||||
|
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
|
||||||
|
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
|
||||||
|
serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)")
|
||||||
|
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
|
||||||
|
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
|
||||||
|
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
|
||||||
|
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
|
||||||
|
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
|
||||||
|
|
||||||
|
version := serviceFlags.Bool("version", false, "Print the version")
|
||||||
|
showConfig := serviceFlags.Bool("show-config", false, "Show configuration sources and exit")
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
if err := serviceFlags.Parse(args); err != nil {
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track which values were changed by CLI args
|
||||||
|
if config.Endpoint != origValues["endpoint"].(string) {
|
||||||
|
config.sources["endpoint"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.ID != origValues["id"].(string) {
|
||||||
|
config.sources["id"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.Secret != origValues["secret"].(string) {
|
||||||
|
config.sources["secret"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.OrgID != origValues["org"].(string) {
|
||||||
|
config.sources["org"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.UserToken != origValues["userToken"].(string) {
|
||||||
|
config.sources["userToken"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.MTU != origValues["mtu"].(int) {
|
||||||
|
config.sources["mtu"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.DNS != origValues["dns"].(string) {
|
||||||
|
config.sources["dns"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.LogLevel != origValues["logLevel"].(string) {
|
||||||
|
config.sources["logLevel"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.InterfaceName != origValues["interface"].(string) {
|
||||||
|
config.sources["interface"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.HTTPAddr != origValues["httpAddr"].(string) {
|
||||||
|
config.sources["httpAddr"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.SocketPath != origValues["socketPath"].(string) {
|
||||||
|
config.sources["socketPath"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.PingInterval != origValues["pingInterval"].(string) {
|
||||||
|
config.sources["pingInterval"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.PingTimeout != origValues["pingTimeout"].(string) {
|
||||||
|
config.sources["pingTimeout"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.EnableAPI != origValues["enableApi"].(bool) {
|
||||||
|
config.sources["enableApi"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.Holepunch != origValues["holepunch"].(bool) {
|
||||||
|
config.sources["holepunch"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
|
||||||
|
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
|
||||||
|
// }
|
||||||
|
|
||||||
|
return *version, *showConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDurations parses the duration strings into time.Duration
|
||||||
|
func (c *OlmConfig) parseDurations() error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Parse ping interval
|
||||||
|
if c.PingInterval != "" {
|
||||||
|
c.PingIntervalDuration, err = time.ParseDuration(c.PingInterval)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", c.PingInterval)
|
||||||
|
c.PingIntervalDuration = 3 * time.Second
|
||||||
|
c.PingInterval = "3s"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.PingIntervalDuration = 3 * time.Second
|
||||||
|
c.PingInterval = "3s"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse ping timeout
|
||||||
|
if c.PingTimeout != "" {
|
||||||
|
c.PingTimeoutDuration, err = time.ParseDuration(c.PingTimeout)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", c.PingTimeout)
|
||||||
|
c.PingTimeoutDuration = 5 * time.Second
|
||||||
|
c.PingTimeout = "5s"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.PingTimeoutDuration = 5 * time.Second
|
||||||
|
c.PingTimeout = "5s"
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeConfigs merges source config into destination (only non-empty values)
|
||||||
|
// Also tracks that these values came from a file
|
||||||
|
func mergeConfigs(dest, src *OlmConfig) {
|
||||||
|
if src.Endpoint != "" {
|
||||||
|
dest.Endpoint = src.Endpoint
|
||||||
|
dest.sources["endpoint"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.ID != "" {
|
||||||
|
dest.ID = src.ID
|
||||||
|
dest.sources["id"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.Secret != "" {
|
||||||
|
dest.Secret = src.Secret
|
||||||
|
dest.sources["secret"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.OrgID != "" {
|
||||||
|
dest.OrgID = src.OrgID
|
||||||
|
dest.sources["org"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.UserToken != "" {
|
||||||
|
dest.UserToken = src.UserToken
|
||||||
|
dest.sources["userToken"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.MTU != 0 && src.MTU != 1280 {
|
||||||
|
dest.MTU = src.MTU
|
||||||
|
dest.sources["mtu"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.DNS != "" && src.DNS != "8.8.8.8" {
|
||||||
|
dest.DNS = src.DNS
|
||||||
|
dest.sources["dns"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.LogLevel != "" && src.LogLevel != "INFO" {
|
||||||
|
dest.LogLevel = src.LogLevel
|
||||||
|
dest.sources["logLevel"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.InterfaceName != "" && src.InterfaceName != "olm" {
|
||||||
|
dest.InterfaceName = src.InterfaceName
|
||||||
|
dest.sources["interface"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.HTTPAddr != "" && src.HTTPAddr != ":9452" {
|
||||||
|
dest.HTTPAddr = src.HTTPAddr
|
||||||
|
dest.sources["httpAddr"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.SocketPath != "" {
|
||||||
|
// Check if it's not the default for any OS
|
||||||
|
isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm"
|
||||||
|
if !isDefault {
|
||||||
|
dest.SocketPath = src.SocketPath
|
||||||
|
dest.sources["socketPath"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if src.PingInterval != "" && src.PingInterval != "3s" {
|
||||||
|
dest.PingInterval = src.PingInterval
|
||||||
|
dest.sources["pingInterval"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.PingTimeout != "" && src.PingTimeout != "5s" {
|
||||||
|
dest.PingTimeout = src.PingTimeout
|
||||||
|
dest.sources["pingTimeout"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.TlsClientCert != "" {
|
||||||
|
dest.TlsClientCert = src.TlsClientCert
|
||||||
|
dest.sources["tlsClientCert"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
// For booleans, we always take the source value if explicitly set
|
||||||
|
if src.EnableAPI {
|
||||||
|
dest.EnableAPI = src.EnableAPI
|
||||||
|
dest.sources["enableApi"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.Holepunch {
|
||||||
|
dest.Holepunch = src.Holepunch
|
||||||
|
dest.sources["holepunch"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
// if src.DoNotCreateNewClient {
|
||||||
|
// dest.DoNotCreateNewClient = src.DoNotCreateNewClient
|
||||||
|
// dest.sources["doNotCreateNewClient"] = string(SourceFile)
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveConfig saves the current configuration to the config file
|
||||||
|
func SaveConfig(config *OlmConfig) error {
|
||||||
|
configPath := getOlmConfigPath()
|
||||||
|
data, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal config: %w", err)
|
||||||
|
}
|
||||||
|
return os.WriteFile(configPath, data, 0644)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShowConfig prints the configuration and the source of each value
|
||||||
|
func (c *OlmConfig) ShowConfig() {
|
||||||
|
configPath := getOlmConfigPath()
|
||||||
|
|
||||||
|
fmt.Println("\n=== Olm Configuration ===\n")
|
||||||
|
fmt.Printf("Config File: %s\n", configPath)
|
||||||
|
|
||||||
|
// Check if config file exists
|
||||||
|
if _, err := os.Stat(configPath); err == nil {
|
||||||
|
fmt.Printf("Config File Status: ✓ exists\n")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Config File Status: ✗ not found\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n--- Configuration Values ---")
|
||||||
|
fmt.Println("(Format: Setting = Value [source])\n")
|
||||||
|
|
||||||
|
// Helper to get source or default
|
||||||
|
getSource := func(key string) string {
|
||||||
|
if source, ok := c.sources[key]; ok {
|
||||||
|
return source
|
||||||
|
}
|
||||||
|
return string(SourceDefault)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to format value (mask secrets)
|
||||||
|
formatValue := func(key, value string) string {
|
||||||
|
if key == "secret" && value != "" {
|
||||||
|
if len(value) > 8 {
|
||||||
|
return value[:4] + "****" + value[len(value)-4:]
|
||||||
|
}
|
||||||
|
return "****"
|
||||||
|
}
|
||||||
|
if value == "" {
|
||||||
|
return "(not set)"
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection settings
|
||||||
|
fmt.Println("Connection:")
|
||||||
|
fmt.Printf(" endpoint = %s [%s]\n", formatValue("endpoint", c.Endpoint), getSource("endpoint"))
|
||||||
|
fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id"))
|
||||||
|
fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret"))
|
||||||
|
fmt.Printf(" org = %s [%s]\n", formatValue("org", c.OrgID), getSource("org"))
|
||||||
|
fmt.Printf(" user-token = %s [%s]\n", formatValue("userToken", c.UserToken), getSource("userToken"))
|
||||||
|
|
||||||
|
// Network settings
|
||||||
|
fmt.Println("\nNetwork:")
|
||||||
|
fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu"))
|
||||||
|
fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns"))
|
||||||
|
fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface"))
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
fmt.Println("\nLogging:")
|
||||||
|
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
|
||||||
|
|
||||||
|
// API server
|
||||||
|
fmt.Println("\nAPI Server:")
|
||||||
|
fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi"))
|
||||||
|
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
|
||||||
|
fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath"))
|
||||||
|
|
||||||
|
// Timing
|
||||||
|
fmt.Println("\nTiming:")
|
||||||
|
fmt.Printf(" ping-interval = %s [%s]\n", c.PingInterval, getSource("pingInterval"))
|
||||||
|
fmt.Printf(" ping-timeout = %s [%s]\n", c.PingTimeout, getSource("pingTimeout"))
|
||||||
|
|
||||||
|
// Advanced
|
||||||
|
fmt.Println("\nAdvanced:")
|
||||||
|
fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch"))
|
||||||
|
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
|
||||||
|
if c.TlsClientCert != "" {
|
||||||
|
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Source legend
|
||||||
|
fmt.Println("\n--- Source Legend ---")
|
||||||
|
fmt.Println(" default = Built-in default value")
|
||||||
|
fmt.Println(" file = Loaded from config file")
|
||||||
|
fmt.Println(" environment = Set via environment variable")
|
||||||
|
fmt.Println(" cli = Provided as command-line argument")
|
||||||
|
fmt.Println("\nPriority: cli > environment > file > default")
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
523
diff
Normal file
523
diff
Normal file
@@ -0,0 +1,523 @@
|
|||||||
|
diff --git a/api/api.go b/api/api.go
|
||||||
|
index dd07751..0d2e4ef 100644
|
||||||
|
--- a/api/api.go
|
||||||
|
+++ b/api/api.go
|
||||||
|
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
+// SwitchOrgRequest defines the structure for switching organizations
|
||||||
|
+type SwitchOrgRequest struct {
|
||||||
|
+ OrgID string `json:"orgId"`
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
// PeerStatus represents the status of a peer connection
|
||||||
|
type PeerStatus struct {
|
||||||
|
SiteID int `json:"siteId"`
|
||||||
|
@@ -35,6 +40,7 @@ type StatusResponse struct {
|
||||||
|
Registered bool `json:"registered"`
|
||||||
|
TunnelIP string `json:"tunnelIP,omitempty"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
+ OrgID string `json:"orgId,omitempty"`
|
||||||
|
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -46,6 +52,7 @@ type API struct {
|
||||||
|
server *http.Server
|
||||||
|
connectionChan chan ConnectionRequest
|
||||||
|
shutdownChan chan struct{}
|
||||||
|
+ switchOrgChan chan SwitchOrgRequest
|
||||||
|
statusMu sync.RWMutex
|
||||||
|
peerStatuses map[int]*PeerStatus
|
||||||
|
connectedAt time.Time
|
||||||
|
@@ -53,6 +60,7 @@ type API struct {
|
||||||
|
isRegistered bool
|
||||||
|
tunnelIP string
|
||||||
|
version string
|
||||||
|
+ orgID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||||
|
@@ -61,6 +69,7 @@ func NewAPI(addr string) *API {
|
||||||
|
addr: addr,
|
||||||
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
|
shutdownChan: make(chan struct{}, 1),
|
||||||
|
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||||
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API {
|
||||||
|
socketPath: socketPath,
|
||||||
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
|
shutdownChan: make(chan struct{}, 1),
|
||||||
|
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||||
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -85,6 +95,7 @@ func (s *API) Start() error {
|
||||||
|
mux.HandleFunc("/connect", s.handleConnect)
|
||||||
|
mux.HandleFunc("/status", s.handleStatus)
|
||||||
|
mux.HandleFunc("/exit", s.handleExit)
|
||||||
|
+ mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||||
|
|
||||||
|
s.server = &http.Server{
|
||||||
|
Handler: mux,
|
||||||
|
@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
|
||||||
|
return s.shutdownChan
|
||||||
|
}
|
||||||
|
|
||||||
|
+// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
||||||
|
+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
||||||
|
+ return s.switchOrgChan
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||||
|
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) {
|
||||||
|
s.version = version
|
||||||
|
}
|
||||||
|
|
||||||
|
+// SetOrgID sets the org ID
|
||||||
|
+func (s *API) SetOrgID(orgID string) {
|
||||||
|
+ s.statusMu.Lock()
|
||||||
|
+ defer s.statusMu.Unlock()
|
||||||
|
+ s.orgID = orgID
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
// UpdatePeerRelayStatus updates only the relay status of a peer
|
||||||
|
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
|
Registered: s.isRegistered,
|
||||||
|
TunnelIP: s.tunnelIP,
|
||||||
|
Version: s.version,
|
||||||
|
+ OrgID: s.orgID,
|
||||||
|
PeerStatuses: s.peerStatuses,
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||||
|
"status": "shutdown initiated",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
+
|
||||||
|
+// handleSwitchOrg handles the /switch-org endpoint
|
||||||
|
+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||||
|
+ if r.Method != http.MethodPost {
|
||||||
|
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
+ return
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ var req SwitchOrgRequest
|
||||||
|
+ decoder := json.NewDecoder(r.Body)
|
||||||
|
+ if err := decoder.Decode(&req); err != nil {
|
||||||
|
+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||||
|
+ return
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Validate required fields
|
||||||
|
+ if req.OrgID == "" {
|
||||||
|
+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
|
||||||
|
+ return
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||||
|
+
|
||||||
|
+ // Send the request to the main goroutine
|
||||||
|
+ select {
|
||||||
|
+ case s.switchOrgChan <- req:
|
||||||
|
+ // Signal sent successfully
|
||||||
|
+ default:
|
||||||
|
+ // Channel already has a signal, don't block
|
||||||
|
+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests)
|
||||||
|
+ return
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Return a success response
|
||||||
|
+ w.Header().Set("Content-Type", "application/json")
|
||||||
|
+ w.WriteHeader(http.StatusAccepted)
|
||||||
|
+ json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
+ "status": "org switch initiated",
|
||||||
|
+ "orgId": req.OrgID,
|
||||||
|
+ })
|
||||||
|
+}
|
||||||
|
diff --git a/olm/olm.go b/olm/olm.go
|
||||||
|
index 78080c4..5e292d6 100644
|
||||||
|
--- a/olm/olm.go
|
||||||
|
+++ b/olm/olm.go
|
||||||
|
@@ -58,6 +58,58 @@ type Config struct {
|
||||||
|
OrgID string
|
||||||
|
}
|
||||||
|
|
||||||
|
+// tunnelState holds all the active tunnel resources that need cleanup
|
||||||
|
+type tunnelState struct {
|
||||||
|
+ dev *device.Device
|
||||||
|
+ tdev tun.Device
|
||||||
|
+ uapiListener net.Listener
|
||||||
|
+ peerMonitor *peermonitor.PeerMonitor
|
||||||
|
+ stopRegister func()
|
||||||
|
+ connected bool
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
+// teardownTunnel cleans up all tunnel resources
|
||||||
|
+func teardownTunnel(state *tunnelState) {
|
||||||
|
+ if state == nil {
|
||||||
|
+ return
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ logger.Info("Tearing down tunnel...")
|
||||||
|
+
|
||||||
|
+ // Stop registration messages
|
||||||
|
+ if state.stopRegister != nil {
|
||||||
|
+ state.stopRegister()
|
||||||
|
+ state.stopRegister = nil
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Stop peer monitor
|
||||||
|
+ if state.peerMonitor != nil {
|
||||||
|
+ state.peerMonitor.Stop()
|
||||||
|
+ state.peerMonitor = nil
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Close UAPI listener
|
||||||
|
+ if state.uapiListener != nil {
|
||||||
|
+ state.uapiListener.Close()
|
||||||
|
+ state.uapiListener = nil
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Close WireGuard device
|
||||||
|
+ if state.dev != nil {
|
||||||
|
+ state.dev.Close()
|
||||||
|
+ state.dev = nil
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Close TUN device
|
||||||
|
+ if state.tdev != nil {
|
||||||
|
+ state.tdev.Close()
|
||||||
|
+ state.tdev = nil
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ state.connected = false
|
||||||
|
+ logger.Info("Tunnel teardown complete")
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
func Run(ctx context.Context, config Config) {
|
||||||
|
// Create a cancellable context for internal shutdown control
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
pingTimeout = config.PingTimeoutDuration
|
||||||
|
doHolepunch = config.Holepunch
|
||||||
|
privateKey wgtypes.Key
|
||||||
|
- connected bool
|
||||||
|
- dev *device.Device
|
||||||
|
wgData WgData
|
||||||
|
holePunchData HolePunchData
|
||||||
|
- uapiListener net.Listener
|
||||||
|
- tdev tun.Device
|
||||||
|
+ orgID = config.OrgID
|
||||||
|
)
|
||||||
|
|
||||||
|
+ // Tunnel state that can be torn down and recreated
|
||||||
|
+ tunnel := &tunnelState{}
|
||||||
|
+
|
||||||
|
stopHolepunch = make(chan struct{})
|
||||||
|
stopPing = make(chan struct{})
|
||||||
|
|
||||||
|
@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
}
|
||||||
|
|
||||||
|
apiServer.SetVersion(config.Version)
|
||||||
|
+ apiServer.SetOrgID(orgID)
|
||||||
|
if err := apiServer.Start(); err != nil {
|
||||||
|
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||||
|
}
|
||||||
|
@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
- if connected {
|
||||||
|
+ if tunnel.connected {
|
||||||
|
logger.Info("Already connected. Ignoring new connection request.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
- if stopRegister != nil {
|
||||||
|
- stopRegister()
|
||||||
|
- stopRegister = nil
|
||||||
|
+ if tunnel.stopRegister != nil {
|
||||||
|
+ tunnel.stopRegister()
|
||||||
|
+ tunnel.stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
close(stopHolepunch)
|
||||||
|
@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
// if there is an existing tunnel then close it
|
||||||
|
- if dev != nil {
|
||||||
|
+ if tunnel.dev != nil {
|
||||||
|
logger.Info("Got new message. Closing existing tunnel!")
|
||||||
|
- dev.Close()
|
||||||
|
+ tunnel.dev.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
- tdev, err = func() (tun.Device, error) {
|
||||||
|
+ tunnel.tdev, err = func() (tun.Device, error) {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
interfaceName, err := findUnusedUTUN()
|
||||||
|
if err != nil {
|
||||||
|
@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
- if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
||||||
|
+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil {
|
||||||
|
interfaceName = realInterfaceName
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||||
|
+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||||
|
|
||||||
|
- uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||||
|
+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
- conn, err := uapiListener.Accept()
|
||||||
|
+ conn, err := tunnel.uapiListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
- go dev.IpcHandle(conn)
|
||||||
|
+ go tunnel.dev.IpcHandle(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
logger.Info("UAPI listener started")
|
||||||
|
|
||||||
|
- if err = dev.Up(); err != nil {
|
||||||
|
+ if err = tunnel.dev.Up(); err != nil {
|
||||||
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
|
}
|
||||||
|
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||||
|
@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
apiServer.SetTunnelIP(wgData.TunnelIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
- peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
|
+ tunnel.peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
|
func(siteID int, connected bool, rtt time.Duration) {
|
||||||
|
if apiServer != nil {
|
||||||
|
// Find the site config to get endpoint information
|
||||||
|
@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
},
|
||||||
|
fixKey(privateKey.String()),
|
||||||
|
olm,
|
||||||
|
- dev,
|
||||||
|
+ tunnel.dev,
|
||||||
|
doHolepunch,
|
||||||
|
)
|
||||||
|
|
||||||
|
@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
// Format the endpoint before configuring the peer.
|
||||||
|
site.Endpoint = formatEndpoint(site.Endpoint)
|
||||||
|
|
||||||
|
- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
|
||||||
|
+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to configure peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
logger.Info("Configured peer %s", site.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
- peerMonitor.Start()
|
||||||
|
+ tunnel.peerMonitor.Start()
|
||||||
|
|
||||||
|
if apiServer != nil {
|
||||||
|
apiServer.SetRegistered(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
- connected = true
|
||||||
|
+ tunnel.connected = true
|
||||||
|
|
||||||
|
logger.Info("WireGuard device created.")
|
||||||
|
})
|
||||||
|
@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the peer in WireGuard
|
||||||
|
- if dev != nil {
|
||||||
|
+ if tunnel.dev != nil {
|
||||||
|
// Find the existing peer to get old data
|
||||||
|
var oldRemoteSubnets string
|
||||||
|
var oldPublicKey string
|
||||||
|
@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
// If the public key has changed, remove the old peer first
|
||||||
|
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
|
||||||
|
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
|
||||||
|
- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
|
||||||
|
+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil {
|
||||||
|
logger.Error("Failed to remove old peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
// Format the endpoint before updating the peer.
|
||||||
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
|
|
||||||
|
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to update peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the peer to WireGuard
|
||||||
|
- if dev != nil {
|
||||||
|
+ if tunnel.dev != nil {
|
||||||
|
// Format the endpoint before adding the new peer.
|
||||||
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
|
|
||||||
|
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to add peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the peer from WireGuard
|
||||||
|
- if dev != nil {
|
||||||
|
- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
||||||
|
+ if tunnel.dev != nil {
|
||||||
|
+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
||||||
|
logger.Error("Failed to remove peer: %v", err)
|
||||||
|
// Send error response if needed
|
||||||
|
return
|
||||||
|
@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||||
|
+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
|
||||||
|
@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
apiServer.SetConnectionStatus(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
- if connected {
|
||||||
|
+ if tunnel.connected {
|
||||||
|
logger.Debug("Already connected, skipping registration")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
|
||||||
|
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
|
||||||
|
|
||||||
|
- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
"publicKey": publicKey.String(),
|
||||||
|
"relay": !doHolepunch,
|
||||||
|
"olmVersion": config.Version,
|
||||||
|
- "orgId": config.OrgID,
|
||||||
|
+ "orgId": orgID,
|
||||||
|
}, 1*time.Second)
|
||||||
|
|
||||||
|
go keepSendingPing(olm)
|
||||||
|
@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
}
|
||||||
|
defer olm.Close()
|
||||||
|
|
||||||
|
+ // Listen for org switch requests from the API (after olm is created)
|
||||||
|
+ if apiServer != nil {
|
||||||
|
+ go func() {
|
||||||
|
+ for req := range apiServer.GetSwitchOrgChannel() {
|
||||||
|
+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID)
|
||||||
|
+
|
||||||
|
+ // Update the orgId
|
||||||
|
+ orgID = req.OrgID
|
||||||
|
+
|
||||||
|
+ // Teardown existing tunnel
|
||||||
|
+ teardownTunnel(tunnel)
|
||||||
|
+
|
||||||
|
+ // Reset tunnel state
|
||||||
|
+ tunnel = &tunnelState{}
|
||||||
|
+
|
||||||
|
+ // Stop holepunch
|
||||||
|
+ select {
|
||||||
|
+ case <-stopHolepunch:
|
||||||
|
+ // Channel already closed
|
||||||
|
+ default:
|
||||||
|
+ close(stopHolepunch)
|
||||||
|
+ }
|
||||||
|
+ stopHolepunch = make(chan struct{})
|
||||||
|
+
|
||||||
|
+ // Clear API server state
|
||||||
|
+ apiServer.SetRegistered(false)
|
||||||
|
+ apiServer.SetTunnelIP("")
|
||||||
|
+ apiServer.SetOrgID(orgID)
|
||||||
|
+
|
||||||
|
+ // Send new registration message with updated orgId
|
||||||
|
+ publicKey := privateKey.PublicKey()
|
||||||
|
+ logger.Info("Sending registration message with new orgId: %s", orgID)
|
||||||
|
+
|
||||||
|
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
+ "publicKey": publicKey.String(),
|
||||||
|
+ "relay": !doHolepunch,
|
||||||
|
+ "olmVersion": config.Version,
|
||||||
|
+ "orgId": orgID,
|
||||||
|
+ }, 1*time.Second)
|
||||||
|
+ }
|
||||||
|
+ }()
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Info("Context cancelled")
|
||||||
|
@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
close(stopHolepunch)
|
||||||
|
}
|
||||||
|
|
||||||
|
- if stopRegister != nil {
|
||||||
|
- stopRegister()
|
||||||
|
- stopRegister = nil
|
||||||
|
+ if tunnel.stopRegister != nil {
|
||||||
|
+ tunnel.stopRegister()
|
||||||
|
+ tunnel.stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) {
|
||||||
|
close(stopPing)
|
||||||
|
}
|
||||||
|
|
||||||
|
- if peerMonitor != nil {
|
||||||
|
- peerMonitor.Stop()
|
||||||
|
- }
|
||||||
|
-
|
||||||
|
- if uapiListener != nil {
|
||||||
|
- uapiListener.Close()
|
||||||
|
- }
|
||||||
|
- if dev != nil {
|
||||||
|
- dev.Close()
|
||||||
|
- }
|
||||||
|
+ // Use teardownTunnel to clean up all tunnel resources
|
||||||
|
+ teardownTunnel(tunnel)
|
||||||
|
|
||||||
|
if apiServer != nil {
|
||||||
|
apiServer.Stop()
|
||||||
15
docker-compose.yml
Normal file
15
docker-compose.yml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
services:
|
||||||
|
olm:
|
||||||
|
image: fosrl/olm:latest
|
||||||
|
container_name: olm
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
- PANGOLIN_ENDPOINT=https://example.com
|
||||||
|
- OLM_ID=vdqnz8rwgb95cnp
|
||||||
|
- OLM_SECRET=1sw05qv1tkfdb1k81zpw05nahnnjvmhxjvf746umwagddmdg
|
||||||
|
cap_add:
|
||||||
|
- NET_ADMIN
|
||||||
|
- SYS_MODULE
|
||||||
|
devices:
|
||||||
|
- /dev/net/tun:/dev/net/tun
|
||||||
|
network_mode: host
|
||||||
@@ -1,21 +1,10 @@
|
|||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
|
|
||||||
# Sample from https://github.com/traefik/traefik-library-image/blob/5070edb25b03cca6802d75d5037576c840f73fdd/v3.1/alpine/entrypoint.sh
|
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
# first arg is `-f` or `--some-option`
|
# first arg is `-f` or `--some-option`
|
||||||
if [ "${1#-}" != "$1" ]; then
|
if [ "${1#-}" != "$1" ]; then
|
||||||
set -- newt "$@"
|
set -- olm "$@"
|
||||||
fi
|
|
||||||
|
|
||||||
# if our command is a valid newt subcommand, let's invoke it through newt instead
|
|
||||||
# (this allows for "docker run newt version", etc)
|
|
||||||
if newt "$1" --help >/dev/null 2>&1
|
|
||||||
then
|
|
||||||
set -- newt "$@"
|
|
||||||
else
|
|
||||||
echo "= '$1' is not a newt command: assuming shell execution." 1>&2
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
exec "$@"
|
exec "$@"
|
||||||
279
get-olm.sh
Normal file
279
get-olm.sh
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Get Olm - Cross-platform installation script
|
||||||
|
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/olm/refs/heads/main/get-olm.sh | bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# GitHub repository info
|
||||||
|
REPO="fosrl/olm"
|
||||||
|
GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
|
||||||
|
|
||||||
|
# Function to print colored output
|
||||||
|
print_status() {
|
||||||
|
echo -e "${GREEN}[INFO]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
print_warning() {
|
||||||
|
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
print_error() {
|
||||||
|
echo -e "${RED}[ERROR]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Function to get latest version from GitHub API
|
||||||
|
get_latest_version() {
|
||||||
|
local latest_info
|
||||||
|
|
||||||
|
if command -v curl >/dev/null 2>&1; then
|
||||||
|
latest_info=$(curl -fsSL "$GITHUB_API_URL" 2>/dev/null)
|
||||||
|
elif command -v wget >/dev/null 2>&1; then
|
||||||
|
latest_info=$(wget -qO- "$GITHUB_API_URL" 2>/dev/null)
|
||||||
|
else
|
||||||
|
print_error "Neither curl nor wget is available. Please install one of them." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$latest_info" ]; then
|
||||||
|
print_error "Failed to fetch latest version information" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract version from JSON response (works without jq)
|
||||||
|
local version=$(echo "$latest_info" | grep '"tag_name"' | head -1 | sed 's/.*"tag_name": *"\([^"]*\)".*/\1/')
|
||||||
|
|
||||||
|
if [ -z "$version" ]; then
|
||||||
|
print_error "Could not parse version from GitHub API response" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Remove 'v' prefix if present
|
||||||
|
version=$(echo "$version" | sed 's/^v//')
|
||||||
|
|
||||||
|
echo "$version"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Detect OS and architecture
|
||||||
|
detect_platform() {
|
||||||
|
local os arch
|
||||||
|
|
||||||
|
# Detect OS
|
||||||
|
case "$(uname -s)" in
|
||||||
|
Linux*) os="linux" ;;
|
||||||
|
Darwin*) os="darwin" ;;
|
||||||
|
MINGW*|MSYS*|CYGWIN*) os="windows" ;;
|
||||||
|
FreeBSD*) os="freebsd" ;;
|
||||||
|
*)
|
||||||
|
print_error "Unsupported operating system: $(uname -s)"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
# Detect architecture
|
||||||
|
case "$(uname -m)" in
|
||||||
|
x86_64|amd64) arch="amd64" ;;
|
||||||
|
arm64|aarch64) arch="arm64" ;;
|
||||||
|
armv7l|armv6l)
|
||||||
|
if [ "$os" = "linux" ]; then
|
||||||
|
if [ "$(uname -m)" = "armv6l" ]; then
|
||||||
|
arch="arm32v6"
|
||||||
|
else
|
||||||
|
arch="arm32"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
arch="arm64" # Default for non-Linux ARM
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
riscv64)
|
||||||
|
if [ "$os" = "linux" ]; then
|
||||||
|
arch="riscv64"
|
||||||
|
else
|
||||||
|
print_error "RISC-V architecture only supported on Linux"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
print_error "Unsupported architecture: $(uname -m)"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
echo "${os}_${arch}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get installation directory
|
||||||
|
get_install_dir() {
|
||||||
|
local platform="$1"
|
||||||
|
|
||||||
|
if [[ "$platform" == *"windows"* ]]; then
|
||||||
|
echo "$HOME/bin"
|
||||||
|
else
|
||||||
|
# For Unix-like systems, prioritize system-wide directories for sudo access
|
||||||
|
# Check in order of preference: /usr/local/bin, /usr/bin, ~/.local/bin
|
||||||
|
if [ -d "/usr/local/bin" ]; then
|
||||||
|
echo "/usr/local/bin"
|
||||||
|
elif [ -d "/usr/bin" ]; then
|
||||||
|
echo "/usr/bin"
|
||||||
|
else
|
||||||
|
# Fallback to user directory if system directories don't exist
|
||||||
|
echo "$HOME/.local/bin"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if we need sudo for installation
|
||||||
|
need_sudo() {
|
||||||
|
local install_dir="$1"
|
||||||
|
|
||||||
|
# If installing to system directory and we don't have write permission, need sudo
|
||||||
|
if [[ "$install_dir" == "/usr/local/bin" || "$install_dir" == "/usr/bin" ]]; then
|
||||||
|
if [ ! -w "$install_dir" ] 2>/dev/null; then
|
||||||
|
return 0 # Need sudo
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
return 1 # Don't need sudo
|
||||||
|
}
|
||||||
|
|
||||||
|
# Download and install olm
|
||||||
|
install_olm() {
|
||||||
|
local platform="$1"
|
||||||
|
local install_dir="$2"
|
||||||
|
local binary_name="olm_${platform}"
|
||||||
|
local exe_suffix=""
|
||||||
|
|
||||||
|
# Add .exe suffix for Windows
|
||||||
|
if [[ "$platform" == *"windows"* ]]; then
|
||||||
|
binary_name="${binary_name}.exe"
|
||||||
|
exe_suffix=".exe"
|
||||||
|
fi
|
||||||
|
|
||||||
|
local download_url="${BASE_URL}/${binary_name}"
|
||||||
|
local temp_file="/tmp/olm${exe_suffix}"
|
||||||
|
local final_path="${install_dir}/olm${exe_suffix}"
|
||||||
|
|
||||||
|
print_status "Downloading olm from ${download_url}"
|
||||||
|
|
||||||
|
# Download the binary
|
||||||
|
if command -v curl >/dev/null 2>&1; then
|
||||||
|
curl -fsSL "$download_url" -o "$temp_file"
|
||||||
|
elif command -v wget >/dev/null 2>&1; then
|
||||||
|
wget -q "$download_url" -O "$temp_file"
|
||||||
|
else
|
||||||
|
print_error "Neither curl nor wget is available. Please install one of them."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if we need sudo for installation
|
||||||
|
local use_sudo=""
|
||||||
|
if need_sudo "$install_dir"; then
|
||||||
|
print_status "Administrator privileges required for system-wide installation"
|
||||||
|
if command -v sudo >/dev/null 2>&1; then
|
||||||
|
use_sudo="sudo"
|
||||||
|
else
|
||||||
|
print_error "sudo is required for system-wide installation but not available"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Create install directory if it doesn't exist
|
||||||
|
if [ -n "$use_sudo" ]; then
|
||||||
|
$use_sudo mkdir -p "$install_dir"
|
||||||
|
else
|
||||||
|
mkdir -p "$install_dir"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Move binary to install directory
|
||||||
|
if [ -n "$use_sudo" ]; then
|
||||||
|
$use_sudo mv "$temp_file" "$final_path"
|
||||||
|
$use_sudo chmod +x "$final_path"
|
||||||
|
else
|
||||||
|
mv "$temp_file" "$final_path"
|
||||||
|
chmod +x "$final_path"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_status "olm installed to ${final_path}"
|
||||||
|
|
||||||
|
# Check if install directory is in PATH (only warn for non-system directories)
|
||||||
|
if [[ "$install_dir" != "/usr/local/bin" && "$install_dir" != "/usr/bin" ]]; then
|
||||||
|
if ! echo "$PATH" | grep -q "$install_dir"; then
|
||||||
|
print_warning "Install directory ${install_dir} is not in your PATH."
|
||||||
|
print_warning "Add it to your PATH by adding this line to your shell profile:"
|
||||||
|
print_warning " export PATH=\"${install_dir}:\$PATH\""
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Verify installation
|
||||||
|
verify_installation() {
|
||||||
|
local install_dir="$1"
|
||||||
|
local exe_suffix=""
|
||||||
|
|
||||||
|
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||||
|
exe_suffix=".exe"
|
||||||
|
fi
|
||||||
|
|
||||||
|
local olm_path="${install_dir}/olm${exe_suffix}"
|
||||||
|
|
||||||
|
if [ -f "$olm_path" ] && [ -x "$olm_path" ]; then
|
||||||
|
print_status "Installation successful!"
|
||||||
|
print_status "olm version: $("$olm_path" --version 2>/dev/null || echo "unknown")"
|
||||||
|
return 0
|
||||||
|
else
|
||||||
|
print_error "Installation failed. Binary not found or not executable."
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main installation process
|
||||||
|
main() {
|
||||||
|
print_status "Installing latest version of olm..."
|
||||||
|
|
||||||
|
# Get latest version
|
||||||
|
print_status "Fetching latest version from GitHub..."
|
||||||
|
VERSION=$(get_latest_version)
|
||||||
|
print_status "Latest version: v${VERSION}"
|
||||||
|
|
||||||
|
# Set base URL with the fetched version
|
||||||
|
BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}"
|
||||||
|
|
||||||
|
# Detect platform
|
||||||
|
PLATFORM=$(detect_platform)
|
||||||
|
print_status "Detected platform: ${PLATFORM}"
|
||||||
|
|
||||||
|
# Get install directory
|
||||||
|
INSTALL_DIR=$(get_install_dir "$PLATFORM")
|
||||||
|
print_status "Install directory: ${INSTALL_DIR}"
|
||||||
|
|
||||||
|
# Inform user about system-wide installation
|
||||||
|
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
|
||||||
|
print_status "Installing system-wide for sudo access"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install olm
|
||||||
|
install_olm "$PLATFORM" "$INSTALL_DIR"
|
||||||
|
|
||||||
|
# Verify installation
|
||||||
|
if verify_installation "$INSTALL_DIR"; then
|
||||||
|
print_status "olm is ready to use!"
|
||||||
|
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
|
||||||
|
print_status "olm is installed system-wide and accessible via sudo"
|
||||||
|
fi
|
||||||
|
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||||
|
print_status "Run 'olm --help' to get started"
|
||||||
|
else
|
||||||
|
print_status "Run 'olm --help' or 'sudo olm --help' to get started"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run main function
|
||||||
|
main "$@"
|
||||||
33
go.mod
33
go.mod
@@ -1,19 +1,22 @@
|
|||||||
module github.com/fosrl/newt
|
module github.com/fosrl/olm
|
||||||
|
|
||||||
go 1.23.1
|
go 1.25
|
||||||
|
|
||||||
toolchain go1.23.2
|
|
||||||
|
|
||||||
require golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/google/btree v1.1.2 // indirect
|
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7
|
||||||
github.com/gorilla/websocket v1.5.3 // indirect
|
github.com/vishvananda/netlink v1.3.1
|
||||||
golang.org/x/crypto v0.28.0 // indirect
|
golang.org/x/crypto v0.43.0
|
||||||
golang.org/x/net v0.30.0 // indirect
|
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
|
||||||
golang.org/x/sys v0.26.0 // indirect
|
golang.org/x/sys v0.37.0
|
||||||
golang.org/x/time v0.7.0 // indirect
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
|
)
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
|
|
||||||
|
require (
|
||||||
|
github.com/gorilla/websocket v1.5.3 // indirect
|
||||||
|
github.com/vishvananda/netns v0.0.5 // indirect
|
||||||
|
golang.org/x/net v0.45.0 // indirect
|
||||||
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
|
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.6.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
46
go.sum
46
go.sum
@@ -1,20 +1,34 @@
|
|||||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 h1:6bSU8Efyhx1SR53iSw1Wjk5V8vDfizGAudq/GlE9b+o=
|
||||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7/go.mod h1:Ac0k2FmAMC+hu21rAK+p7EnnEGrqKO/QZuGTVHA/XDM=
|
||||||
|
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||||
|
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
|
||||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||||
|
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||||
|
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||||
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||||
|
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||||
|
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs=
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||||
|
|||||||
351
holepunch/holepunch.go
Normal file
351
holepunch/holepunch.go
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
package holepunch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/bind"
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
"golang.org/x/exp/rand"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DomainResolver is a function type for resolving domains to IP addresses
|
||||||
|
type DomainResolver func(string) (string, error)
|
||||||
|
|
||||||
|
// ExitNode represents a WireGuard exit node for hole punching
|
||||||
|
type ExitNode struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager handles UDP hole punching operations
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
stopChan chan struct{}
|
||||||
|
sharedBind *bind.SharedBind
|
||||||
|
olmID string
|
||||||
|
token string
|
||||||
|
domainResolver DomainResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new hole punch manager
|
||||||
|
func NewManager(sharedBind *bind.SharedBind, olmID string, domainResolver DomainResolver) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
sharedBind: sharedBind,
|
||||||
|
olmID: olmID,
|
||||||
|
domainResolver: domainResolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetToken updates the authentication token used for hole punching
|
||||||
|
func (m *Manager) SetToken(token string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.token = token
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning returns whether hole punching is currently active
|
||||||
|
func (m *Manager) IsRunning() bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.running
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops any ongoing hole punch operations
|
||||||
|
func (m *Manager) Stop() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if !m.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.stopChan != nil {
|
||||||
|
close(m.stopChan)
|
||||||
|
m.stopChan = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.running = false
|
||||||
|
logger.Info("Hole punch manager stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartMultipleExitNodes starts hole punching to multiple exit nodes
|
||||||
|
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
|
||||||
|
if m.running {
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Debug("UDP hole punch already running, skipping new request")
|
||||||
|
return fmt.Errorf("hole punch already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(exitNodes) == 0 {
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Warn("No exit nodes provided for hole punching")
|
||||||
|
return fmt.Errorf("no exit nodes provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.running = true
|
||||||
|
m.stopChan = make(chan struct{})
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||||
|
|
||||||
|
go m.runMultipleExitNodes(exitNodes)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
|
||||||
|
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
|
||||||
|
if m.running {
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Debug("UDP hole punch already running, skipping new request")
|
||||||
|
return fmt.Errorf("hole punch already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.running = true
|
||||||
|
m.stopChan = make(chan struct{})
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
|
||||||
|
|
||||||
|
go m.runSingleEndpoint(endpoint, serverPubKey)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runMultipleExitNodes performs hole punching to multiple exit nodes
|
||||||
|
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
||||||
|
defer func() {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.running = false
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Resolve all endpoints upfront
|
||||||
|
type resolvedExitNode struct {
|
||||||
|
remoteAddr *net.UDPAddr
|
||||||
|
publicKey string
|
||||||
|
endpointName string
|
||||||
|
}
|
||||||
|
|
||||||
|
var resolvedNodes []resolvedExitNode
|
||||||
|
for _, exitNode := range exitNodes {
|
||||||
|
host, err := m.domainResolver(exitNode.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := net.JoinHostPort(host, "21820")
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||||
|
remoteAddr: remoteAddr,
|
||||||
|
publicKey: exitNode.PublicKey,
|
||||||
|
endpointName: exitNode.Endpoint,
|
||||||
|
})
|
||||||
|
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resolvedNodes) == 0 {
|
||||||
|
logger.Error("No exit nodes could be resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send initial hole punch to all exit nodes
|
||||||
|
for _, node := range resolvedNodes {
|
||||||
|
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||||
|
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(250 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
timeout := time.NewTimer(15 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.stopChan:
|
||||||
|
logger.Debug("Hole punch stopped by signal")
|
||||||
|
return
|
||||||
|
case <-timeout.C:
|
||||||
|
logger.Debug("Hole punch timeout reached")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
// Send hole punch to all exit nodes
|
||||||
|
for _, node := range resolvedNodes {
|
||||||
|
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||||
|
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runSingleEndpoint performs hole punching to a single endpoint
|
||||||
|
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
|
||||||
|
defer func() {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.running = false
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
||||||
|
}()
|
||||||
|
|
||||||
|
host, err := m.domainResolver(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := net.JoinHostPort(host, "21820")
|
||||||
|
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute once immediately before starting the loop
|
||||||
|
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||||
|
logger.Warn("Failed to send initial hole punch: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(250 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
timeout := time.NewTimer(15 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.stopChan:
|
||||||
|
logger.Debug("Hole punch stopped by signal")
|
||||||
|
return
|
||||||
|
case <-timeout.C:
|
||||||
|
logger.Debug("Hole punch timeout reached")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||||
|
logger.Debug("Failed to send hole punch: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendHolePunch sends an encrypted hole punch packet using the shared bind
|
||||||
|
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
token := m.token
|
||||||
|
olmID := m.olmID
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if serverPubKey == "" || token == "" {
|
||||||
|
return fmt.Errorf("server public key or OLM token is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := struct {
|
||||||
|
OlmID string `json:"olmId"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}{
|
||||||
|
OlmID: olmID,
|
||||||
|
Token: token,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert payload to JSON
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the payload using the server's WireGuard public key
|
||||||
|
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encrypt payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(encryptedPayload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write to UDP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
|
||||||
|
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
|
||||||
|
// Generate an ephemeral keypair for this message
|
||||||
|
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||||
|
}
|
||||||
|
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||||
|
|
||||||
|
// Parse the server's public key
|
||||||
|
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use X25519 for key exchange
|
||||||
|
var ephPrivKeyFixed [32]byte
|
||||||
|
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||||
|
|
||||||
|
// Perform X25519 key exchange
|
||||||
|
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an AEAD cipher using the shared secret
|
||||||
|
aead, err := chacha20poly1305.New(sharedSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random nonce
|
||||||
|
nonce := make([]byte, aead.NonceSize())
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the payload
|
||||||
|
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||||
|
|
||||||
|
// Prepare the final encrypted message
|
||||||
|
encryptedMsg := struct {
|
||||||
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||||
|
Nonce []byte `json:"nonce"`
|
||||||
|
Ciphertext []byte `json:"ciphertext"`
|
||||||
|
}{
|
||||||
|
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||||
|
Nonce: nonce,
|
||||||
|
Ciphertext: ciphertext,
|
||||||
|
}
|
||||||
|
|
||||||
|
return encryptedMsg, nil
|
||||||
|
}
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
package logger
|
|
||||||
|
|
||||||
type LogLevel int
|
|
||||||
|
|
||||||
const (
|
|
||||||
DEBUG LogLevel = iota
|
|
||||||
INFO
|
|
||||||
WARN
|
|
||||||
ERROR
|
|
||||||
FATAL
|
|
||||||
)
|
|
||||||
|
|
||||||
var levelStrings = map[LogLevel]string{
|
|
||||||
DEBUG: "DEBUG",
|
|
||||||
INFO: "INFO",
|
|
||||||
WARN: "WARN",
|
|
||||||
ERROR: "ERROR",
|
|
||||||
FATAL: "FATAL",
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns the string representation of the log level
|
|
||||||
func (l LogLevel) String() string {
|
|
||||||
if s, ok := levelStrings[l]; ok {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return "UNKNOWN"
|
|
||||||
}
|
|
||||||
106
logger/logger.go
106
logger/logger.go
@@ -1,106 +0,0 @@
|
|||||||
package logger
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Logger struct holds the logger instance
|
|
||||||
type Logger struct {
|
|
||||||
logger *log.Logger
|
|
||||||
level LogLevel
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
defaultLogger *Logger
|
|
||||||
once sync.Once
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewLogger creates a new logger instance
|
|
||||||
func NewLogger() *Logger {
|
|
||||||
return &Logger{
|
|
||||||
logger: log.New(os.Stdout, "", 0),
|
|
||||||
level: DEBUG,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Init initializes the default logger
|
|
||||||
func Init() *Logger {
|
|
||||||
once.Do(func() {
|
|
||||||
defaultLogger = NewLogger()
|
|
||||||
})
|
|
||||||
return defaultLogger
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLogger returns the default logger instance
|
|
||||||
func GetLogger() *Logger {
|
|
||||||
if defaultLogger == nil {
|
|
||||||
Init()
|
|
||||||
}
|
|
||||||
return defaultLogger
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLevel sets the minimum logging level
|
|
||||||
func (l *Logger) SetLevel(level LogLevel) {
|
|
||||||
l.level = level
|
|
||||||
}
|
|
||||||
|
|
||||||
// log handles the actual logging
|
|
||||||
func (l *Logger) log(level LogLevel, format string, args ...interface{}) {
|
|
||||||
if level < l.level {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
|
||||||
message := fmt.Sprintf(format, args...)
|
|
||||||
l.logger.Printf("%s: %s %s", level.String(), timestamp, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug logs debug level messages
|
|
||||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
|
||||||
l.log(DEBUG, format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Info logs info level messages
|
|
||||||
func (l *Logger) Info(format string, args ...interface{}) {
|
|
||||||
l.log(INFO, format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Warn logs warning level messages
|
|
||||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
|
||||||
l.log(WARN, format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error logs error level messages
|
|
||||||
func (l *Logger) Error(format string, args ...interface{}) {
|
|
||||||
l.log(ERROR, format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fatal logs fatal level messages and exits
|
|
||||||
func (l *Logger) Fatal(format string, args ...interface{}) {
|
|
||||||
l.log(FATAL, format, args...)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Global helper functions
|
|
||||||
func Debug(format string, args ...interface{}) {
|
|
||||||
GetLogger().Debug(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Info(format string, args ...interface{}) {
|
|
||||||
GetLogger().Info(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Warn(format string, args ...interface{}) {
|
|
||||||
GetLogger().Warn(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Error(format string, args ...interface{}) {
|
|
||||||
GetLogger().Error(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Fatal(format string, args ...interface{}) {
|
|
||||||
GetLogger().Fatal(format, args...)
|
|
||||||
}
|
|
||||||
728
main.go
728
main.go
@@ -1,566 +1,220 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"strings"
|
"runtime"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/proxy"
|
"github.com/fosrl/olm/olm"
|
||||||
"github.com/fosrl/newt/websocket"
|
|
||||||
|
|
||||||
"golang.org/x/net/icmp"
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
|
||||||
"golang.zx2c4.com/wireguard/device"
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type WgData struct {
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
ServerIP string `json:"serverIP"`
|
|
||||||
TunnelIP string `json:"tunnelIP"`
|
|
||||||
Targets TargetsByType `json:"targets"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TargetsByType struct {
|
|
||||||
UDP []string `json:"udp"`
|
|
||||||
TCP []string `json:"tcp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TargetData struct {
|
|
||||||
Targets []string `json:"targets"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func fixKey(key string) string {
|
|
||||||
// Remove any whitespace
|
|
||||||
key = strings.TrimSpace(key)
|
|
||||||
|
|
||||||
// Decode from base64
|
|
||||||
decoded, err := base64.StdEncoding.DecodeString(key)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Error decoding base64:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to hex
|
|
||||||
return hex.EncodeToString(decoded)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ping(tnet *netstack.Net, dst string) error {
|
|
||||||
logger.Info("Pinging %s", dst)
|
|
||||||
socket, err := tnet.Dial("ping4", dst)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create ICMP socket: %w", err)
|
|
||||||
}
|
|
||||||
defer socket.Close()
|
|
||||||
|
|
||||||
requestPing := icmp.Echo{
|
|
||||||
Seq: rand.Intn(1 << 16),
|
|
||||||
Data: []byte("gopher burrow"),
|
|
||||||
}
|
|
||||||
|
|
||||||
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal ICMP message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
|
|
||||||
return fmt.Errorf("failed to set read deadline: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
_, err = socket.Write(icmpBytes)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write ICMP packet: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := socket.Read(icmpBytes[:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read ICMP packet: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse ICMP packet: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
replyPing, ok := replyPacket.Body.(*icmp.Echo)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
|
|
||||||
return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q",
|
|
||||||
replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Ping latency: %v", time.Since(start))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func pingWithRetry(tnet *netstack.Net, dst string) error {
|
|
||||||
const (
|
|
||||||
maxAttempts = 5
|
|
||||||
retryDelay = 2 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
var lastErr error
|
|
||||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
|
||||||
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
|
|
||||||
|
|
||||||
if err := ping(tnet, dst); err != nil {
|
|
||||||
lastErr = err
|
|
||||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
|
||||||
|
|
||||||
if attempt < maxAttempts {
|
|
||||||
time.Sleep(retryDelay)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return fmt.Errorf("all ping attempts failed after %d tries, last error: %w",
|
|
||||||
maxAttempts, lastErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Successful ping
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// This shouldn't be reached due to the return in the loop, but added for completeness
|
|
||||||
return fmt.Errorf("unexpected error: all ping attempts failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseLogLevel(level string) logger.LogLevel {
|
|
||||||
switch strings.ToUpper(level) {
|
|
||||||
case "DEBUG":
|
|
||||||
return logger.DEBUG
|
|
||||||
case "INFO":
|
|
||||||
return logger.INFO
|
|
||||||
case "WARN":
|
|
||||||
return logger.WARN
|
|
||||||
case "ERROR":
|
|
||||||
return logger.ERROR
|
|
||||||
case "FATAL":
|
|
||||||
return logger.FATAL
|
|
||||||
default:
|
|
||||||
return logger.INFO // default to INFO if invalid level provided
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
|
||||||
switch level {
|
|
||||||
case logger.DEBUG:
|
|
||||||
return device.LogLevelVerbose
|
|
||||||
// case logger.INFO:
|
|
||||||
// return device.LogLevel
|
|
||||||
case logger.WARN:
|
|
||||||
return device.LogLevelError
|
|
||||||
case logger.ERROR, logger.FATAL:
|
|
||||||
return device.LogLevelSilent
|
|
||||||
default:
|
|
||||||
return device.LogLevelSilent
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func resolveDomain(domain string) (string, error) {
|
|
||||||
// Check if there's a port in the domain
|
|
||||||
host, port, err := net.SplitHostPort(domain)
|
|
||||||
if err != nil {
|
|
||||||
// No port found, use the domain as is
|
|
||||||
host = domain
|
|
||||||
port = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove any protocol prefix if present
|
|
||||||
if strings.HasPrefix(host, "http://") {
|
|
||||||
host = strings.TrimPrefix(host, "http://")
|
|
||||||
} else if strings.HasPrefix(host, "https://") {
|
|
||||||
host = strings.TrimPrefix(host, "https://")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lookup IP addresses
|
|
||||||
ips, err := net.LookupIP(host)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ips) == 0 {
|
|
||||||
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the first IPv4 address if available
|
|
||||||
var ipAddr string
|
|
||||||
for _, ip := range ips {
|
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
|
||||||
ipAddr = ipv4.String()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no IPv4 found, use the first IP (might be IPv6)
|
|
||||||
if ipAddr == "" {
|
|
||||||
ipAddr = ips[0].String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add port back if it existed
|
|
||||||
if port != "" {
|
|
||||||
ipAddr = net.JoinHostPort(ipAddr, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ipAddr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getEnvWithDefault(key, defaultValue string) string {
|
|
||||||
if value := os.Getenv(key); value != "" {
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var (
|
// Check if we're running as a Windows service
|
||||||
endpoint string
|
if isWindowsService() {
|
||||||
id string
|
runService("OlmWireguardService", false, os.Args[1:])
|
||||||
secret string
|
fmt.Println("Running as Windows service")
|
||||||
dns string
|
return
|
||||||
privateKey wgtypes.Key
|
|
||||||
err error
|
|
||||||
logLevel string
|
|
||||||
)
|
|
||||||
|
|
||||||
// Define CLI flags with default values from environment variables
|
|
||||||
flag.StringVar(&endpoint, "endpoint", os.Getenv("PANGOLIN_ENDPOINT"), "Endpoint of your pangolin server")
|
|
||||||
flag.StringVar(&id, "id", os.Getenv("NEWT_ID"), "Newt ID")
|
|
||||||
flag.StringVar(&secret, "secret", os.Getenv("NEWT_SECRET"), "Newt secret")
|
|
||||||
flag.StringVar(&dns, "dns", getEnvWithDefault("DEFAULT_DNS", "8.8.8.8"), "DNS server to use")
|
|
||||||
flag.StringVar(&logLevel, "log-level", getEnvWithDefault("LOG_LEVEL", "INFO"), "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
logger.Init()
|
|
||||||
loggerLevel := parseLogLevel(logLevel)
|
|
||||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
|
||||||
|
|
||||||
// Validate required fields
|
|
||||||
if endpoint == "" || id == "" || secret == "" {
|
|
||||||
logger.Fatal("endpoint, id, and secret are required either via CLI flags or environment variables")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
// Handle service management commands on Windows
|
||||||
if err != nil {
|
if runtime.GOOS == "windows" {
|
||||||
logger.Fatal("Failed to generate private key: %v", err)
|
var command string
|
||||||
}
|
if len(os.Args) > 1 {
|
||||||
|
command = os.Args[1]
|
||||||
// Create a new client
|
} else {
|
||||||
client, err := websocket.NewClient(
|
command = "default"
|
||||||
id, // CLI arg takes precedence
|
|
||||||
secret, // CLI arg takes precedence
|
|
||||||
endpoint,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to create client: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create TUN device and network stack
|
|
||||||
var tun tun.Device
|
|
||||||
var tnet *netstack.Net
|
|
||||||
var dev *device.Device
|
|
||||||
var pm *proxy.ProxyManager
|
|
||||||
var connected bool
|
|
||||||
var wgData WgData
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received terminate message")
|
|
||||||
if pm != nil {
|
|
||||||
pm.Stop()
|
|
||||||
}
|
}
|
||||||
if dev != nil {
|
|
||||||
dev.Close()
|
|
||||||
}
|
|
||||||
client.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Register handlers for different message types
|
switch command {
|
||||||
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
case "install":
|
||||||
logger.Info("Received registration message")
|
err := installService()
|
||||||
|
|
||||||
if connected {
|
|
||||||
logger.Info("Already connected! Put I will send a ping anyway...")
|
|
||||||
// ping(tnet, wgData.ServerIP)
|
|
||||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Handle complete failure after all retries
|
fmt.Printf("Failed to install service: %v\n", err)
|
||||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
fmt.Println("Service installed successfully")
|
||||||
return
|
return
|
||||||
}
|
case "remove", "uninstall":
|
||||||
|
err := removeService()
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
tun, tnet, err = netstack.CreateNetTUN(
|
|
||||||
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
|
||||||
[]netip.Addr{netip.MustParseAddr(dns)},
|
|
||||||
1420)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to create TUN device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create WireGuard device
|
|
||||||
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
|
|
||||||
mapToWireGuardLogLevel(loggerLevel),
|
|
||||||
"wireguard: ",
|
|
||||||
))
|
|
||||||
|
|
||||||
endpoint, err := resolveDomain(wgData.Endpoint)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to resolve endpoint: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Configure WireGuard
|
|
||||||
config := fmt.Sprintf(`private_key=%s
|
|
||||||
public_key=%s
|
|
||||||
allowed_ip=%s/32
|
|
||||||
endpoint=%s
|
|
||||||
persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
|
||||||
|
|
||||||
err = dev.IpcSet(config)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to configure WireGuard device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bring up the device
|
|
||||||
err = dev.Up()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("WireGuard device created. Lets ping the server now...")
|
|
||||||
// Ping to bring the tunnel up on the server side quickly
|
|
||||||
// ping(tnet, wgData.ServerIP)
|
|
||||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
|
||||||
if err != nil {
|
|
||||||
// Handle complete failure after all retries
|
|
||||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create proxy manager
|
|
||||||
pm = proxy.NewProxyManager(tnet)
|
|
||||||
|
|
||||||
connected = true
|
|
||||||
|
|
||||||
// add the targets if there are any
|
|
||||||
if len(wgData.Targets.TCP) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wgData.Targets.UDP) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
|
||||||
}
|
|
||||||
|
|
||||||
err = pm.Start()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to start proxy manager: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = pm.Start()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to start proxy manager: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = pm.Start()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to start proxy manager: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.OnConnect(func() error {
|
|
||||||
publicKey := privateKey.PublicKey()
|
|
||||||
logger.Debug("Public key: %s", publicKey)
|
|
||||||
|
|
||||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
|
||||||
"publicKey": fmt.Sprintf("%s", publicKey),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Sent registration message")
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// Connect to the WebSocket server
|
|
||||||
if err := client.Connect(); err != nil {
|
|
||||||
logger.Fatal("Failed to connect to server: %v", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
// Wait for interrupt signal
|
|
||||||
sigCh := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
<-sigCh
|
|
||||||
|
|
||||||
// Cleanup
|
|
||||||
dev.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTargetData(data interface{}) (TargetData, error) {
|
|
||||||
var targetData TargetData
|
|
||||||
jsonData, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
return targetData, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &targetData); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
return targetData, err
|
|
||||||
}
|
|
||||||
return targetData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
|
||||||
for _, t := range targetData.Targets {
|
|
||||||
// Split the first number off of the target with : separator and use as the port
|
|
||||||
parts := strings.Split(t, ":")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
logger.Info("Invalid target format: %s", t)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the port as an int
|
|
||||||
port := 0
|
|
||||||
_, err := fmt.Sscanf(parts[0], "%d", &port)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Invalid port: %s", parts[0])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if action == "add" {
|
|
||||||
target := parts[1] + ":" + parts[2]
|
|
||||||
// Only remove the specific target if it exists
|
|
||||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Ignore "target not found" errors as this is expected for new targets
|
fmt.Printf("Failed to remove service: %v\n", err)
|
||||||
if !strings.Contains(err.Error(), "target not found") {
|
os.Exit(1)
|
||||||
logger.Error("Failed to remove existing target: %v", err)
|
}
|
||||||
|
fmt.Println("Service removed successfully")
|
||||||
|
return
|
||||||
|
case "start":
|
||||||
|
// Pass the remaining arguments (after "start") to the service
|
||||||
|
serviceArgs := os.Args[2:]
|
||||||
|
err := startService(serviceArgs)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to start service: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Println("Service started successfully")
|
||||||
|
return
|
||||||
|
case "stop":
|
||||||
|
err := stopService()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to stop service: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Println("Service stopped successfully")
|
||||||
|
return
|
||||||
|
case "status":
|
||||||
|
status, err := getServiceStatus()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to get service status: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("Service status: %s\n", status)
|
||||||
|
return
|
||||||
|
case "debug":
|
||||||
|
// get the status and if it is Not Installed then install it first
|
||||||
|
status, err := getServiceStatus()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to get service status: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if status == "Not Installed" {
|
||||||
|
err := installService()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to install service: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
fmt.Println("Service installed successfully, now running in debug mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the new target
|
// Pass the remaining arguments (after "debug") to the service
|
||||||
pm.AddTarget(proto, tunnelIP, port, target)
|
serviceArgs := os.Args[2:]
|
||||||
|
err = debugService(serviceArgs)
|
||||||
} else if action == "remove" {
|
|
||||||
logger.Info("Removing target with port %d", port)
|
|
||||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to remove target: %v", err)
|
fmt.Printf("Failed to debug service: %v\n", err)
|
||||||
return err
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
case "logs":
|
||||||
|
err := watchLogFile(false)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to watch log file: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case "config":
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
showServiceConfig()
|
||||||
|
} else {
|
||||||
|
fmt.Println("Service configuration is only available on Windows")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case "help", "--help", "-h":
|
||||||
|
fmt.Println("Olm WireGuard VPN Client")
|
||||||
|
fmt.Println("\nWindows Service Management:")
|
||||||
|
fmt.Println(" install Install the service")
|
||||||
|
fmt.Println(" remove Remove the service")
|
||||||
|
fmt.Println(" start [args] Start the service with optional arguments")
|
||||||
|
fmt.Println(" stop Stop the service")
|
||||||
|
fmt.Println(" status Show service status")
|
||||||
|
fmt.Println(" debug [args] Run service in debug mode with optional arguments")
|
||||||
|
fmt.Println(" logs Tail the service log file")
|
||||||
|
fmt.Println(" config Show current service configuration")
|
||||||
|
fmt.Println("\nExamples:")
|
||||||
|
fmt.Println(" olm start --enable-http --http-addr :9452")
|
||||||
|
fmt.Println(" olm debug --endpoint https://example.com --id myid --secret mysecret")
|
||||||
|
fmt.Println("\nFor console mode, run without arguments or with standard flags.")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// get the status and if it is Not Installed then install it first
|
||||||
|
status, err := getServiceStatus()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to get service status: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if status == "Not Installed" {
|
||||||
|
err := installService()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to install service: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Println("Service installed successfully, now running")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass the remaining arguments (after "debug") to the service
|
||||||
|
serviceArgs := os.Args[1:]
|
||||||
|
err = debugService(serviceArgs)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to debug service: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Setup Windows event logging if on Windows
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
setupWindowsEventLog()
|
||||||
|
} else {
|
||||||
|
// Initialize logger for non-Windows platforms
|
||||||
|
logger.Init()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load configuration from file, env vars, and CLI args
|
||||||
|
// Priority: CLI args > Env vars > Config file > Defaults
|
||||||
|
config, showVersion, showConfig, err := LoadConfig(os.Args[1:])
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to load configuration: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle --show-config flag
|
||||||
|
if showConfig {
|
||||||
|
config.ShowConfig()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
olmVersion := "version_replaceme"
|
||||||
|
if showVersion {
|
||||||
|
fmt.Println("Olm version " + olmVersion)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
logger.Info("Olm version " + olmVersion)
|
||||||
|
|
||||||
|
config.Version = olmVersion
|
||||||
|
|
||||||
|
if err := SaveConfig(config); err != nil {
|
||||||
|
logger.Error("Failed to save full olm config: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Saved full olm config with all options")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new olm.Config struct and copy values from the main config
|
||||||
|
olmConfig := olm.Config{
|
||||||
|
Endpoint: config.Endpoint,
|
||||||
|
ID: config.ID,
|
||||||
|
Secret: config.Secret,
|
||||||
|
UserToken: config.UserToken,
|
||||||
|
MTU: config.MTU,
|
||||||
|
DNS: config.DNS,
|
||||||
|
InterfaceName: config.InterfaceName,
|
||||||
|
LogLevel: config.LogLevel,
|
||||||
|
EnableAPI: config.EnableAPI,
|
||||||
|
HTTPAddr: config.HTTPAddr,
|
||||||
|
SocketPath: config.SocketPath,
|
||||||
|
Holepunch: config.Holepunch,
|
||||||
|
TlsClientCert: config.TlsClientCert,
|
||||||
|
PingIntervalDuration: config.PingIntervalDuration,
|
||||||
|
PingTimeoutDuration: config.PingTimeoutDuration,
|
||||||
|
Version: config.Version,
|
||||||
|
OrgID: config.OrgID,
|
||||||
|
// DoNotCreateNewClient: config.DoNotCreateNewClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a context that will be cancelled on interrupt signals
|
||||||
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
olm.Run(ctx, olmConfig)
|
||||||
}
|
}
|
||||||
|
|||||||
1
olm-binary.REMOVED.git-id
Normal file
1
olm-binary.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
573df1772c00fcb34ec68e575e973c460dc27ba8
|
||||||
1
olm-test.REMOVED.git-id
Normal file
1
olm-test.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
ba2c118fd96937229ef54dcd0b82fe5d53d94a87
|
||||||
88
olm.iss
Normal file
88
olm.iss
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
; Script generated by the Inno Setup Script Wizard.
|
||||||
|
; SEE THE DOCUMENTATION FOR DETAILS ON CREATING INNO SETUP SCRIPT FILES!
|
||||||
|
|
||||||
|
#define MyAppName "olm"
|
||||||
|
#define MyAppVersion "1.0.0"
|
||||||
|
#define MyAppPublisher "Fossorial Inc."
|
||||||
|
#define MyAppURL "https://pangolin.net"
|
||||||
|
#define MyAppExeName "olm.exe"
|
||||||
|
|
||||||
|
[Setup]
|
||||||
|
; NOTE: The value of AppId uniquely identifies this application. Do not use the same AppId value in installers for other applications.
|
||||||
|
; (To generate a new GUID, click Tools | Generate GUID inside the IDE.)
|
||||||
|
AppId={{44A24E4C-B616-476F-ADE7-8D56B930959E}
|
||||||
|
AppName={#MyAppName}
|
||||||
|
AppVersion={#MyAppVersion}
|
||||||
|
;AppVerName={#MyAppName} {#MyAppVersion}
|
||||||
|
AppPublisher={#MyAppPublisher}
|
||||||
|
AppPublisherURL={#MyAppURL}
|
||||||
|
AppSupportURL={#MyAppURL}
|
||||||
|
AppUpdatesURL={#MyAppURL}
|
||||||
|
DefaultDirName={autopf}\{#MyAppName}
|
||||||
|
UninstallDisplayIcon={app}\{#MyAppExeName}
|
||||||
|
; "ArchitecturesAllowed=x64compatible" specifies that Setup cannot run
|
||||||
|
; on anything but x64 and Windows 11 on Arm.
|
||||||
|
ArchitecturesAllowed=x64compatible
|
||||||
|
; "ArchitecturesInstallIn64BitMode=x64compatible" requests that the
|
||||||
|
; install be done in "64-bit mode" on x64 or Windows 11 on Arm,
|
||||||
|
; meaning it should use the native 64-bit Program Files directory and
|
||||||
|
; the 64-bit view of the registry.
|
||||||
|
ArchitecturesInstallIn64BitMode=x64compatible
|
||||||
|
DefaultGroupName={#MyAppName}
|
||||||
|
DisableProgramGroupPage=yes
|
||||||
|
; Uncomment the following line to run in non administrative install mode (install for current user only).
|
||||||
|
;PrivilegesRequired=lowest
|
||||||
|
OutputBaseFilename=mysetup
|
||||||
|
SolidCompression=yes
|
||||||
|
WizardStyle=modern
|
||||||
|
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed
|
||||||
|
RestartIfNeededByRun=no
|
||||||
|
ChangesEnvironment=true
|
||||||
|
|
||||||
|
[Languages]
|
||||||
|
Name: "english"; MessagesFile: "compiler:Default.isl"
|
||||||
|
|
||||||
|
[Files]
|
||||||
|
; The 'DestName' flag ensures that 'olm_windows_amd64.exe' is installed as 'olm.exe'
|
||||||
|
Source: "C:\Users\Administrator\Downloads\olm_windows_amd64.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}"; Flags: ignoreversion
|
||||||
|
Source: "C:\Users\Administrator\Downloads\wintun.dll"; DestDir: "{app}"; Flags: ignoreversion
|
||||||
|
; NOTE: Don't use "Flags: ignoreversion" on any shared system files
|
||||||
|
|
||||||
|
[Icons]
|
||||||
|
Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"
|
||||||
|
|
||||||
|
[Registry]
|
||||||
|
; Add the application's installation directory to the system PATH environment variable.
|
||||||
|
; HKLM (HKEY_LOCAL_MACHINE) is used for system-wide changes.
|
||||||
|
; The 'Path' variable is located under 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'.
|
||||||
|
; ValueType: expandsz allows for environment variables (like %ProgramFiles%) in the path.
|
||||||
|
; ValueData: "{olddata};{app}" appends the current application directory to the existing PATH.
|
||||||
|
; Flags: uninsdeletevalue ensures the entry is removed upon uninstallation.
|
||||||
|
; Check: IsWin64 ensures this is applied on 64-bit systems, which matches ArchitecturesAllowed.
|
||||||
|
[Registry]
|
||||||
|
; Add the application's installation directory to the system PATH.
|
||||||
|
Root: HKLM; Subkey: "SYSTEM\CurrentControlSet\Control\Session Manager\Environment"; \
|
||||||
|
ValueType: expandsz; ValueName: "Path"; ValueData: "{olddata};{app}"; \
|
||||||
|
Flags: uninsdeletevalue; Check: NeedsAddPath(ExpandConstant('{app}'))
|
||||||
|
|
||||||
|
[Code]
|
||||||
|
function NeedsAddPath(Path: string): boolean;
|
||||||
|
var
|
||||||
|
OrigPath: string;
|
||||||
|
begin
|
||||||
|
if not RegQueryStringValue(HKEY_LOCAL_MACHINE,
|
||||||
|
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||||
|
'Path', OrigPath)
|
||||||
|
then begin
|
||||||
|
// Path variable doesn't exist at all, so we definitely need to add it.
|
||||||
|
Result := True;
|
||||||
|
exit;
|
||||||
|
end;
|
||||||
|
|
||||||
|
// Perform a case-insensitive check to see if the path is already present.
|
||||||
|
// We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2).
|
||||||
|
if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then
|
||||||
|
Result := False
|
||||||
|
else
|
||||||
|
Result := True;
|
||||||
|
end;
|
||||||
885
olm/common.go
Normal file
885
olm/common.go
Normal file
@@ -0,0 +1,885 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/peermonitor"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
"golang.org/x/exp/rand"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WgData struct {
|
||||||
|
Sites []SiteConfig `json:"sites"`
|
||||||
|
TunnelIP string `json:"tunnelIP"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SiteConfig struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
ServerIP string `json:"serverIP"`
|
||||||
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
|
}
|
||||||
|
|
||||||
|
type TargetsByType struct {
|
||||||
|
UDP []string `json:"udp"`
|
||||||
|
TCP []string `json:"tcp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TargetData struct {
|
||||||
|
Targets []string `json:"targets"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HolePunchMessage struct {
|
||||||
|
NewtID string `json:"newtId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExitNode struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HolePunchData struct {
|
||||||
|
ExitNodes []ExitNode `json:"exitNodes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncryptedHolePunchMessage struct {
|
||||||
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||||
|
Nonce []byte `json:"nonce"`
|
||||||
|
Ciphertext []byte `json:"ciphertext"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
peerMonitor *peermonitor.PeerMonitor
|
||||||
|
stopHolepunch chan struct{}
|
||||||
|
stopRegister func()
|
||||||
|
stopPing chan struct{}
|
||||||
|
olmToken string
|
||||||
|
holePunchRunning bool
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ENV_WG_TUN_FD = "WG_TUN_FD"
|
||||||
|
ENV_WG_UAPI_FD = "WG_UAPI_FD"
|
||||||
|
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PeerAction represents a request to add, update, or remove a peer
|
||||||
|
type PeerAction struct {
|
||||||
|
Action string `json:"action"` // "add", "update", or "remove"
|
||||||
|
SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeerData represents the data needed to update a peer
|
||||||
|
type UpdatePeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
ServerIP string `json:"serverIP"`
|
||||||
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPeerData represents the data needed to add a peer
|
||||||
|
type AddPeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
ServerIP string `json:"serverIP"`
|
||||||
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeerData represents the data needed to remove a peer
|
||||||
|
type RemovePeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayPeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to format endpoints correctly
|
||||||
|
func formatEndpoint(endpoint string) string {
|
||||||
|
if endpoint == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
|
||||||
|
_, _, err := net.SplitHostPort(endpoint)
|
||||||
|
if err == nil {
|
||||||
|
return endpoint // Already valid, no change needed
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
|
||||||
|
lastColon := strings.LastIndex(endpoint, ":")
|
||||||
|
if lastColon > 0 { // Ensure there is a colon and it's not the first character
|
||||||
|
hostPart := endpoint[:lastColon]
|
||||||
|
// Check if the host part is a literal IPv6 address
|
||||||
|
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
|
||||||
|
// It is! Reformat it with brackets.
|
||||||
|
portPart := endpoint[lastColon+1:]
|
||||||
|
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's not the specific malformed case, return it as is.
|
||||||
|
return endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func fixKey(key string) string {
|
||||||
|
// Remove any whitespace
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
|
||||||
|
// Decode from base64
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(key)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Error decoding base64")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to hex
|
||||||
|
return hex.EncodeToString(decoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLogLevel(level string) logger.LogLevel {
|
||||||
|
switch strings.ToUpper(level) {
|
||||||
|
case "DEBUG":
|
||||||
|
return logger.DEBUG
|
||||||
|
case "INFO":
|
||||||
|
return logger.INFO
|
||||||
|
case "WARN":
|
||||||
|
return logger.WARN
|
||||||
|
case "ERROR":
|
||||||
|
return logger.ERROR
|
||||||
|
case "FATAL":
|
||||||
|
return logger.FATAL
|
||||||
|
default:
|
||||||
|
return logger.INFO // default to INFO if invalid level provided
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
||||||
|
switch level {
|
||||||
|
case logger.DEBUG:
|
||||||
|
return device.LogLevelVerbose
|
||||||
|
// case logger.INFO:
|
||||||
|
// return device.LogLevel
|
||||||
|
case logger.WARN:
|
||||||
|
return device.LogLevelError
|
||||||
|
case logger.ERROR, logger.FATAL:
|
||||||
|
return device.LogLevelSilent
|
||||||
|
default:
|
||||||
|
return device.LogLevelSilent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveDomain(domain string) (string, error) {
|
||||||
|
// First handle any protocol prefix
|
||||||
|
domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://")
|
||||||
|
|
||||||
|
// if there are any trailing slashes, remove them
|
||||||
|
domain = strings.TrimSuffix(domain, "/")
|
||||||
|
|
||||||
|
// Now split host and port
|
||||||
|
host, port, err := net.SplitHostPort(domain)
|
||||||
|
if err != nil {
|
||||||
|
// No port found, use the domain as is
|
||||||
|
host = domain
|
||||||
|
port = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup IP addresses
|
||||||
|
ips, err := net.LookupIP(host)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ips) == 0 {
|
||||||
|
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the first IPv4 address if available
|
||||||
|
var ipAddr string
|
||||||
|
for _, ip := range ips {
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
ipAddr = ipv4.String()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no IPv4 found, use the first IP (might be IPv6)
|
||||||
|
if ipAddr == "" {
|
||||||
|
ipAddr = ips[0].String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add port back if it existed
|
||||||
|
if port != "" {
|
||||||
|
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||||
|
if maxPort < minPort {
|
||||||
|
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a slice of all ports in the range
|
||||||
|
portRange := make([]uint16, maxPort-minPort+1)
|
||||||
|
for i := range portRange {
|
||||||
|
portRange[i] = minPort + uint16(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fisher-Yates shuffle to randomize the port order
|
||||||
|
rand.Seed(uint64(time.Now().UnixNano()))
|
||||||
|
for i := len(portRange) - 1; i > 0; i-- {
|
||||||
|
j := rand.Intn(i + 1)
|
||||||
|
portRange[i], portRange[j] = portRange[j], portRange[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try each port in the randomized order
|
||||||
|
for _, port := range portRange {
|
||||||
|
addr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: int(port),
|
||||||
|
}
|
||||||
|
conn, err := net.ListenUDP("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
continue // Port is in use or there was an error, try next port
|
||||||
|
}
|
||||||
|
_ = conn.SetDeadline(time.Now())
|
||||||
|
conn.Close()
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendPing(olm *websocket.Client) error {
|
||||||
|
err := olm.SendMessage("olm/ping", map[string]interface{}{
|
||||||
|
"timestamp": time.Now().Unix(),
|
||||||
|
"userToken": olm.GetConfig().UserToken,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send ping message: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Debug("Sent ping message")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func keepSendingPing(olm *websocket.Client) {
|
||||||
|
// Send ping immediately on startup
|
||||||
|
if err := sendPing(olm); err != nil {
|
||||||
|
logger.Error("Failed to send initial ping: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Sent initial ping message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up ticker for one minute intervals
|
||||||
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stopPing:
|
||||||
|
logger.Info("Stopping ping messages")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := sendPing(olm); err != nil {
|
||||||
|
logger.Error("Failed to send periodic ping: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||||
|
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
|
||||||
|
siteHost, err := ResolveDomain(siteConfig.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP
|
||||||
|
allowedIp := strings.Split(siteConfig.ServerIP, "/")
|
||||||
|
if len(allowedIp) > 1 {
|
||||||
|
allowedIp[1] = "32"
|
||||||
|
} else {
|
||||||
|
allowedIp = append(allowedIp, "32")
|
||||||
|
}
|
||||||
|
allowedIpStr := strings.Join(allowedIp, "/")
|
||||||
|
|
||||||
|
// Collect all allowed IPs in a slice
|
||||||
|
var allowedIPs []string
|
||||||
|
allowedIPs = append(allowedIPs, allowedIpStr)
|
||||||
|
|
||||||
|
// If we have anything in remoteSubnets, add those as well
|
||||||
|
if siteConfig.RemoteSubnets != "" {
|
||||||
|
// Split remote subnets by comma and add each one
|
||||||
|
remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",")
|
||||||
|
for _, subnet := range remoteSubnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet != "" {
|
||||||
|
allowedIPs = append(allowedIPs, subnet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct WireGuard config for this peer
|
||||||
|
var configBuilder strings.Builder
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String())))
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(siteConfig.PublicKey)))
|
||||||
|
|
||||||
|
// Add each allowed IP separately
|
||||||
|
for _, allowedIP := range allowedIPs {
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||||
|
}
|
||||||
|
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
||||||
|
configBuilder.WriteString("persistent_keepalive_interval=1\n")
|
||||||
|
|
||||||
|
config := configBuilder.String()
|
||||||
|
logger.Debug("Configuring peer with config: %s", config)
|
||||||
|
|
||||||
|
err = dev.IpcSet(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up peer monitoring
|
||||||
|
if peerMonitor != nil {
|
||||||
|
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||||
|
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||||
|
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||||
|
|
||||||
|
primaryRelay, err := ResolveDomain(endpoint) // Using global endpoint variable
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wgConfig := &peermonitor.WireGuardConfig{
|
||||||
|
SiteID: siteConfig.SiteId,
|
||||||
|
PublicKey: fixKey(siteConfig.PublicKey),
|
||||||
|
ServerIP: strings.Split(siteConfig.ServerIP, "/")[0],
|
||||||
|
Endpoint: siteConfig.Endpoint,
|
||||||
|
PrimaryRelay: primaryRelay,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeer removes a peer from the WireGuard device
|
||||||
|
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
|
||||||
|
// Construct WireGuard config to remove the peer
|
||||||
|
var configBuilder strings.Builder
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(publicKey)))
|
||||||
|
configBuilder.WriteString("remove=true\n")
|
||||||
|
|
||||||
|
config := configBuilder.String()
|
||||||
|
logger.Debug("Removing peer with config: %s", config)
|
||||||
|
|
||||||
|
err := dev.IpcSet(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop monitoring this peer
|
||||||
|
if peerMonitor != nil {
|
||||||
|
peerMonitor.RemovePeer(siteId)
|
||||||
|
logger.Info("Stopped monitoring for site %d", siteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigureInterface configures a network interface with an IP address and brings it up
|
||||||
|
func ConfigureInterface(interfaceName string, wgData WgData) error {
|
||||||
|
var ipAddr string = wgData.TunnelIP
|
||||||
|
|
||||||
|
// Parse the IP address and network
|
||||||
|
ip, ipNet, err := net.ParseCIDR(ipAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return configureLinux(interfaceName, ip, ipNet)
|
||||||
|
case "darwin":
|
||||||
|
return configureDarwin(interfaceName, ip, ipNet)
|
||||||
|
case "windows":
|
||||||
|
return configureWindows(interfaceName, ip, ipNet)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
logger.Info("Configuring Windows interface: %s", interfaceName)
|
||||||
|
|
||||||
|
// Calculate mask string (e.g., 255.255.255.0)
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
mask := net.CIDRMask(maskBits, 32)
|
||||||
|
maskIP := net.IP(mask)
|
||||||
|
|
||||||
|
// Set the IP address using netsh
|
||||||
|
cmd := exec.Command("netsh", "interface", "ipv4", "set", "address",
|
||||||
|
fmt.Sprintf("name=%s", interfaceName),
|
||||||
|
"source=static",
|
||||||
|
fmt.Sprintf("addr=%s", ip.String()),
|
||||||
|
fmt.Sprintf("mask=%s", maskIP.String()))
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("netsh command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the interface if needed (in Windows, setting the IP usually brings it up)
|
||||||
|
// But we'll explicitly enable it to be sure
|
||||||
|
cmd = exec.Command("netsh", "interface", "set", "interface",
|
||||||
|
interfaceName,
|
||||||
|
"admin=enable")
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err = cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// delay 2 seconds
|
||||||
|
time.Sleep(8 * time.Second)
|
||||||
|
|
||||||
|
// Wait for the interface to be up and have the correct IP
|
||||||
|
err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("interface did not come up within timeout: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForInterfaceUp polls the network interface until it's up or times out
|
||||||
|
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
|
||||||
|
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
pollInterval := 500 * time.Millisecond
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
// Check if interface exists and is up
|
||||||
|
iface, err := net.InterfaceByName(interfaceName)
|
||||||
|
if err == nil {
|
||||||
|
// Check if interface is up
|
||||||
|
if iface.Flags&net.FlagUp != 0 {
|
||||||
|
// Check if it has the expected IP
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err == nil {
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ipNet, ok := addr.(*net.IPNet)
|
||||||
|
if ok && ipNet.IP.Equal(expectedIP) {
|
||||||
|
logger.Info("Interface %s is up with correct IP", interfaceName)
|
||||||
|
return nil // Interface is up with correct IP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Info("Interface %s exists but is not up yet", interfaceName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Info("Interface %s not found yet: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait before next check
|
||||||
|
time.Sleep(pollInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
// Parse destination to get the IP and subnet
|
||||||
|
ip, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the subnet mask
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
mask := net.CIDRMask(maskBits, 32)
|
||||||
|
maskIP := net.IP(mask)
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
cmd = exec.Command("route", "add",
|
||||||
|
ip.String(),
|
||||||
|
"mask", maskIP.String(),
|
||||||
|
gateway,
|
||||||
|
"metric", "1")
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// First, get the interface index
|
||||||
|
indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces")
|
||||||
|
output, err := indexCmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface index: %v, output: %s", err, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the output to find the interface index
|
||||||
|
lines := strings.Split(string(output), "\n")
|
||||||
|
var ifIndex string
|
||||||
|
for _, line := range lines {
|
||||||
|
if strings.Contains(line, interfaceName) {
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) > 0 {
|
||||||
|
ifIndex = fields[0]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ifIndex == "" {
|
||||||
|
return fmt.Errorf("could not find index for interface %s", interfaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to integer to validate
|
||||||
|
idx, err := strconv.Atoi(ifIndex)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid interface index: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route via interface using the index
|
||||||
|
cmd = exec.Command("route", "add",
|
||||||
|
ip.String(),
|
||||||
|
"mask", maskIP.String(),
|
||||||
|
"0.0.0.0",
|
||||||
|
"if", strconv.Itoa(idx))
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WindowsRemoveRoute(destination string) error {
|
||||||
|
// Parse destination to get the IP
|
||||||
|
ip, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the subnet mask
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
mask := net.CIDRMask(maskBits, 32)
|
||||||
|
maskIP := net.IP(mask)
|
||||||
|
|
||||||
|
cmd := exec.Command("route", "delete",
|
||||||
|
ip.String(),
|
||||||
|
"mask", maskIP.String())
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findUnusedUTUN() (string, error) {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to list interfaces: %v", err)
|
||||||
|
}
|
||||||
|
used := make(map[int]bool)
|
||||||
|
re := regexp.MustCompile(`^utun(\d+)$`)
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 {
|
||||||
|
if num, err := strconv.Atoi(matches[1]); err == nil {
|
||||||
|
used[num] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try utun0 up to utun255.
|
||||||
|
for i := 0; i < 256; i++ {
|
||||||
|
if !used[i] {
|
||||||
|
return fmt.Sprintf("utun%d", i), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("no unused utun interface found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
logger.Info("Configuring darwin interface: %s", interfaceName)
|
||||||
|
|
||||||
|
prefix, _ := ipNet.Mask.Size()
|
||||||
|
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
|
||||||
|
|
||||||
|
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the interface
|
||||||
|
cmd = exec.Command("ifconfig", interfaceName, "up")
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err = cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
// Get the interface
|
||||||
|
link, err := netlink.LinkByName(interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the IP address attributes
|
||||||
|
addr := &netlink.Addr{
|
||||||
|
IPNet: &net.IPNet{
|
||||||
|
IP: ip,
|
||||||
|
Mask: ipNet.Mask,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the IP address to the interface
|
||||||
|
if err := netlink.AddrAdd(link, addr); err != nil {
|
||||||
|
return fmt.Errorf("failed to add IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the interface
|
||||||
|
if err := netlink.LinkSetUp(link); err != nil {
|
||||||
|
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway)
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// Route via interface
|
||||||
|
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DarwinRemoveRoute(destination string) error {
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination)
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
cmd = exec.Command("ip", "route", "add", destination, "via", gateway)
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// Route via interface
|
||||||
|
cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ip route command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LinuxRemoveRoute(destination string) error {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("ip", "route", "del", destination)
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRouteForServerIP adds an OS-specific route for the server IP
|
||||||
|
func addRouteForServerIP(serverIP, interfaceName string) error {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return DarwinAddRoute(serverIP, "", interfaceName)
|
||||||
|
}
|
||||||
|
// else if runtime.GOOS == "windows" {
|
||||||
|
// return WindowsAddRoute(serverIP, "", interfaceName)
|
||||||
|
// } else if runtime.GOOS == "linux" {
|
||||||
|
// return LinuxAddRoute(serverIP, "", interfaceName)
|
||||||
|
// }
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRouteForServerIP removes an OS-specific route for the server IP
|
||||||
|
func removeRouteForServerIP(serverIP string) error {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return DarwinRemoveRoute(serverIP)
|
||||||
|
}
|
||||||
|
// else if runtime.GOOS == "windows" {
|
||||||
|
// return WindowsRemoveRoute(serverIP)
|
||||||
|
// } else if runtime.GOOS == "linux" {
|
||||||
|
// return LinuxRemoveRoute(serverIP)
|
||||||
|
// }
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets
|
||||||
|
func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error {
|
||||||
|
if remoteSubnets == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split remote subnets by comma and add routes for each one
|
||||||
|
subnets := strings.Split(remoteSubnets, ",")
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add route based on operating system
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "windows" {
|
||||||
|
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "linux" {
|
||||||
|
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Added route for remote subnet: %s", subnet)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets
|
||||||
|
func removeRoutesForRemoteSubnets(remoteSubnets string) error {
|
||||||
|
if remoteSubnets == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split remote subnets by comma and remove routes for each one
|
||||||
|
subnets := strings.Split(remoteSubnets, ",")
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove route based on operating system
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if err := DarwinRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "windows" {
|
||||||
|
if err := WindowsRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "linux" {
|
||||||
|
if err := LinuxRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Removed route for remote subnet: %s", subnet)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
894
olm/olm.go
Normal file
894
olm/olm.go
Normal file
@@ -0,0 +1,894 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/newt/updates"
|
||||||
|
"github.com/fosrl/olm/api"
|
||||||
|
"github.com/fosrl/olm/bind"
|
||||||
|
"github.com/fosrl/olm/holepunch"
|
||||||
|
"github.com/fosrl/olm/peermonitor"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
// Connection settings
|
||||||
|
Endpoint string
|
||||||
|
ID string
|
||||||
|
Secret string
|
||||||
|
UserToken string
|
||||||
|
|
||||||
|
// Network settings
|
||||||
|
MTU int
|
||||||
|
DNS string
|
||||||
|
InterfaceName string
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
LogLevel string
|
||||||
|
|
||||||
|
// HTTP server
|
||||||
|
EnableAPI bool
|
||||||
|
HTTPAddr string
|
||||||
|
SocketPath string
|
||||||
|
|
||||||
|
// Advanced
|
||||||
|
Holepunch bool
|
||||||
|
TlsClientCert string
|
||||||
|
|
||||||
|
// Parsed values (not in JSON)
|
||||||
|
PingIntervalDuration time.Duration
|
||||||
|
PingTimeoutDuration time.Duration
|
||||||
|
|
||||||
|
// Source tracking (not in JSON)
|
||||||
|
sources map[string]string
|
||||||
|
|
||||||
|
Version string
|
||||||
|
OrgID string
|
||||||
|
DoNotCreateNewClient bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
privateKey wgtypes.Key
|
||||||
|
connected bool
|
||||||
|
dev *device.Device
|
||||||
|
wgData WgData
|
||||||
|
holePunchData HolePunchData
|
||||||
|
uapiListener net.Listener
|
||||||
|
tdev tun.Device
|
||||||
|
apiServer *api.API
|
||||||
|
olmClient *websocket.Client
|
||||||
|
tunnelCancel context.CancelFunc
|
||||||
|
tunnelRunning bool
|
||||||
|
sharedBind *bind.SharedBind
|
||||||
|
holePunchManager *holepunch.Manager
|
||||||
|
)
|
||||||
|
|
||||||
|
func Run(ctx context.Context, config Config) {
|
||||||
|
// Create a cancellable context for internal shutdown control
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
logger.GetLogger().SetLevel(parseLogLevel(config.LogLevel))
|
||||||
|
|
||||||
|
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
|
||||||
|
logger.Debug("Failed to check for updates: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Holepunch {
|
||||||
|
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.HTTPAddr != "" {
|
||||||
|
apiServer = api.NewAPI(config.HTTPAddr)
|
||||||
|
} else if config.SocketPath != "" {
|
||||||
|
apiServer = api.NewAPISocket(config.SocketPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
apiServer.SetVersion(config.Version)
|
||||||
|
apiServer.SetOrgID(config.OrgID)
|
||||||
|
|
||||||
|
if err := apiServer.Start(); err != nil {
|
||||||
|
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen for shutdown requests from the API
|
||||||
|
go func() {
|
||||||
|
<-apiServer.GetShutdownChannel()
|
||||||
|
logger.Info("Shutdown requested via API")
|
||||||
|
// Cancel the context to trigger graceful shutdown
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var (
|
||||||
|
id = config.ID
|
||||||
|
secret = config.Secret
|
||||||
|
endpoint = config.Endpoint
|
||||||
|
userToken = config.UserToken
|
||||||
|
)
|
||||||
|
|
||||||
|
// Main event loop that handles connect, disconnect, and reconnect
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Info("Context cancelled while waiting for credentials")
|
||||||
|
goto shutdown
|
||||||
|
|
||||||
|
case req := <-apiServer.GetConnectionChannel():
|
||||||
|
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
||||||
|
|
||||||
|
// Stop any existing tunnel before starting a new one
|
||||||
|
if olmClient != nil {
|
||||||
|
logger.Info("Stopping existing tunnel before starting new connection")
|
||||||
|
StopTunnel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the connection parameters
|
||||||
|
id = req.ID
|
||||||
|
secret = req.Secret
|
||||||
|
endpoint = req.Endpoint
|
||||||
|
userToken := req.UserToken
|
||||||
|
|
||||||
|
// Start the tunnel process with the new credentials
|
||||||
|
if id != "" && secret != "" && endpoint != "" {
|
||||||
|
logger.Info("Starting tunnel with new credentials")
|
||||||
|
tunnelRunning = true
|
||||||
|
go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-apiServer.GetDisconnectChannel():
|
||||||
|
logger.Info("Received disconnect request via API")
|
||||||
|
StopTunnel()
|
||||||
|
// Clear credentials so we wait for new connect call
|
||||||
|
id = ""
|
||||||
|
secret = ""
|
||||||
|
endpoint = ""
|
||||||
|
userToken = ""
|
||||||
|
|
||||||
|
default:
|
||||||
|
// If we have credentials and no tunnel is running, start it
|
||||||
|
if id != "" && secret != "" && endpoint != "" && !tunnelRunning {
|
||||||
|
logger.Info("Starting tunnel process with initial credentials")
|
||||||
|
tunnelRunning = true
|
||||||
|
go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
|
||||||
|
} else if id == "" || secret == "" || endpoint == "" {
|
||||||
|
// If we don't have credentials, check if API is enabled
|
||||||
|
if !config.EnableAPI {
|
||||||
|
missing := []string{}
|
||||||
|
if id == "" {
|
||||||
|
missing = append(missing, "id")
|
||||||
|
}
|
||||||
|
if secret == "" {
|
||||||
|
missing = append(missing, "secret")
|
||||||
|
}
|
||||||
|
if endpoint == "" {
|
||||||
|
missing = append(missing, "endpoint")
|
||||||
|
}
|
||||||
|
// exit the application because there is no way to provide the missing parameters
|
||||||
|
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
|
||||||
|
goto shutdown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sleep briefly to prevent tight loop
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdown:
|
||||||
|
Stop()
|
||||||
|
apiServer.Stop()
|
||||||
|
logger.Info("Olm service shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TunnelProcess(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) {
|
||||||
|
// Create a cancellable context for this tunnel process
|
||||||
|
tunnelCtx, cancel := context.WithCancel(ctx)
|
||||||
|
tunnelCancel = cancel
|
||||||
|
defer func() {
|
||||||
|
tunnelCancel = nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Recreate channels for this tunnel session
|
||||||
|
stopPing = make(chan struct{})
|
||||||
|
|
||||||
|
var (
|
||||||
|
interfaceName = config.InterfaceName
|
||||||
|
loggerLevel = parseLogLevel(config.LogLevel)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create a new olm client using the provided credentials
|
||||||
|
olm, err := websocket.NewClient(
|
||||||
|
id, // Use provided ID
|
||||||
|
secret, // Use provided secret
|
||||||
|
userToken, // Use provided user token OPTIONAL
|
||||||
|
endpoint, // Use provided endpoint
|
||||||
|
config.PingIntervalDuration,
|
||||||
|
config.PingTimeoutDuration,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create olm: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the client reference globally
|
||||||
|
olmClient = olm
|
||||||
|
|
||||||
|
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to generate private key: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create shared UDP socket for both holepunch and WireGuard
|
||||||
|
if sharedBind == nil {
|
||||||
|
sourcePort, err := FindAvailableUDPPort(49152, 65535)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error finding available port: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := &net.UDPAddr{
|
||||||
|
Port: int(sourcePort),
|
||||||
|
IP: net.IPv4zero,
|
||||||
|
}
|
||||||
|
|
||||||
|
udpConn, err := net.ListenUDP("udp", localAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create shared UDP socket: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedBind, err = bind.New(udpConn)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create shared bind: %v", err)
|
||||||
|
udpConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a reference for the hole punch senders (creator already has one reference for WireGuard)
|
||||||
|
sharedBind.AddRef()
|
||||||
|
|
||||||
|
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the holepunch manager
|
||||||
|
if holePunchManager == nil {
|
||||||
|
holePunchManager = holepunch.NewManager(sharedBind, id, ResolveDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
|
||||||
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
|
||||||
|
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
|
||||||
|
for i, node := range holePunchData.ExitNodes {
|
||||||
|
exitNodes[i] = holepunch.ExitNode{
|
||||||
|
Endpoint: node.Endpoint,
|
||||||
|
PublicKey: node.PublicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start hole punching using the manager
|
||||||
|
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
|
||||||
|
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
|
||||||
|
logger.Warn("Failed to start hole punch: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
||||||
|
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
|
||||||
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
type LegacyHolePunchData struct {
|
||||||
|
ServerPubKey string `json:"serverPubKey"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var legacyHolePunchData LegacyHolePunchData
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
|
||||||
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop any existing hole punch operations
|
||||||
|
if holePunchManager != nil {
|
||||||
|
holePunchManager.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start hole punching for the exit node
|
||||||
|
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
|
||||||
|
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
|
||||||
|
logger.Warn("Failed to start hole punch: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
if connected {
|
||||||
|
logger.Info("Already connected. Ignoring new connection request.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopRegister != nil {
|
||||||
|
stopRegister()
|
||||||
|
stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// close(stopHolepunch)
|
||||||
|
|
||||||
|
// wait 10 milliseconds to ensure the previous connection is closed
|
||||||
|
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
// if there is an existing tunnel then close it
|
||||||
|
if dev != nil {
|
||||||
|
logger.Info("Got new message. Closing existing tunnel!")
|
||||||
|
dev.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
||||||
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tdev, err = func() (tun.Device, error) {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
interfaceName, err := findUnusedUTUN()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tun.CreateTUN(interfaceName, config.MTU)
|
||||||
|
}
|
||||||
|
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
|
||||||
|
return createTUNFromFD(tunFdStr, config.MTU)
|
||||||
|
}
|
||||||
|
return tun.CreateTUN(interfaceName, config.MTU)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create TUN device: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
||||||
|
interfaceName = realInterfaceName
|
||||||
|
}
|
||||||
|
|
||||||
|
fileUAPI, err := func() (*os.File, error) {
|
||||||
|
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
|
||||||
|
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return os.NewFile(uintptr(fd), ""), nil
|
||||||
|
}
|
||||||
|
return uapiOpen(interfaceName)
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("UAPI listen error: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dev = device.NewDevice(tdev, sharedBind, device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||||
|
|
||||||
|
uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := uapiListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go dev.IpcHandle(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
logger.Info("UAPI listener started")
|
||||||
|
|
||||||
|
if err = dev.Up(); err != nil {
|
||||||
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
|
}
|
||||||
|
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||||
|
logger.Error("Failed to configure interface: %v", err)
|
||||||
|
}
|
||||||
|
apiServer.SetTunnelIP(wgData.TunnelIP)
|
||||||
|
|
||||||
|
peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
|
func(siteID int, connected bool, rtt time.Duration) {
|
||||||
|
// Find the site config to get endpoint information
|
||||||
|
var endpoint string
|
||||||
|
var isRelay bool
|
||||||
|
for _, site := range wgData.Sites {
|
||||||
|
if site.SiteId == siteID {
|
||||||
|
endpoint = site.Endpoint
|
||||||
|
// TODO: We'll need to track relay status separately
|
||||||
|
// For now, assume not using relay unless we get relay data
|
||||||
|
isRelay = !config.Holepunch
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
|
||||||
|
if connected {
|
||||||
|
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Peer %d is disconnected", siteID)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
fixKey(privateKey.String()),
|
||||||
|
olm,
|
||||||
|
dev,
|
||||||
|
config.Holepunch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := range wgData.Sites {
|
||||||
|
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
|
||||||
|
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
|
||||||
|
|
||||||
|
// Format the endpoint before configuring the peer.
|
||||||
|
site.Endpoint = formatEndpoint(site.Endpoint)
|
||||||
|
|
||||||
|
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to configure peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add route for peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Configured peer %s", site.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerMonitor.Start()
|
||||||
|
|
||||||
|
apiServer.SetRegistered(true)
|
||||||
|
|
||||||
|
connected = true
|
||||||
|
|
||||||
|
logger.Info("WireGuard device created.")
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received update-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var updateData UpdatePeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &updateData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling update data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to SiteConfig
|
||||||
|
siteConfig := SiteConfig{
|
||||||
|
SiteId: updateData.SiteId,
|
||||||
|
Endpoint: updateData.Endpoint,
|
||||||
|
PublicKey: updateData.PublicKey,
|
||||||
|
ServerIP: updateData.ServerIP,
|
||||||
|
ServerPort: updateData.ServerPort,
|
||||||
|
RemoteSubnets: updateData.RemoteSubnets,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the peer in WireGuard
|
||||||
|
if dev != nil {
|
||||||
|
// Find the existing peer to get old data
|
||||||
|
var oldRemoteSubnets string
|
||||||
|
var oldPublicKey string
|
||||||
|
for _, site := range wgData.Sites {
|
||||||
|
if site.SiteId == updateData.SiteId {
|
||||||
|
oldRemoteSubnets = site.RemoteSubnets
|
||||||
|
oldPublicKey = site.PublicKey
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the public key has changed, remove the old peer first
|
||||||
|
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
|
||||||
|
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
|
||||||
|
if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
|
||||||
|
logger.Error("Failed to remove old peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format the endpoint before updating the peer.
|
||||||
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
|
|
||||||
|
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to update peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove old remote subnet routes if they changed
|
||||||
|
if oldRemoteSubnets != siteConfig.RemoteSubnets {
|
||||||
|
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
|
||||||
|
logger.Error("Failed to remove old remote subnet routes: %v", err)
|
||||||
|
// Continue anyway to add new routes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new remote subnet routes
|
||||||
|
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add new remote subnet routes: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update successful
|
||||||
|
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||||
|
for i := range wgData.Sites {
|
||||||
|
if wgData.Sites[i].SiteId == updateData.SiteId {
|
||||||
|
wgData.Sites[i] = siteConfig
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Error("WireGuard device not initialized")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handler for adding a new peer
|
||||||
|
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received add-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var addData AddPeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &addData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling add data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to SiteConfig
|
||||||
|
siteConfig := SiteConfig{
|
||||||
|
SiteId: addData.SiteId,
|
||||||
|
Endpoint: addData.Endpoint,
|
||||||
|
PublicKey: addData.PublicKey,
|
||||||
|
ServerIP: addData.ServerIP,
|
||||||
|
ServerPort: addData.ServerPort,
|
||||||
|
RemoteSubnets: addData.RemoteSubnets,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the peer to WireGuard
|
||||||
|
if dev != nil {
|
||||||
|
// Format the endpoint before adding the new peer.
|
||||||
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
|
|
||||||
|
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to add peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add route for new peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add successful
|
||||||
|
logger.Info("Successfully added peer for site %d", addData.SiteId)
|
||||||
|
|
||||||
|
// Update WgData with the new peer
|
||||||
|
wgData.Sites = append(wgData.Sites, siteConfig)
|
||||||
|
} else {
|
||||||
|
logger.Error("WireGuard device not initialized")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handler for removing a peer
|
||||||
|
olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received remove-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var removeData RemovePeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &removeData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling remove data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the peer to remove
|
||||||
|
var peerToRemove *SiteConfig
|
||||||
|
var newSites []SiteConfig
|
||||||
|
|
||||||
|
for _, site := range wgData.Sites {
|
||||||
|
if site.SiteId == removeData.SiteId {
|
||||||
|
peerToRemove = &site
|
||||||
|
} else {
|
||||||
|
newSites = append(newSites, site)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerToRemove == nil {
|
||||||
|
logger.Error("Peer with site ID %d not found", removeData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the peer from WireGuard
|
||||||
|
if dev != nil {
|
||||||
|
if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
||||||
|
logger.Error("Failed to remove peer: %v", err)
|
||||||
|
// Send error response if needed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove route for the peer
|
||||||
|
err = removeRouteForServerIP(peerToRemove.ServerIP)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to remove route for peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove routes for remote subnets
|
||||||
|
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
|
||||||
|
logger.Error("Failed to remove routes for remote subnets: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove successful
|
||||||
|
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
|
||||||
|
|
||||||
|
// Update WgData to remove the peer
|
||||||
|
wgData.Sites = newSites
|
||||||
|
} else {
|
||||||
|
logger.Error("WireGuard device not initialized")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received relay-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var relayData RelayPeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling relay data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryRelay, err := ResolveDomain(relayData.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update HTTP server to mark this peer as using relay
|
||||||
|
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
||||||
|
|
||||||
|
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
|
||||||
|
logger.Info("Received no-sites message - no sites available for connection")
|
||||||
|
|
||||||
|
if stopRegister != nil {
|
||||||
|
stopRegister()
|
||||||
|
stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("No sites available - stopped registration and holepunch processes")
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
|
||||||
|
logger.Info("Received terminate message")
|
||||||
|
olm.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.OnConnect(func() error {
|
||||||
|
logger.Info("Websocket Connected")
|
||||||
|
|
||||||
|
apiServer.SetConnectionStatus(true)
|
||||||
|
|
||||||
|
if connected {
|
||||||
|
logger.Debug("Already connected, skipping registration")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey := privateKey.PublicKey()
|
||||||
|
|
||||||
|
if stopRegister == nil {
|
||||||
|
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
|
||||||
|
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
"publicKey": publicKey.String(),
|
||||||
|
"relay": !config.Holepunch,
|
||||||
|
"olmVersion": config.Version,
|
||||||
|
"orgId": config.OrgID,
|
||||||
|
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||||
|
}, 1*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
go keepSendingPing(olm)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
olm.OnTokenUpdate(func(token string) {
|
||||||
|
if holePunchManager != nil {
|
||||||
|
holePunchManager.SetToken(token)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Connect to the WebSocket server
|
||||||
|
if err := olm.Connect(); err != nil {
|
||||||
|
logger.Error("Failed to connect to server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer olm.Close()
|
||||||
|
|
||||||
|
// Listen for org switch requests from the API
|
||||||
|
go func() {
|
||||||
|
for req := range apiServer.GetSwitchOrgChannel() {
|
||||||
|
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
||||||
|
|
||||||
|
// Update the config with the new orgId
|
||||||
|
config.OrgID = req.OrgID
|
||||||
|
|
||||||
|
// Mark as not connected to trigger re-registration
|
||||||
|
connected = false
|
||||||
|
|
||||||
|
Stop()
|
||||||
|
|
||||||
|
// Clear peer statuses in API
|
||||||
|
apiServer.SetRegistered(false)
|
||||||
|
apiServer.SetTunnelIP("")
|
||||||
|
apiServer.SetOrgID(config.OrgID)
|
||||||
|
|
||||||
|
// Trigger re-registration with new orgId
|
||||||
|
logger.Info("Re-registering with new orgId: %s", config.OrgID)
|
||||||
|
publicKey := privateKey.PublicKey()
|
||||||
|
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
"publicKey": publicKey.String(),
|
||||||
|
"relay": !config.Holepunch,
|
||||||
|
"olmVersion": config.Version,
|
||||||
|
"orgId": config.OrgID,
|
||||||
|
}, 1*time.Second)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for context cancellation
|
||||||
|
<-tunnelCtx.Done()
|
||||||
|
logger.Info("Tunnel process context cancelled, cleaning up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Stop() {
|
||||||
|
// Stop hole punch manager
|
||||||
|
if holePunchManager != nil {
|
||||||
|
holePunchManager.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopPing != nil {
|
||||||
|
select {
|
||||||
|
case <-stopPing:
|
||||||
|
// Channel already closed
|
||||||
|
default:
|
||||||
|
close(stopPing)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopRegister != nil {
|
||||||
|
stopRegister()
|
||||||
|
stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerMonitor != nil {
|
||||||
|
peerMonitor.Stop()
|
||||||
|
peerMonitor = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if uapiListener != nil {
|
||||||
|
uapiListener.Close()
|
||||||
|
uapiListener = nil
|
||||||
|
}
|
||||||
|
if dev != nil {
|
||||||
|
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
|
||||||
|
dev = nil
|
||||||
|
}
|
||||||
|
// Close TUN device
|
||||||
|
if tdev != nil {
|
||||||
|
tdev.Close()
|
||||||
|
tdev = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release the hole punch reference to the shared bind
|
||||||
|
if sharedBind != nil {
|
||||||
|
// Release hole punch reference (WireGuard already released its reference via dev.Close())
|
||||||
|
logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount())
|
||||||
|
sharedBind.Release()
|
||||||
|
sharedBind = nil
|
||||||
|
logger.Info("Released shared UDP bind")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Olm service stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopTunnel stops just the tunnel process and websocket connection
|
||||||
|
// without shutting down the entire application
|
||||||
|
func StopTunnel() {
|
||||||
|
logger.Info("Stopping tunnel process")
|
||||||
|
|
||||||
|
// Cancel the tunnel context if it exists
|
||||||
|
if tunnelCancel != nil {
|
||||||
|
tunnelCancel()
|
||||||
|
// Give it a moment to clean up
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the websocket connection
|
||||||
|
if olmClient != nil {
|
||||||
|
olmClient.Close()
|
||||||
|
olmClient = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
Stop()
|
||||||
|
|
||||||
|
// Reset the connected state
|
||||||
|
connected = false
|
||||||
|
tunnelRunning = false
|
||||||
|
|
||||||
|
// Update API server status
|
||||||
|
apiServer.SetConnectionStatus(false)
|
||||||
|
apiServer.SetRegistered(false)
|
||||||
|
apiServer.SetTunnelIP("")
|
||||||
|
|
||||||
|
logger.Info("Tunnel process stopped")
|
||||||
|
}
|
||||||
35
olm/unix.go
Normal file
35
olm/unix.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
|
||||||
|
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = unix.SetNonblock(int(fd), true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
file := os.NewFile(uintptr(fd), "")
|
||||||
|
return tun.CreateTUNFromFile(file, mtuInt)
|
||||||
|
}
|
||||||
|
func uapiOpen(interfaceName string) (*os.File, error) {
|
||||||
|
return ipc.UAPIOpen(interfaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||||
|
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||||
|
}
|
||||||
25
olm/windows.go
Normal file
25
olm/windows.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
|
||||||
|
return nil, errors.New("CreateTUNFromFile not supported on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func uapiOpen(interfaceName string) (*os.File, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||||
|
// On Windows, UAPIListen only takes one parameter
|
||||||
|
return ipc.UAPIListen(interfaceName)
|
||||||
|
}
|
||||||
331
peermonitor/peermonitor.go
Normal file
331
peermonitor/peermonitor.go
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
package peermonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
"github.com/fosrl/olm/wgtester"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PeerMonitorCallback is the function type for connection status change callbacks
|
||||||
|
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
|
||||||
|
|
||||||
|
// WireGuardConfig holds the WireGuard configuration for a peer
|
||||||
|
type WireGuardConfig struct {
|
||||||
|
SiteID int
|
||||||
|
PublicKey string
|
||||||
|
ServerIP string
|
||||||
|
Endpoint string
|
||||||
|
PrimaryRelay string // The primary relay endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||||
|
type PeerMonitor struct {
|
||||||
|
monitors map[int]*wgtester.Client
|
||||||
|
configs map[int]*WireGuardConfig
|
||||||
|
callback PeerMonitorCallback
|
||||||
|
mutex sync.Mutex
|
||||||
|
running bool
|
||||||
|
interval time.Duration
|
||||||
|
timeout time.Duration
|
||||||
|
maxAttempts int
|
||||||
|
privateKey string
|
||||||
|
wsClient *websocket.Client
|
||||||
|
device *device.Device
|
||||||
|
handleRelaySwitch bool // Whether to handle relay switching
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||||
|
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
|
||||||
|
return &PeerMonitor{
|
||||||
|
monitors: make(map[int]*wgtester.Client),
|
||||||
|
configs: make(map[int]*WireGuardConfig),
|
||||||
|
callback: callback,
|
||||||
|
interval: 1 * time.Second, // Default check interval
|
||||||
|
timeout: 2500 * time.Millisecond,
|
||||||
|
maxAttempts: 8,
|
||||||
|
privateKey: privateKey,
|
||||||
|
wsClient: wsClient,
|
||||||
|
device: device,
|
||||||
|
handleRelaySwitch: handleRelaySwitch,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetInterval changes how frequently peers are checked
|
||||||
|
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
pm.interval = interval
|
||||||
|
|
||||||
|
// Update interval for all existing monitors
|
||||||
|
for _, client := range pm.monitors {
|
||||||
|
client.SetPacketInterval(interval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTimeout changes the timeout for waiting for responses
|
||||||
|
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
pm.timeout = timeout
|
||||||
|
|
||||||
|
// Update timeout for all existing monitors
|
||||||
|
for _, client := range pm.monitors {
|
||||||
|
client.SetTimeout(timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||||
|
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
pm.maxAttempts = attempts
|
||||||
|
|
||||||
|
// Update max attempts for all existing monitors
|
||||||
|
for _, client := range pm.monitors {
|
||||||
|
client.SetMaxAttempts(attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPeer adds a new peer to monitor
|
||||||
|
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
// Check if we're already monitoring this peer
|
||||||
|
if _, exists := pm.monitors[siteID]; exists {
|
||||||
|
// Update the endpoint instead of creating a new monitor
|
||||||
|
pm.removePeerUnlocked(siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := wgtester.NewClient(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure the client with our settings
|
||||||
|
client.SetPacketInterval(pm.interval)
|
||||||
|
client.SetTimeout(pm.timeout)
|
||||||
|
client.SetMaxAttempts(pm.maxAttempts)
|
||||||
|
|
||||||
|
// Store the client and config
|
||||||
|
pm.monitors[siteID] = client
|
||||||
|
pm.configs[siteID] = wgConfig
|
||||||
|
|
||||||
|
// If monitor is already running, start monitoring this peer
|
||||||
|
if pm.running {
|
||||||
|
siteIDCopy := siteID // Create a copy for the closure
|
||||||
|
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||||
|
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
|
||||||
|
// This function assumes the mutex is already held by the caller
|
||||||
|
func (pm *PeerMonitor) removePeerUnlocked(siteID int) {
|
||||||
|
client, exists := pm.monitors[siteID]
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client.StopMonitor()
|
||||||
|
client.Close()
|
||||||
|
delete(pm.monitors, siteID)
|
||||||
|
delete(pm.configs, siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeer stops monitoring a peer and removes it from the monitor
|
||||||
|
func (pm *PeerMonitor) RemovePeer(siteID int) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
pm.removePeerUnlocked(siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins monitoring all peers
|
||||||
|
func (pm *PeerMonitor) Start() {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
if pm.running {
|
||||||
|
return // Already running
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.running = true
|
||||||
|
|
||||||
|
// Start monitoring all peers
|
||||||
|
for siteID, client := range pm.monitors {
|
||||||
|
siteIDCopy := siteID // Create a copy for the closure
|
||||||
|
err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||||
|
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Info("Started monitoring peer %d\n", siteID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConnectionStatusChange is called when a peer's connection status changes
|
||||||
|
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status wgtester.ConnectionStatus) {
|
||||||
|
// Call the user-provided callback first
|
||||||
|
if pm.callback != nil {
|
||||||
|
pm.callback(siteID, status.Connected, status.RTT)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If disconnected, handle failover
|
||||||
|
if !status.Connected {
|
||||||
|
// Send relay message to the server
|
||||||
|
if pm.wsClient != nil {
|
||||||
|
pm.sendRelay(siteID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFailover handles failover to the relay server when a peer is disconnected
|
||||||
|
func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
config, exists := pm.configs[siteID]
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for IPv6 and format the endpoint correctly
|
||||||
|
formattedEndpoint := relayEndpoint
|
||||||
|
if strings.Contains(relayEndpoint, ":") {
|
||||||
|
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure WireGuard to use the relay
|
||||||
|
wgConfig := fmt.Sprintf(`private_key=%s
|
||||||
|
public_key=%s
|
||||||
|
allowed_ip=%s/32
|
||||||
|
endpoint=%s:21820
|
||||||
|
persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, formattedEndpoint)
|
||||||
|
|
||||||
|
err := pm.device.IpcSet(wgConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to configure WireGuard device: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Adjusted peer %d to point to relay!\n", siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendRelay sends a relay message to the server
|
||||||
|
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||||
|
if !pm.handleRelaySwitch {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if pm.wsClient == nil {
|
||||||
|
return fmt.Errorf("websocket client is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
|
||||||
|
"siteId": siteID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send registration message: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Info("Sent relay message")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops monitoring all peers
|
||||||
|
func (pm *PeerMonitor) Stop() {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
if !pm.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.running = false
|
||||||
|
|
||||||
|
// Stop all monitors
|
||||||
|
for _, client := range pm.monitors {
|
||||||
|
client.StopMonitor()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops monitoring and cleans up resources
|
||||||
|
func (pm *PeerMonitor) Close() {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
// Stop and close all clients
|
||||||
|
for siteID, client := range pm.monitors {
|
||||||
|
client.StopMonitor()
|
||||||
|
client.Close()
|
||||||
|
delete(pm.monitors, siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.running = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeer tests connectivity to a specific peer
|
||||||
|
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
client, exists := pm.monitors[siteID]
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
connected, rtt := client.TestConnection(ctx)
|
||||||
|
return connected, rtt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAllPeers tests connectivity to all peers
|
||||||
|
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||||
|
Connected bool
|
||||||
|
RTT time.Duration
|
||||||
|
} {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
peers := make(map[int]*wgtester.Client, len(pm.monitors))
|
||||||
|
for siteID, client := range pm.monitors {
|
||||||
|
peers[siteID] = client
|
||||||
|
}
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
results := make(map[int]struct {
|
||||||
|
Connected bool
|
||||||
|
RTT time.Duration
|
||||||
|
})
|
||||||
|
for siteID, client := range peers {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||||
|
connected, rtt := client.TestConnection(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
results[siteID] = struct {
|
||||||
|
Connected bool
|
||||||
|
RTT time.Duration
|
||||||
|
}{
|
||||||
|
Connected: connected,
|
||||||
|
RTT: rtt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
334
proxy/manager.go
334
proxy/manager.go
@@ -1,334 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
|
|
||||||
return &ProxyManager{
|
|
||||||
tnet: tnet,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target string) {
|
|
||||||
pm.Lock()
|
|
||||||
defer pm.Unlock()
|
|
||||||
|
|
||||||
logger.Info("Adding target: %s://%s:%d -> %s", protocol, listen, port, target)
|
|
||||||
|
|
||||||
newTarget := ProxyTarget{
|
|
||||||
Protocol: protocol,
|
|
||||||
Listen: listen,
|
|
||||||
Port: port,
|
|
||||||
Target: target,
|
|
||||||
cancel: make(chan struct{}),
|
|
||||||
done: make(chan struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
pm.targets = append(pm.targets, newTarget)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error {
|
|
||||||
pm.Lock()
|
|
||||||
defer pm.Unlock()
|
|
||||||
|
|
||||||
protocol = strings.ToLower(protocol)
|
|
||||||
if protocol != "tcp" && protocol != "udp" {
|
|
||||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, target := range pm.targets {
|
|
||||||
if target.Listen == listen &&
|
|
||||||
target.Port == port &&
|
|
||||||
strings.ToLower(target.Protocol) == protocol {
|
|
||||||
|
|
||||||
// Signal the serving goroutine to stop
|
|
||||||
select {
|
|
||||||
case <-target.cancel:
|
|
||||||
// Channel is already closed, no need to close it again
|
|
||||||
default:
|
|
||||||
close(target.cancel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the appropriate listener/connection based on protocol
|
|
||||||
target.Lock()
|
|
||||||
switch protocol {
|
|
||||||
case "tcp":
|
|
||||||
if target.listener != nil {
|
|
||||||
select {
|
|
||||||
case <-target.cancel:
|
|
||||||
// Listener was already closed by Stop()
|
|
||||||
default:
|
|
||||||
target.listener.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "udp":
|
|
||||||
if target.udpConn != nil {
|
|
||||||
select {
|
|
||||||
case <-target.cancel:
|
|
||||||
// Connection was already closed by Stop()
|
|
||||||
default:
|
|
||||||
target.udpConn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
target.Unlock()
|
|
||||||
|
|
||||||
// Wait for the target to fully stop
|
|
||||||
<-target.done
|
|
||||||
|
|
||||||
// Remove the target from the slice
|
|
||||||
pm.targets = append(pm.targets[:i], pm.targets[i+1:]...)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("target not found for %s %s:%d", protocol, listen, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) Start() error {
|
|
||||||
pm.RLock()
|
|
||||||
defer pm.RUnlock()
|
|
||||||
|
|
||||||
for i := range pm.targets {
|
|
||||||
target := &pm.targets[i]
|
|
||||||
|
|
||||||
target.Lock()
|
|
||||||
// If target is already running, skip it
|
|
||||||
if target.listener != nil || target.udpConn != nil {
|
|
||||||
target.Unlock()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark the target as starting by creating a nil listener/connection
|
|
||||||
// This prevents other goroutines from trying to start it
|
|
||||||
if strings.ToLower(target.Protocol) == "tcp" {
|
|
||||||
target.listener = nil
|
|
||||||
} else {
|
|
||||||
target.udpConn = nil
|
|
||||||
}
|
|
||||||
target.Unlock()
|
|
||||||
|
|
||||||
switch strings.ToLower(target.Protocol) {
|
|
||||||
case "tcp":
|
|
||||||
go pm.serveTCP(target)
|
|
||||||
case "udp":
|
|
||||||
go pm.serveUDP(target)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported protocol: %s", target.Protocol)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) Stop() error {
|
|
||||||
pm.Lock()
|
|
||||||
defer pm.Unlock()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := range pm.targets {
|
|
||||||
target := &pm.targets[i]
|
|
||||||
wg.Add(1)
|
|
||||||
go func(t *ProxyTarget) {
|
|
||||||
defer wg.Done()
|
|
||||||
close(t.cancel)
|
|
||||||
t.Lock()
|
|
||||||
if t.listener != nil {
|
|
||||||
t.listener.Close()
|
|
||||||
}
|
|
||||||
if t.udpConn != nil {
|
|
||||||
t.udpConn.Close()
|
|
||||||
}
|
|
||||||
t.Unlock()
|
|
||||||
// Wait for the target to fully stop
|
|
||||||
<-t.done
|
|
||||||
}(target)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
|
|
||||||
defer close(target.done) // Signal that this target is fully stopped
|
|
||||||
|
|
||||||
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{
|
|
||||||
IP: net.ParseIP(target.Listen),
|
|
||||||
Port: target.Port,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
target.Lock()
|
|
||||||
target.listener = listener
|
|
||||||
target.Unlock()
|
|
||||||
|
|
||||||
defer listener.Close()
|
|
||||||
logger.Info("TCP proxy listening on %s", listener.Addr())
|
|
||||||
|
|
||||||
var activeConns sync.WaitGroup
|
|
||||||
acceptDone := make(chan struct{})
|
|
||||||
|
|
||||||
// Goroutine to handle shutdown signal
|
|
||||||
go func() {
|
|
||||||
<-target.cancel
|
|
||||||
close(acceptDone)
|
|
||||||
listener.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
select {
|
|
||||||
case <-target.cancel:
|
|
||||||
// Wait for active connections to finish
|
|
||||||
activeConns.Wait()
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
logger.Info("Failed to accept TCP connection: %v", err)
|
|
||||||
// Don't return here, try to accept new connections
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
activeConns.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer activeConns.Done()
|
|
||||||
pm.handleTCPConnection(conn, target.Target, acceptDone)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string, done chan struct{}) {
|
|
||||||
defer clientConn.Close()
|
|
||||||
|
|
||||||
serverConn, err := net.Dial("tcp", target)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to connect to target %s: %v", target, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer serverConn.Close()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
// Client -> Server
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
io.Copy(serverConn, clientConn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Server -> Client
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
io.Copy(clientConn, serverConn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
|
|
||||||
defer close(target.done) // Signal that this target is fully stopped
|
|
||||||
|
|
||||||
addr := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP(target.Listen),
|
|
||||||
Port: target.Port,
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := pm.tnet.ListenUDP(addr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to start UDP listener for %s:%d: %v", target.Listen, target.Port, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
target.Lock()
|
|
||||||
target.udpConn = conn
|
|
||||||
target.Unlock()
|
|
||||||
|
|
||||||
defer conn.Close()
|
|
||||||
logger.Info("UDP proxy listening on %s", conn.LocalAddr())
|
|
||||||
|
|
||||||
buffer := make([]byte, 65535)
|
|
||||||
var activeConns sync.WaitGroup
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-target.cancel:
|
|
||||||
activeConns.Wait() // Wait for all active UDP handlers to complete
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
n, remoteAddr, err := conn.ReadFrom(buffer)
|
|
||||||
if err != nil {
|
|
||||||
select {
|
|
||||||
case <-target.cancel:
|
|
||||||
activeConns.Wait()
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
logger.Info("Failed to read UDP packet: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
targetAddr, err := net.ResolveUDPAddr("udp", target.Target)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to resolve target address %s: %v", target.Target, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
activeConns.Add(1)
|
|
||||||
go func(data []byte, remote net.Addr) {
|
|
||||||
defer activeConns.Done()
|
|
||||||
targetConn, err := net.DialUDP("udp", nil, targetAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to connect to target %s: %v", target.Target, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer targetConn.Close()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-target.cancel:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
_, err = targetConn.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to write to target: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response := make([]byte, 65535)
|
|
||||||
n, err := targetConn.Read(response)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to read response from target: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = conn.WriteTo(response[:n], remote)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to write response to client: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(buffer[:n], remoteAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ProxyTarget struct {
|
|
||||||
Protocol string
|
|
||||||
Listen string
|
|
||||||
Port int
|
|
||||||
Target string
|
|
||||||
cancel chan struct{} // Channel to signal shutdown
|
|
||||||
done chan struct{} // Channel to signal completion
|
|
||||||
listener net.Listener // For TCP
|
|
||||||
udpConn net.PacketConn // For UDP
|
|
||||||
sync.Mutex // Protect access to connection
|
|
||||||
}
|
|
||||||
|
|
||||||
type ProxyManager struct {
|
|
||||||
targets []ProxyTarget
|
|
||||||
tnet *netstack.Net
|
|
||||||
log *log.Logger
|
|
||||||
sync.RWMutex // Protect access to targets slice
|
|
||||||
}
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 774 KiB |
54
service_unix.go
Normal file
54
service_unix.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Service management functions are not available on non-Windows platforms
|
||||||
|
func installService() error {
|
||||||
|
return fmt.Errorf("service management is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeService() error {
|
||||||
|
return fmt.Errorf("service management is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func startService(args []string) error {
|
||||||
|
_ = args // unused on Unix platforms
|
||||||
|
return fmt.Errorf("service management is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopService() error {
|
||||||
|
return fmt.Errorf("service management is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getServiceStatus() (string, error) {
|
||||||
|
return "", fmt.Errorf("service management is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func debugService(args []string) error {
|
||||||
|
_ = args // unused on Unix platforms
|
||||||
|
return fmt.Errorf("debug service is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWindowsService() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func runService(name string, isDebug bool, args []string) {
|
||||||
|
// No-op on non-Windows platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupWindowsEventLog() {
|
||||||
|
// No-op on non-Windows platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func watchLogFile(end bool) error {
|
||||||
|
return fmt.Errorf("watching log file is only available on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func showServiceConfig() {
|
||||||
|
fmt.Println("Service configuration is only available on Windows")
|
||||||
|
}
|
||||||
631
service_windows.go
Normal file
631
service_windows.go
Normal file
@@ -0,0 +1,631 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"golang.org/x/sys/windows/svc"
|
||||||
|
"golang.org/x/sys/windows/svc/debug"
|
||||||
|
"golang.org/x/sys/windows/svc/eventlog"
|
||||||
|
"golang.org/x/sys/windows/svc/mgr"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
serviceName = "OlmWireguardService"
|
||||||
|
serviceDisplayName = "Olm WireGuard VPN Service"
|
||||||
|
serviceDescription = "Olm WireGuard VPN client service for secure network connectivity"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Global variable to store service arguments
|
||||||
|
var serviceArgs []string
|
||||||
|
|
||||||
|
// getServiceArgsPath returns the path where service arguments are stored
|
||||||
|
func getServiceArgsPath() string {
|
||||||
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm")
|
||||||
|
return filepath.Join(logDir, "service_args.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveServiceArgs saves the service arguments to a file
|
||||||
|
func saveServiceArgs(args []string) error {
|
||||||
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm")
|
||||||
|
err := os.MkdirAll(logDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create config directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
argsPath := getServiceArgsPath()
|
||||||
|
data, err := json.Marshal(args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal service args: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.WriteFile(argsPath, data, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write service args: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadServiceArgs loads the service arguments from a file
|
||||||
|
func loadServiceArgs() ([]string, error) {
|
||||||
|
argsPath := getServiceArgsPath()
|
||||||
|
data, err := os.ReadFile(argsPath)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return []string{}, nil // Return empty args if file doesn't exist
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to read service args: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var args []string
|
||||||
|
err = json.Unmarshal(data, &args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal service args: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type olmService struct {
|
||||||
|
elog debug.Log
|
||||||
|
ctx context.Context
|
||||||
|
stop context.CancelFunc
|
||||||
|
args []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) {
|
||||||
|
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
|
||||||
|
changes <- svc.Status{State: svc.StartPending}
|
||||||
|
|
||||||
|
s.elog.Info(1, fmt.Sprintf("Service Execute called with args: %v", args))
|
||||||
|
|
||||||
|
// Load saved service arguments
|
||||||
|
savedArgs, err := loadServiceArgs()
|
||||||
|
if err != nil {
|
||||||
|
s.elog.Error(1, fmt.Sprintf("Failed to load service args: %v", err))
|
||||||
|
// Continue with empty args if loading fails
|
||||||
|
savedArgs = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine service start args with saved args, giving priority to service start args
|
||||||
|
finalArgs := []string{}
|
||||||
|
if len(args) > 0 {
|
||||||
|
// Skip the first arg which is typically the service name
|
||||||
|
if len(args) > 1 {
|
||||||
|
finalArgs = append(finalArgs, args[1:]...)
|
||||||
|
}
|
||||||
|
s.elog.Info(1, fmt.Sprintf("Using service start parameters: %v", finalArgs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no service start parameters, use saved args
|
||||||
|
if len(finalArgs) == 0 && len(savedArgs) > 0 {
|
||||||
|
finalArgs = savedArgs
|
||||||
|
s.elog.Info(1, fmt.Sprintf("Using saved service args: %v", finalArgs))
|
||||||
|
}
|
||||||
|
|
||||||
|
s.args = finalArgs
|
||||||
|
|
||||||
|
// Start the main olm functionality
|
||||||
|
olmDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
s.runOlm()
|
||||||
|
close(olmDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
|
||||||
|
s.elog.Info(1, "Service status set to Running")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case c := <-r:
|
||||||
|
switch c.Cmd {
|
||||||
|
case svc.Interrogate:
|
||||||
|
changes <- c.CurrentStatus
|
||||||
|
case svc.Stop, svc.Shutdown:
|
||||||
|
s.elog.Info(1, "Service stopping")
|
||||||
|
changes <- svc.Status{State: svc.StopPending}
|
||||||
|
if s.stop != nil {
|
||||||
|
s.stop()
|
||||||
|
}
|
||||||
|
// Wait for main logic to finish or timeout
|
||||||
|
select {
|
||||||
|
case <-olmDone:
|
||||||
|
s.elog.Info(1, "Main logic finished gracefully")
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
s.elog.Info(1, "Timeout waiting for main logic to finish")
|
||||||
|
}
|
||||||
|
return false, 0
|
||||||
|
default:
|
||||||
|
s.elog.Error(1, fmt.Sprintf("Unexpected control request #%d", c))
|
||||||
|
}
|
||||||
|
case <-olmDone:
|
||||||
|
s.elog.Info(1, "Main olm logic completed, stopping service")
|
||||||
|
changes <- svc.Status{State: svc.StopPending}
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *olmService) runOlm() {
|
||||||
|
// Create a context that can be cancelled when the service stops
|
||||||
|
s.ctx, s.stop = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
// Setup logging for service mode
|
||||||
|
s.elog.Info(1, "Starting Olm main logic")
|
||||||
|
|
||||||
|
// Run the main olm logic and wait for it to complete
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
s.elog.Error(1, fmt.Sprintf("Olm panic: %v", r))
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Call the main olm function with stored arguments
|
||||||
|
runOlmMainWithArgs(s.ctx, s.args)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for either context cancellation or main logic completion
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
s.elog.Info(1, "Olm service context cancelled")
|
||||||
|
case <-done:
|
||||||
|
s.elog.Info(1, "Olm main logic completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runService(name string, isDebug bool, args []string) {
|
||||||
|
var err error
|
||||||
|
var elog debug.Log
|
||||||
|
|
||||||
|
if isDebug {
|
||||||
|
elog = debug.New(name)
|
||||||
|
fmt.Printf("Starting %s service in debug mode\n", name)
|
||||||
|
} else {
|
||||||
|
elog, err = eventlog.Open(name)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to open event log: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer elog.Close()
|
||||||
|
|
||||||
|
elog.Info(1, fmt.Sprintf("Starting %s service", name))
|
||||||
|
run := svc.Run
|
||||||
|
if isDebug {
|
||||||
|
run = debug.Run
|
||||||
|
}
|
||||||
|
|
||||||
|
service := &olmService{elog: elog, args: args}
|
||||||
|
err = run(name, service)
|
||||||
|
if err != nil {
|
||||||
|
elog.Error(1, fmt.Sprintf("%s service failed: %v", name, err))
|
||||||
|
if isDebug {
|
||||||
|
fmt.Printf("Service failed: %v\n", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
elog.Info(1, fmt.Sprintf("%s service stopped", name))
|
||||||
|
if isDebug {
|
||||||
|
fmt.Printf("%s service stopped\n", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func installService() error {
|
||||||
|
exepath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get executable path: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
|
||||||
|
s, err := m.OpenService(serviceName)
|
||||||
|
if err == nil {
|
||||||
|
s.Close()
|
||||||
|
return fmt.Errorf("service %s already exists", serviceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := mgr.Config{
|
||||||
|
ServiceType: 0x10, // SERVICE_WIN32_OWN_PROCESS
|
||||||
|
StartType: mgr.StartManual,
|
||||||
|
ErrorControl: mgr.ErrorNormal,
|
||||||
|
DisplayName: serviceDisplayName,
|
||||||
|
Description: serviceDescription,
|
||||||
|
BinaryPathName: exepath,
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err = m.CreateService(serviceName, exepath, config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
err = eventlog.InstallAsEventCreate(serviceName, eventlog.Error|eventlog.Warning|eventlog.Info)
|
||||||
|
if err != nil {
|
||||||
|
s.Delete()
|
||||||
|
return fmt.Errorf("failed to install event log: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeService() error {
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
|
||||||
|
s, err := m.OpenService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("service %s is not installed", serviceName)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
// Stop the service if it's running
|
||||||
|
status, err := s.Query()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to query service status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.State != svc.Stopped {
|
||||||
|
_, err = s.Control(svc.Stop)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to stop service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for service to stop
|
||||||
|
timeout := time.Now().Add(30 * time.Second)
|
||||||
|
for status.State != svc.Stopped {
|
||||||
|
if timeout.Before(time.Now()) {
|
||||||
|
return fmt.Errorf("timeout waiting for service to stop")
|
||||||
|
}
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
status, err = s.Query()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to query service status: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.Delete()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = eventlog.Remove(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove event log: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func startService(args []string) error {
|
||||||
|
// Save the service arguments as backup
|
||||||
|
if len(args) > 0 {
|
||||||
|
err := saveServiceArgs(args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to save service args: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
|
||||||
|
s, err := m.OpenService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("service %s is not installed", serviceName)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
// Pass arguments directly to the service start call
|
||||||
|
err = s.Start(args...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopService() error {
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
|
||||||
|
s, err := m.OpenService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("service %s is not installed", serviceName)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
status, err := s.Control(svc.Stop)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to stop service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.Now().Add(30 * time.Second)
|
||||||
|
for status.State != svc.Stopped {
|
||||||
|
if timeout.Before(time.Now()) {
|
||||||
|
return fmt.Errorf("timeout waiting for service to stop")
|
||||||
|
}
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
status, err = s.Query()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to query service status: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func debugService(args []string) error {
|
||||||
|
// Save the service arguments before starting
|
||||||
|
if len(args) > 0 {
|
||||||
|
err := saveServiceArgs(args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to save service args: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the service with the provided arguments
|
||||||
|
err := startService(args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watch the log file
|
||||||
|
return watchLogFile(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func watchLogFile(end bool) error {
|
||||||
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
|
||||||
|
logPath := filepath.Join(logDir, "olm.log")
|
||||||
|
|
||||||
|
// Ensure the log directory exists
|
||||||
|
err := os.MkdirAll(logDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create log directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the log file to be created if it doesn't exist
|
||||||
|
var file *os.File
|
||||||
|
for i := 0; i < 30; i++ { // Wait up to 15 seconds
|
||||||
|
file, err = os.Open(logPath)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if i == 0 {
|
||||||
|
fmt.Printf("Waiting for log file to be created...\n")
|
||||||
|
}
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open log file after waiting: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// Seek to the end of the file to only show new logs
|
||||||
|
_, err = file.Seek(0, 2)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to seek to end of file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up signal handling for graceful exit
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Create a ticker to check for new content
|
||||||
|
ticker := time.NewTicker(500 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
buffer := make([]byte, 4096)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-sigCh:
|
||||||
|
fmt.Printf("\n\nStopping log watch...\n")
|
||||||
|
// stop the service if needed
|
||||||
|
if end {
|
||||||
|
if err := stopService(); err != nil {
|
||||||
|
fmt.Printf("Failed to stop service: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf("Log watch stopped.\n")
|
||||||
|
return nil
|
||||||
|
case <-ticker.C:
|
||||||
|
// Read new content
|
||||||
|
n, err := file.Read(buffer)
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
// Try to reopen the file in case it was recreated
|
||||||
|
file.Close()
|
||||||
|
file, err = os.Open(logPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error reopening log file: %v", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
// Print the new content
|
||||||
|
fmt.Print(string(buffer[:n]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getServiceStatus() (string, error) {
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to connect to service manager: %v", err)
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
|
||||||
|
s, err := m.OpenService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return "Not Installed", nil
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
status, err := s.Query()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to query service status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch status.State {
|
||||||
|
case svc.Stopped:
|
||||||
|
return "Stopped", nil
|
||||||
|
case svc.StartPending:
|
||||||
|
return "Starting", nil
|
||||||
|
case svc.StopPending:
|
||||||
|
return "Stopping", nil
|
||||||
|
case svc.Running:
|
||||||
|
return "Running", nil
|
||||||
|
case svc.ContinuePending:
|
||||||
|
return "Continue Pending", nil
|
||||||
|
case svc.PausePending:
|
||||||
|
return "Pause Pending", nil
|
||||||
|
case svc.Paused:
|
||||||
|
return "Paused", nil
|
||||||
|
default:
|
||||||
|
return "Unknown", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// showServiceConfig displays current saved service configuration
|
||||||
|
func showServiceConfig() {
|
||||||
|
configPath := getServiceArgsPath()
|
||||||
|
fmt.Printf("Service configuration file: %s\n", configPath)
|
||||||
|
|
||||||
|
args, err := loadServiceArgs()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("No saved configuration found or error loading: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) == 0 {
|
||||||
|
fmt.Println("No saved service arguments found")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Saved service arguments: %v\n", args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWindowsService() bool {
|
||||||
|
isWindowsService, err := svc.IsWindowsService()
|
||||||
|
return err == nil && isWindowsService
|
||||||
|
}
|
||||||
|
|
||||||
|
// rotateLogFile handles daily log rotation
|
||||||
|
func rotateLogFile(logDir string, logFile string) error {
|
||||||
|
// Get current log file info
|
||||||
|
info, err := os.Stat(logFile)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil // No current log file to rotate
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to stat log file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if log file is from today
|
||||||
|
now := time.Now()
|
||||||
|
fileTime := info.ModTime()
|
||||||
|
|
||||||
|
// If the log file is from today, no rotation needed
|
||||||
|
if now.Year() == fileTime.Year() && now.YearDay() == fileTime.YearDay() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create rotated filename with date
|
||||||
|
rotatedName := fmt.Sprintf("olm-%s.log", fileTime.Format("2006-01-02"))
|
||||||
|
rotatedPath := filepath.Join(logDir, rotatedName)
|
||||||
|
|
||||||
|
// Rename current log file to dated filename
|
||||||
|
err = os.Rename(logFile, rotatedPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to rotate log file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up old log files (keep last 30 days)
|
||||||
|
cleanupOldLogFiles(logDir, 30)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupOldLogFiles removes log files older than specified days
|
||||||
|
func cleanupOldLogFiles(logDir string, daysToKeep int) {
|
||||||
|
cutoff := time.Now().AddDate(0, 0, -daysToKeep)
|
||||||
|
|
||||||
|
files, err := os.ReadDir(logDir)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
if !file.IsDir() && strings.HasPrefix(file.Name(), "olm-") && strings.HasSuffix(file.Name(), ".log") {
|
||||||
|
filePath := filepath.Join(logDir, file.Name())
|
||||||
|
info, err := file.Info()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.ModTime().Before(cutoff) {
|
||||||
|
os.Remove(filePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupWindowsEventLog() {
|
||||||
|
// Create log directory if it doesn't exist
|
||||||
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
|
||||||
|
err := os.MkdirAll(logDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to create log directory: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logFile := filepath.Join(logDir, "olm.log")
|
||||||
|
|
||||||
|
// Rotate log file if needed
|
||||||
|
err = rotateLogFile(logDir, logFile)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to rotate log file: %v\n", err)
|
||||||
|
// Continue anyway to create new log file
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to open log file: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the custom logger output
|
||||||
|
logger.GetLogger().SetOutput(file)
|
||||||
|
|
||||||
|
log.Printf("Olm service logging initialized - log file: %s", logFile)
|
||||||
|
}
|
||||||
@@ -2,38 +2,81 @@ package websocket
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"software.sslmate.com/src/go-pkcs12"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type TokenResponse struct {
|
||||||
conn *websocket.Conn
|
Data struct {
|
||||||
config *Config
|
Token string `json:"token"`
|
||||||
baseURL string
|
} `json:"data"`
|
||||||
handlers map[string]MessageHandler
|
Success bool `json:"success"`
|
||||||
done chan struct{}
|
Message string `json:"message"`
|
||||||
handlersMux sync.RWMutex
|
}
|
||||||
|
|
||||||
|
type WSMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Data interface{} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is not json anymore
|
||||||
|
type Config struct {
|
||||||
|
ID string
|
||||||
|
Secret string
|
||||||
|
Endpoint string
|
||||||
|
TlsClientCert string // legacy PKCS12 file path
|
||||||
|
UserToken string // optional user token for websocket authentication
|
||||||
|
}
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
config *Config
|
||||||
|
conn *websocket.Conn
|
||||||
|
baseURL string
|
||||||
|
handlers map[string]MessageHandler
|
||||||
|
done chan struct{}
|
||||||
|
handlersMux sync.RWMutex
|
||||||
reconnectInterval time.Duration
|
reconnectInterval time.Duration
|
||||||
isConnected bool
|
isConnected bool
|
||||||
reconnectMux sync.RWMutex
|
reconnectMux sync.RWMutex
|
||||||
|
pingInterval time.Duration
|
||||||
onConnect func() error
|
pingTimeout time.Duration
|
||||||
|
onConnect func() error
|
||||||
|
onTokenUpdate func(token string)
|
||||||
|
writeMux sync.Mutex
|
||||||
|
clientType string // Type of client (e.g., "newt", "olm")
|
||||||
|
tlsConfig TLSConfig
|
||||||
|
configNeedsSave bool // Flag to track if config needs to be saved
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientOption func(*Client)
|
type ClientOption func(*Client)
|
||||||
|
|
||||||
type MessageHandler func(message WSMessage)
|
type MessageHandler func(message WSMessage)
|
||||||
|
|
||||||
|
// TLSConfig holds TLS configuration options
|
||||||
|
type TLSConfig struct {
|
||||||
|
// New separate certificate support
|
||||||
|
ClientCertFile string
|
||||||
|
ClientKeyFile string
|
||||||
|
CAFiles []string
|
||||||
|
|
||||||
|
// Existing PKCS12 support (deprecated)
|
||||||
|
PKCS12File string
|
||||||
|
}
|
||||||
|
|
||||||
// WithBaseURL sets the base URL for the client
|
// WithBaseURL sets the base URL for the client
|
||||||
func WithBaseURL(url string) ClientOption {
|
func WithBaseURL(url string) ClientOption {
|
||||||
return func(c *Client) {
|
return func(c *Client) {
|
||||||
@@ -41,16 +84,32 @@ func WithBaseURL(url string) ClientOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithTLSConfig sets the TLS configuration for the client
|
||||||
|
func WithTLSConfig(config TLSConfig) ClientOption {
|
||||||
|
return func(c *Client) {
|
||||||
|
c.tlsConfig = config
|
||||||
|
// For backward compatibility, also set the legacy field
|
||||||
|
if config.PKCS12File != "" {
|
||||||
|
c.config.TlsClientCert = config.PKCS12File
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) OnConnect(callback func() error) {
|
func (c *Client) OnConnect(callback func() error) {
|
||||||
c.onConnect = callback
|
c.onConnect = callback
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new Newt client
|
func (c *Client) OnTokenUpdate(callback func(token string)) {
|
||||||
func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
|
c.onTokenUpdate = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new websocket client
|
||||||
|
func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
NewtID: newtID,
|
ID: ID,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
|
UserToken: userToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &Client{
|
client := &Client{
|
||||||
@@ -58,39 +117,59 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
|
|||||||
baseURL: endpoint, // default value
|
baseURL: endpoint, // default value
|
||||||
handlers: make(map[string]MessageHandler),
|
handlers: make(map[string]MessageHandler),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
reconnectInterval: 10 * time.Second,
|
reconnectInterval: 3 * time.Second,
|
||||||
isConnected: false,
|
isConnected: false,
|
||||||
|
pingInterval: pingInterval,
|
||||||
|
pingTimeout: pingTimeout,
|
||||||
|
clientType: "olm",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options before loading config
|
// Apply options before loading config
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
|
if opt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
opt(client)
|
opt(client)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load existing config if available
|
|
||||||
if err := client.loadConfig(); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) GetConfig() *Config {
|
||||||
|
return c.config
|
||||||
|
}
|
||||||
|
|
||||||
// Connect establishes the WebSocket connection
|
// Connect establishes the WebSocket connection
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
go c.connectWithRetry()
|
go c.connectWithRetry()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the WebSocket connection
|
// Close closes the WebSocket connection gracefully
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
close(c.done)
|
// Signal shutdown to all goroutines first
|
||||||
if c.conn != nil {
|
select {
|
||||||
return c.conn.Close()
|
case <-c.done:
|
||||||
|
// Already closed
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
close(c.done)
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop the ping monitor
|
// Set connection status to false
|
||||||
c.setConnected(false)
|
c.setConnected(false)
|
||||||
|
|
||||||
|
// Close the WebSocket connection gracefully
|
||||||
|
if c.conn != nil {
|
||||||
|
// Send close message
|
||||||
|
c.writeMux.Lock()
|
||||||
|
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||||
|
c.writeMux.Unlock()
|
||||||
|
|
||||||
|
// Close the connection
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,9 +184,49 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
|
|||||||
Data: data,
|
Data: data,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Debug("Sending message: %s, data: %+v", messageType, data)
|
||||||
|
|
||||||
|
c.writeMux.Lock()
|
||||||
|
defer c.writeMux.Unlock()
|
||||||
return c.conn.WriteJSON(msg)
|
return c.conn.WriteJSON(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
|
||||||
|
stopChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
count := 0
|
||||||
|
maxAttempts := 10
|
||||||
|
|
||||||
|
err := c.SendMessage(messageType, data) // Send immediately
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send initial message: %v", err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
if count >= maxAttempts {
|
||||||
|
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = c.SendMessage(messageType, data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send message: %v", err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
case <-stopChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return func() {
|
||||||
|
close(stopChan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterHandler registers a handler for a specific message type
|
// RegisterHandler registers a handler for a specific message type
|
||||||
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
||||||
c.handlersMux.Lock()
|
c.handlersMux.Lock()
|
||||||
@@ -115,30 +234,6 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
|||||||
c.handlers[messageType] = handler
|
c.handlers[messageType] = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// readPump pumps messages from the WebSocket connection
|
|
||||||
func (c *Client) readPump() {
|
|
||||||
defer c.conn.Close()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
var msg WSMessage
|
|
||||||
err := c.conn.ReadJSON(&msg)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.handlersMux.RLock()
|
|
||||||
if handler, ok := c.handlers[msg.Type]; ok {
|
|
||||||
handler(msg)
|
|
||||||
}
|
|
||||||
c.handlersMux.RUnlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) getToken() (string, error) {
|
func (c *Client) getToken() (string, error) {
|
||||||
// Parse the base URL to ensure we have the correct hostname
|
// Parse the base URL to ensure we have the correct hostname
|
||||||
baseURL, err := url.Parse(c.baseURL)
|
baseURL, err := url.Parse(c.baseURL)
|
||||||
@@ -149,57 +244,33 @@ func (c *Client) getToken() (string, error) {
|
|||||||
// Ensure we have the base URL without trailing slashes
|
// Ensure we have the base URL without trailing slashes
|
||||||
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
||||||
|
|
||||||
// If we already have a token, try to use it
|
var tlsConfig *tls.Config = nil
|
||||||
if c.config.Token != "" {
|
|
||||||
tokenCheckData := map[string]interface{}{
|
// Use new TLS configuration method
|
||||||
"newtId": c.config.NewtID,
|
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||||
"secret": c.config.Secret,
|
tlsConfig, err = c.setupTLS()
|
||||||
"token": c.config.Token,
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(tokenCheckData)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to marshal token check data: %w", err)
|
return "", fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new request
|
|
||||||
req, err := http.NewRequest(
|
|
||||||
"POST",
|
|
||||||
baseEndpoint+"/api/v1/auth/newt/get-token",
|
|
||||||
bytes.NewBuffer(jsonData),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to create request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set headers
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
|
||||||
|
|
||||||
// Make the request
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to check token validity: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var tokenResp TokenResponse
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
||||||
return "", fmt.Errorf("failed to decode token check response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If token is still valid, return it
|
|
||||||
if tokenResp.Success && tokenResp.Message == "Token session already valid" {
|
|
||||||
return c.config.Token, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a new token
|
// Check for environment variable to skip TLS verification
|
||||||
tokenData := map[string]interface{}{
|
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
|
||||||
"newtId": c.config.NewtID,
|
if tlsConfig == nil {
|
||||||
|
tlsConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenData map[string]interface{}
|
||||||
|
|
||||||
|
tokenData = map[string]interface{}{
|
||||||
|
"olmId": c.config.ID,
|
||||||
"secret": c.config.Secret,
|
"secret": c.config.Secret,
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(tokenData)
|
jsonData, err := json.Marshal(tokenData)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to marshal token request data: %w", err)
|
return "", fmt.Errorf("failed to marshal token request data: %w", err)
|
||||||
}
|
}
|
||||||
@@ -207,7 +278,7 @@ func (c *Client) getToken() (string, error) {
|
|||||||
// Create a new request
|
// Create a new request
|
||||||
req, err := http.NewRequest(
|
req, err := http.NewRequest(
|
||||||
"POST",
|
"POST",
|
||||||
baseEndpoint+"/api/v1/auth/newt/get-token",
|
baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token",
|
||||||
bytes.NewBuffer(jsonData),
|
bytes.NewBuffer(jsonData),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -220,14 +291,26 @@ func (c *Client) getToken() (string, error) {
|
|||||||
|
|
||||||
// Make the request
|
// Make the request
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
|
if tlsConfig != nil {
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
TLSClientConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to request new token: %w", err)
|
return "", fmt.Errorf("failed to request new token: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||||
|
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
var tokenResp TokenResponse
|
var tokenResp TokenResponse
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||||
|
logger.Error("Failed to decode token response.")
|
||||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,6 +322,8 @@ func (c *Client) getToken() (string, error) {
|
|||||||
return "", fmt.Errorf("received empty token from server")
|
return "", fmt.Errorf("received empty token from server")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
||||||
|
|
||||||
return tokenResp.Data.Token, nil
|
return tokenResp.Data.Token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,6 +351,10 @@ func (c *Client) establishConnection() error {
|
|||||||
return fmt.Errorf("failed to get token: %w", err)
|
return fmt.Errorf("failed to get token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.onTokenUpdate != nil {
|
||||||
|
c.onTokenUpdate(token)
|
||||||
|
}
|
||||||
|
|
||||||
// Parse the base URL to determine protocol and hostname
|
// Parse the base URL to determine protocol and hostname
|
||||||
baseURL, err := url.Parse(c.baseURL)
|
baseURL, err := url.Parse(c.baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -288,10 +377,35 @@ func (c *Client) establishConnection() error {
|
|||||||
// Add token to query parameters
|
// Add token to query parameters
|
||||||
q := u.Query()
|
q := u.Query()
|
||||||
q.Set("token", token)
|
q.Set("token", token)
|
||||||
|
q.Set("clientType", c.clientType)
|
||||||
|
if c.config.UserToken != "" {
|
||||||
|
q.Set("userToken", c.config.UserToken)
|
||||||
|
}
|
||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
// Connect to WebSocket
|
// Connect to WebSocket
|
||||||
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
dialer := websocket.DefaultDialer
|
||||||
|
|
||||||
|
// Use new TLS configuration method
|
||||||
|
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||||
|
logger.Info("Setting up TLS configuration for WebSocket connection")
|
||||||
|
tlsConfig, err := c.setupTLS()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||||
|
}
|
||||||
|
dialer.TLSClientConfig = tlsConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for environment variable to skip TLS verification for WebSocket connection
|
||||||
|
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
|
||||||
|
if dialer.TLSClientConfig == nil {
|
||||||
|
dialer.TLSClientConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
dialer.TLSClientConfig.InsecureSkipVerify = true
|
||||||
|
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, _, err := dialer.Dial(u.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||||
}
|
}
|
||||||
@@ -301,8 +415,8 @@ func (c *Client) establishConnection() error {
|
|||||||
|
|
||||||
// Start the ping monitor
|
// Start the ping monitor
|
||||||
go c.pingMonitor()
|
go c.pingMonitor()
|
||||||
// Start the read pump
|
// Start the read pump with disconnect detection
|
||||||
go c.readPump()
|
go c.readPumpWithDisconnectDetection()
|
||||||
|
|
||||||
if c.onConnect != nil {
|
if c.onConnect != nil {
|
||||||
if err := c.onConnect(); err != nil {
|
if err := c.onConnect(); err != nil {
|
||||||
@@ -313,8 +427,72 @@ func (c *Client) establishConnection() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupTLS configures TLS based on the TLS configuration
|
||||||
|
func (c *Client) setupTLS() (*tls.Config, error) {
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
|
||||||
|
// Handle new separate certificate configuration
|
||||||
|
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
|
||||||
|
logger.Info("Loading separate certificate files for mTLS")
|
||||||
|
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||||
|
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||||
|
|
||||||
|
// Load client certificate and key
|
||||||
|
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load client certificate pair: %w", err)
|
||||||
|
}
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||||
|
|
||||||
|
// Load CA certificates for remote validation if specified
|
||||||
|
if len(c.tlsConfig.CAFiles) > 0 {
|
||||||
|
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
for _, caFile := range c.tlsConfig.CAFiles {
|
||||||
|
caCert, err := os.ReadFile(caFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read CA file %s: %w", caFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as PEM first, then DER
|
||||||
|
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||||
|
// If PEM parsing failed, try DER
|
||||||
|
cert, err := x509.ParseCertificate(caCert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse CA certificate from %s: %w", caFile, err)
|
||||||
|
}
|
||||||
|
caCertPool.AddCert(cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tlsConfig.RootCAs = caCertPool
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to existing PKCS12 implementation for backward compatibility
|
||||||
|
if c.tlsConfig.PKCS12File != "" {
|
||||||
|
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
|
||||||
|
return c.setupPKCS12TLS()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy fallback using config.TlsClientCert
|
||||||
|
if c.config.TlsClientCert != "" {
|
||||||
|
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||||
|
return loadClientCertificate(c.config.TlsClientCert)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupPKCS12TLS loads TLS configuration from PKCS12 file
|
||||||
|
func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
|
||||||
|
return loadClientCertificate(c.tlsConfig.PKCS12File)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pingMonitor sends pings at a short interval and triggers reconnect on failure
|
||||||
func (c *Client) pingMonitor() {
|
func (c *Client) pingMonitor() {
|
||||||
ticker := time.NewTicker(30 * time.Second)
|
ticker := time.NewTicker(c.pingInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -322,11 +500,74 @@ func (c *Client) pingMonitor() {
|
|||||||
case <-c.done:
|
case <-c.done:
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
|
if c.conn == nil {
|
||||||
logger.Error("Ping failed: %v", err)
|
|
||||||
c.reconnect()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.writeMux.Lock()
|
||||||
|
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
|
||||||
|
c.writeMux.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
// Check if we're shutting down before logging error and reconnecting
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
// Expected during shutdown
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
logger.Error("Ping failed: %v", err)
|
||||||
|
c.reconnect()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
|
||||||
|
func (c *Client) readPumpWithDisconnectDetection() {
|
||||||
|
defer func() {
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
// Only attempt reconnect if we're not shutting down
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
// Shutting down, don't reconnect
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
c.reconnect()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
var msg WSMessage
|
||||||
|
err := c.conn.ReadJSON(&msg)
|
||||||
|
if err != nil {
|
||||||
|
// Check if we're shutting down before logging error
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
// Expected during shutdown, don't log as error
|
||||||
|
logger.Debug("WebSocket connection closed during shutdown")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// Unexpected error during normal operation
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
||||||
|
logger.Error("WebSocket read error: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Debug("WebSocket connection closed: %v", err)
|
||||||
|
}
|
||||||
|
return // triggers reconnect via defer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.handlersMux.RLock()
|
||||||
|
if handler, ok := c.handlers[msg.Type]; ok {
|
||||||
|
handler(msg)
|
||||||
|
}
|
||||||
|
c.handlersMux.RUnlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -335,9 +576,16 @@ func (c *Client) reconnect() {
|
|||||||
c.setConnected(false)
|
c.setConnected(false)
|
||||||
if c.conn != nil {
|
if c.conn != nil {
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
go c.connectWithRetry()
|
// Only reconnect if we're not shutting down
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
go c.connectWithRetry()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) setConnected(status bool) {
|
func (c *Client) setConnected(status bool) {
|
||||||
@@ -345,3 +593,42 @@ func (c *Client) setConnected(status bool) {
|
|||||||
defer c.reconnectMux.Unlock()
|
defer c.reconnectMux.Unlock()
|
||||||
c.isConnected = status
|
c.isConnected = status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
|
||||||
|
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||||
|
logger.Info("Loading tls-client-cert %s", p12Path)
|
||||||
|
// Read the PKCS12 file
|
||||||
|
p12Data, err := os.ReadFile(p12Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse PKCS12 with empty password for non-encrypted files
|
||||||
|
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate
|
||||||
|
cert := tls.Certificate{
|
||||||
|
Certificate: [][]byte{certificate.Raw},
|
||||||
|
PrivateKey: privateKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional: Add CA certificates if present
|
||||||
|
rootCAs, err := x509.SystemCertPool()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
|
||||||
|
}
|
||||||
|
if len(caCerts) > 0 {
|
||||||
|
for _, caCert := range caCerts {
|
||||||
|
rootCAs.AddCert(caCert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TLS configuration
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
)
|
|
||||||
|
|
||||||
func getConfigPath() string {
|
|
||||||
var configDir string
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "darwin":
|
|
||||||
configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "newt-client")
|
|
||||||
case "windows":
|
|
||||||
configDir = filepath.Join(os.Getenv("APPDATA"), "newt-client")
|
|
||||||
default: // linux and others
|
|
||||||
configDir = filepath.Join(os.Getenv("HOME"), ".config", "newt-client")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
|
||||||
log.Printf("Failed to create config directory: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return filepath.Join(configDir, "config.json")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) loadConfig() error {
|
|
||||||
if c.config.NewtID != "" && c.config.Secret != "" && c.config.Endpoint != "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
configPath := getConfigPath()
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var config Config
|
|
||||||
if err := json.Unmarshal(data, &config); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.NewtID == "" {
|
|
||||||
c.config.NewtID = config.NewtID
|
|
||||||
}
|
|
||||||
if c.config.Token == "" {
|
|
||||||
c.config.Token = config.Token
|
|
||||||
}
|
|
||||||
if c.config.Secret == "" {
|
|
||||||
c.config.Secret = config.Secret
|
|
||||||
}
|
|
||||||
if c.config.Endpoint == "" {
|
|
||||||
c.config.Endpoint = config.Endpoint
|
|
||||||
c.baseURL = config.Endpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) saveConfig() error {
|
|
||||||
configPath := getConfigPath()
|
|
||||||
data, err := json.MarshalIndent(c.config, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return os.WriteFile(configPath, data, 0644)
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package websocket
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
NewtID string `json:"newtId"`
|
|
||||||
Secret string `json:"secret"`
|
|
||||||
Token string `json:"token"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TokenResponse struct {
|
|
||||||
Data struct {
|
|
||||||
Token string `json:"token"`
|
|
||||||
} `json:"data"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type WSMessage struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Data interface{} `json:"data"`
|
|
||||||
}
|
|
||||||
260
wgtester/wgtester.go
Normal file
260
wgtester/wgtester.go
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
package wgtester
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Magic bytes to identify our packets
|
||||||
|
magicHeader uint32 = 0xDEADBEEF
|
||||||
|
// Request packet type
|
||||||
|
packetTypeRequest uint8 = 1
|
||||||
|
// Response packet type
|
||||||
|
packetTypeResponse uint8 = 2
|
||||||
|
// Packet format:
|
||||||
|
// - 4 bytes: magic header (0xDEADBEEF)
|
||||||
|
// - 1 byte: packet type (1 = request, 2 = response)
|
||||||
|
// - 8 bytes: timestamp (for round-trip timing)
|
||||||
|
packetSize = 13
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client handles checking connectivity to a server
|
||||||
|
type Client struct {
|
||||||
|
conn *net.UDPConn
|
||||||
|
serverAddr string
|
||||||
|
monitorRunning bool
|
||||||
|
monitorLock sync.Mutex
|
||||||
|
connLock sync.Mutex // Protects connection operations
|
||||||
|
shutdownCh chan struct{}
|
||||||
|
packetInterval time.Duration
|
||||||
|
timeout time.Duration
|
||||||
|
maxAttempts int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionStatus represents the current connection state
|
||||||
|
type ConnectionStatus struct {
|
||||||
|
Connected bool
|
||||||
|
RTT time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new connection test client
|
||||||
|
func NewClient(serverAddr string) (*Client, error) {
|
||||||
|
return &Client{
|
||||||
|
serverAddr: serverAddr,
|
||||||
|
shutdownCh: make(chan struct{}),
|
||||||
|
packetInterval: 2 * time.Second,
|
||||||
|
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||||
|
maxAttempts: 3, // Default max attempts
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPacketInterval changes how frequently packets are sent in monitor mode
|
||||||
|
func (c *Client) SetPacketInterval(interval time.Duration) {
|
||||||
|
c.packetInterval = interval
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTimeout changes the timeout for waiting for responses
|
||||||
|
func (c *Client) SetTimeout(timeout time.Duration) {
|
||||||
|
c.timeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||||
|
func (c *Client) SetMaxAttempts(attempts int) {
|
||||||
|
c.maxAttempts = attempts
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close cleans up client resources
|
||||||
|
func (c *Client) Close() {
|
||||||
|
c.StopMonitor()
|
||||||
|
|
||||||
|
c.connLock.Lock()
|
||||||
|
defer c.connLock.Unlock()
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureConnection makes sure we have an active UDP connection
|
||||||
|
func (c *Client) ensureConnection() error {
|
||||||
|
c.connLock.Lock()
|
||||||
|
defer c.connLock.Unlock()
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.conn, err = net.DialUDP("udp", nil, serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConnection checks if the connection to the server is working
|
||||||
|
// Returns true if connected, false otherwise
|
||||||
|
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||||
|
if err := c.ensureConnection(); err != nil {
|
||||||
|
logger.Warn("Failed to ensure connection: %v", err)
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare packet buffer
|
||||||
|
packet := make([]byte, packetSize)
|
||||||
|
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
|
||||||
|
packet[4] = packetTypeRequest
|
||||||
|
|
||||||
|
// Send multiple attempts as specified
|
||||||
|
for attempt := 0; attempt < c.maxAttempts; attempt++ {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false, 0
|
||||||
|
default:
|
||||||
|
// Add current timestamp to packet
|
||||||
|
timestamp := time.Now().UnixNano()
|
||||||
|
binary.BigEndian.PutUint64(packet[5:13], uint64(timestamp))
|
||||||
|
|
||||||
|
// Lock the connection for the entire send/receive operation
|
||||||
|
c.connLock.Lock()
|
||||||
|
|
||||||
|
// Check if connection is still valid after acquiring lock
|
||||||
|
if c.conn == nil {
|
||||||
|
c.connLock.Unlock()
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
|
||||||
|
_, err := c.conn.Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
c.connLock.Unlock()
|
||||||
|
logger.Info("Error sending packet: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Debug("Successfully sent monitor packet")
|
||||||
|
|
||||||
|
// Set read deadline
|
||||||
|
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||||
|
|
||||||
|
// Wait for response
|
||||||
|
responseBuffer := make([]byte, packetSize)
|
||||||
|
n, err := c.conn.Read(responseBuffer)
|
||||||
|
c.connLock.Unlock()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
// Timeout, try next attempt
|
||||||
|
time.Sleep(100 * time.Millisecond) // Brief pause between attempts
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Error("Error reading response: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != packetSize {
|
||||||
|
continue // Malformed packet
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
magic := binary.BigEndian.Uint32(responseBuffer[0:4])
|
||||||
|
packetType := responseBuffer[4]
|
||||||
|
if magic != magicHeader || packetType != packetTypeResponse {
|
||||||
|
continue // Not our response
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the original timestamp and calculate RTT
|
||||||
|
sentTimestamp := int64(binary.BigEndian.Uint64(responseBuffer[5:13]))
|
||||||
|
rtt := time.Duration(time.Now().UnixNano() - sentTimestamp)
|
||||||
|
|
||||||
|
return true, rtt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConnectionWithTimeout tries to test connection with a timeout
|
||||||
|
// Returns true if connected, false otherwise
|
||||||
|
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
return c.TestConnection(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MonitorCallback is the function type for connection status change callbacks
|
||||||
|
type MonitorCallback func(status ConnectionStatus)
|
||||||
|
|
||||||
|
// StartMonitor begins monitoring the connection and calls the callback
|
||||||
|
// when the connection status changes
|
||||||
|
func (c *Client) StartMonitor(callback MonitorCallback) error {
|
||||||
|
c.monitorLock.Lock()
|
||||||
|
defer c.monitorLock.Unlock()
|
||||||
|
|
||||||
|
if c.monitorRunning {
|
||||||
|
logger.Info("Monitor already running")
|
||||||
|
return nil // Already running
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ensureConnection(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.monitorRunning = true
|
||||||
|
c.shutdownCh = make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var lastConnected bool
|
||||||
|
firstRun := true
|
||||||
|
|
||||||
|
ticker := time.NewTicker(c.packetInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.shutdownCh:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||||
|
connected, rtt := c.TestConnection(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Callback if status changed or it's the first check
|
||||||
|
if connected != lastConnected || firstRun {
|
||||||
|
callback(ConnectionStatus{
|
||||||
|
Connected: connected,
|
||||||
|
RTT: rtt,
|
||||||
|
})
|
||||||
|
lastConnected = connected
|
||||||
|
firstRun = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopMonitor stops the connection monitoring
|
||||||
|
func (c *Client) StopMonitor() {
|
||||||
|
c.monitorLock.Lock()
|
||||||
|
defer c.monitorLock.Unlock()
|
||||||
|
|
||||||
|
if !c.monitorRunning {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(c.shutdownCh)
|
||||||
|
c.monitorRunning = false
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user