mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 14:06:41 +00:00
Compare commits
124 Commits
1.2.0-rc.0
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af973b2440 | ||
|
|
dd9bff9a4b | ||
|
|
1be5e454ba | ||
|
|
4a25a0d413 | ||
|
|
7fc3c7088e | ||
|
|
1869e70894 | ||
|
|
79783cc3dc | ||
|
|
584298e3bd | ||
|
|
f683afa647 | ||
|
|
ba2631d388 | ||
|
|
6ae4e2b691 | ||
|
|
51eee9dcf5 | ||
|
|
660e9e0e35 | ||
|
|
4ef6089053 | ||
|
|
c4e297cc96 | ||
|
|
e3f5497176 | ||
|
|
6a5dcc01a6 | ||
|
|
18b6d3bb0f | ||
|
|
ccbfdc5265 | ||
|
|
ab04537278 | ||
|
|
29c36c9837 | ||
|
|
c47e9bf547 | ||
|
|
abb682c935 | ||
|
|
79e8a4a8bb | ||
|
|
f2e81c024a | ||
|
|
6d10650e70 | ||
|
|
a81c683c66 | ||
|
|
25cb50901e | ||
|
|
a8e0844758 | ||
|
|
8b9ee6f26a | ||
|
|
82e8fcc3a7 | ||
|
|
e2b7777ba7 | ||
|
|
4e4d1a39f6 | ||
|
|
17dc1b0be1 | ||
|
|
a06436eeab | ||
|
|
a83cc2a3a3 | ||
|
|
d56537d0fd | ||
|
|
31bb483e40 | ||
|
|
cd91ae6e3a | ||
|
|
a9ec1e61d3 | ||
|
|
a13010c4af | ||
|
|
cfac3cdd53 | ||
|
|
5ecba61718 | ||
|
|
2ea12ce258 | ||
|
|
0b46289136 | ||
|
|
71044165d0 | ||
|
|
eafd816159 | ||
|
|
e1a687407e | ||
|
|
bd8031651e | ||
|
|
a63439543d | ||
|
|
90cd6e7f6e | ||
|
|
ea4a63c9b3 | ||
|
|
e047330ffd | ||
|
|
9dcc0796a6 | ||
|
|
4b6999e06a | ||
|
|
69952ee5c5 | ||
|
|
3710880ce0 | ||
|
|
17b75bf58f | ||
|
|
3ba1714524 | ||
|
|
3470da76fc | ||
|
|
c86df2c041 | ||
|
|
0e8315b149 | ||
|
|
2ab9790588 | ||
|
|
1ecb97306f | ||
|
|
15e96a779c | ||
|
|
dada0cc124 | ||
|
|
9c0b4fcd5f | ||
|
|
8a788ef238 | ||
|
|
20e0c18845 | ||
|
|
5b637bb4ca | ||
|
|
c565a46a6f | ||
|
|
7b7eae617a | ||
|
|
1ed27fec1a | ||
|
|
83edde3449 | ||
|
|
1b43f029a9 | ||
|
|
aeb908b68c | ||
|
|
f08b17c7bd | ||
|
|
cce8742490 | ||
|
|
c56696bab1 | ||
|
|
7bb004cf50 | ||
|
|
28910ce188 | ||
|
|
f8dc134210 | ||
|
|
148f5fde23 | ||
|
|
b76259bc31 | ||
|
|
88cc57bcef | ||
|
|
385c64c364 | ||
|
|
0b05497c25 | ||
|
|
4e3e824276 | ||
|
|
effc1a31ac | ||
|
|
03051a37fe | ||
|
|
8cf2a28b6f | ||
|
|
9f3422de1b | ||
|
|
e6d0e9bb13 | ||
|
|
da0ad21fd4 | ||
|
|
2940f16f19 | ||
|
|
44c8d871c2 | ||
|
|
96a88057f9 | ||
|
|
d96fe6391e | ||
|
|
fe7fd31955 | ||
|
|
86b19f243e | ||
|
|
d0940d03c4 | ||
|
|
5a51753dbf | ||
|
|
70be82d68a | ||
|
|
dde79bb2dc | ||
|
|
3822b1a065 | ||
|
|
8b68f00f59 | ||
|
|
fe197f0a0b | ||
|
|
675c934ce1 | ||
|
|
708c761fa6 | ||
|
|
78dc6508a4 | ||
|
|
7f6c824122 | ||
|
|
9ba3569573 | ||
|
|
fd38f4cc59 | ||
|
|
c5d5fcedd9 | ||
|
|
13c0a082b5 | ||
|
|
48962d4b65 | ||
|
|
c469707986 | ||
|
|
13c40f6b2c | ||
|
|
6071be0d08 | ||
|
|
4b269782ea | ||
|
|
518bf0e36a | ||
|
|
c80bb9740a | ||
|
|
3ceef1ef74 | ||
|
|
acb0b4a9a5 |
14
.github/dependabot.yml
vendored
14
.github/dependabot.yml
vendored
@@ -5,20 +5,10 @@ updates:
|
||||
schedule:
|
||||
interval: "daily"
|
||||
groups:
|
||||
dev-patch-updates:
|
||||
dependency-type: "development"
|
||||
patch-updates:
|
||||
update-types:
|
||||
- "patch"
|
||||
dev-minor-updates:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
prod-patch-updates:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "patch"
|
||||
prod-minor-updates:
|
||||
dependency-type: "production"
|
||||
minor-updates:
|
||||
update-types:
|
||||
- "minor"
|
||||
|
||||
|
||||
647
.github/workflows/cicd.yml
vendored
647
.github/workflows/cicd.yml
vendored
@@ -1,60 +1,615 @@
|
||||
name: CI/CD Pipeline
|
||||
|
||||
permissions:
|
||||
contents: write # gh-release
|
||||
packages: write # GHCR push
|
||||
id-token: write # Keyless-Signatures & Attestations
|
||||
attestations: write # actions/attest-build-provenance
|
||||
security-events: write # upload-sarif
|
||||
actions: read
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
push:
|
||||
tags:
|
||||
- "[0-9]+.[0-9]+.[0-9]+"
|
||||
- "[0-9]+.[0-9]+.[0-9]+-rc.[0-9]+"
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: "SemVer version to release (e.g., 1.2.3, no leading 'v')"
|
||||
required: true
|
||||
type: string
|
||||
publish_latest:
|
||||
description: "Also publish the 'latest' image tag"
|
||||
required: true
|
||||
type: boolean
|
||||
default: false
|
||||
publish_minor:
|
||||
description: "Also publish the 'major.minor' image tag (e.g., 1.2)"
|
||||
required: true
|
||||
type: boolean
|
||||
default: false
|
||||
target_branch:
|
||||
description: "Branch to tag"
|
||||
required: false
|
||||
default: "main"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event_name == 'workflow_dispatch' && github.event.inputs.version || github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: amd64-runner
|
||||
prepare:
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
name: Prepare release (create tag)
|
||||
runs-on: ubuntu-24.04
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
- name: Validate version input
|
||||
shell: bash
|
||||
env:
|
||||
INPUT_VERSION: ${{ inputs.version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if ! [[ "$INPUT_VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+(-rc\.[0-9]+)?$ ]]; then
|
||||
echo "Invalid version: $INPUT_VERSION (expected X.Y.Z or X.Y.Z-rc.N)" >&2
|
||||
exit 1
|
||||
fi
|
||||
- name: Create and push tag
|
||||
shell: bash
|
||||
env:
|
||||
TARGET_BRANCH: ${{ inputs.target_branch }}
|
||||
VERSION: ${{ inputs.version }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
git fetch --prune origin
|
||||
git checkout "$TARGET_BRANCH"
|
||||
git pull --ff-only origin "$TARGET_BRANCH"
|
||||
if git rev-parse -q --verify "refs/tags/$VERSION" >/dev/null; then
|
||||
echo "Tag $VERSION already exists" >&2
|
||||
exit 1
|
||||
fi
|
||||
git tag -a "$VERSION" -m "Release $VERSION"
|
||||
git push origin "refs/tags/$VERSION"
|
||||
release:
|
||||
if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && github.actor != 'github-actions[bot]') }}
|
||||
name: Build and Release
|
||||
runs-on: ubuntu-24.04
|
||||
timeout-minutes: 120
|
||||
env:
|
||||
DOCKERHUB_IMAGE: docker.io/fosrl/${{ github.event.repository.name }}
|
||||
GHCR_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Capture created timestamp
|
||||
run: echo "IMAGE_CREATED=$(date -u +%Y-%m-%dT%H:%M:%SZ)" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
||||
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
- name: Set up 1.2.0 Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 1.25
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: docker.io
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
|
||||
- name: Update version in main.go
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
if [ -f main.go ]; then
|
||||
sed -i 's/version_replaceme/'"$TAG"'/' main.go
|
||||
echo "Updated main.go with version $TAG"
|
||||
else
|
||||
echo "main.go not found"
|
||||
fi
|
||||
- name: Build and push Docker images
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
make docker-build-release tag=$TAG
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
- name: Normalize image names to lowercase
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "GHCR_IMAGE=${GHCR_IMAGE,,}" >> "$GITHUB_ENV"
|
||||
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE,,}" >> "$GITHUB_ENV"
|
||||
shell: bash
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
- name: Extract tag name
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
INPUT_VERSION: ${{ inputs.version }}
|
||||
run: |
|
||||
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
|
||||
echo "TAG=${INPUT_VERSION}" >> $GITHUB_ENV
|
||||
else
|
||||
echo "TAG=${{ github.ref_name }}" >> $GITHUB_ENV
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Validate pushed tag format (no leading 'v')
|
||||
if: ${{ github.event_name == 'push' }}
|
||||
shell: bash
|
||||
env:
|
||||
TAG_GOT: ${{ env.TAG }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if [[ "$TAG_GOT" =~ ^[0-9]+\.[0-9]+\.[0-9]+(-rc\.[0-9]+)?$ ]]; then
|
||||
echo "Tag OK: $TAG_GOT"
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: Tag '$TAG_GOT' is not allowed. Use 'X.Y.Z' or 'X.Y.Z-rc.N' (no leading 'v')." >&2
|
||||
exit 1
|
||||
- name: Wait for tag to be visible (dispatch only)
|
||||
if: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for i in {1..90}; do
|
||||
if git ls-remote --tags origin "refs/tags/${TAG}" | grep -qE "refs/tags/${TAG}$"; then
|
||||
echo "Tag ${TAG} is visible on origin"; exit 0
|
||||
fi
|
||||
echo "Tag not yet visible, retrying... ($i/90)"
|
||||
sleep 2
|
||||
done
|
||||
echo "Tag ${TAG} not visible after waiting"; exit 1
|
||||
shell: bash
|
||||
|
||||
- name: Update version in main.go
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
if [ -f main.go ]; then
|
||||
sed -i 's/version_replaceme/'"$TAG"'/' main.go
|
||||
echo "Updated main.go with version $TAG"
|
||||
else
|
||||
echo "main.go not found"
|
||||
fi
|
||||
|
||||
- name: Ensure repository is at the tagged commit (dispatch only)
|
||||
if: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
git fetch --tags --force
|
||||
git checkout "refs/tags/${TAG}"
|
||||
echo "Checked out $(git rev-parse --short HEAD) for tag ${TAG}"
|
||||
shell: bash
|
||||
|
||||
- name: Detect release candidate (rc)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if [[ "${TAG}" =~ ^[0-9]+\.[0-9]+\.[0-9]+-rc\.[0-9]+$ ]]; then
|
||||
echo "IS_RC=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "IS_RC=false" >> $GITHUB_ENV
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Resolve publish-latest flag
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
PL_INPUT: ${{ inputs.publish_latest }}
|
||||
PL_VAR: ${{ vars.PUBLISH_LATEST }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
val="false"
|
||||
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
|
||||
if [ "${PL_INPUT}" = "true" ]; then val="true"; fi
|
||||
else
|
||||
if [ "${PL_VAR}" = "true" ]; then val="true"; fi
|
||||
fi
|
||||
echo "PUBLISH_LATEST=$val" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Resolve publish-minor flag
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
PM_INPUT: ${{ inputs.publish_minor }}
|
||||
PM_VAR: ${{ vars.PUBLISH_MINOR }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
val="false"
|
||||
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
|
||||
if [ "${PM_INPUT}" = "true" ]; then val="true"; fi
|
||||
else
|
||||
if [ "${PM_VAR}" = "true" ]; then val="true"; fi
|
||||
fi
|
||||
echo "PUBLISH_MINOR=$val" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Cache Go modules
|
||||
if: ${{ hashFiles('**/go.sum') != '' }}
|
||||
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Go vet & test
|
||||
if: ${{ hashFiles('**/go.mod') != '' }}
|
||||
run: |
|
||||
go version
|
||||
go vet ./...
|
||||
go test ./... -race -covermode=atomic
|
||||
shell: bash
|
||||
|
||||
- name: Resolve license fallback
|
||||
run: echo "IMAGE_LICENSE=${{ github.event.repository.license.spdx_id || 'NOASSERTION' }}" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Resolve registries list (GHCR always, Docker Hub only if creds)
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
images="${GHCR_IMAGE}"
|
||||
if [ -n "${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}" ] && [ -n "${{ secrets.DOCKER_HUB_USERNAME }}" ]; then
|
||||
images="${images}\n${DOCKERHUB_IMAGE}"
|
||||
fi
|
||||
{
|
||||
echo 'IMAGE_LIST<<EOF'
|
||||
echo -e "$images"
|
||||
echo 'EOF'
|
||||
} >> "$GITHUB_ENV"
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5.10.0
|
||||
with:
|
||||
images: ${{ env.IMAGE_LIST }}
|
||||
tags: |
|
||||
type=semver,pattern={{version}},value=${{ env.TAG }}
|
||||
type=semver,pattern={{major}}.{{minor}},value=${{ env.TAG }},enable=${{ env.PUBLISH_MINOR == 'true' && env.IS_RC != 'true' }}
|
||||
type=raw,value=latest,enable=${{ env.IS_RC != 'true' }}
|
||||
flavor: |
|
||||
latest=false
|
||||
labels: |
|
||||
org.opencontainers.image.title=${{ github.event.repository.name }}
|
||||
org.opencontainers.image.version=${{ env.TAG }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
org.opencontainers.image.source=${{ github.event.repository.html_url }}
|
||||
org.opencontainers.image.url=${{ github.event.repository.html_url }}
|
||||
org.opencontainers.image.documentation=${{ github.event.repository.html_url }}
|
||||
org.opencontainers.image.description=${{ github.event.repository.description }}
|
||||
org.opencontainers.image.licenses=${{ env.IMAGE_LICENSE }}
|
||||
org.opencontainers.image.created=${{ env.IMAGE_CREATED }}
|
||||
org.opencontainers.image.ref.name=${{ env.TAG }}
|
||||
org.opencontainers.image.authors=${{ github.repository_owner }}
|
||||
- name: Echo build config (non-secret)
|
||||
shell: bash
|
||||
env:
|
||||
IMAGE_TITLE: ${{ github.event.repository.name }}
|
||||
IMAGE_VERSION: ${{ env.TAG }}
|
||||
IMAGE_REVISION: ${{ github.sha }}
|
||||
IMAGE_SOURCE_URL: ${{ github.event.repository.html_url }}
|
||||
IMAGE_URL: ${{ github.event.repository.html_url }}
|
||||
IMAGE_DESCRIPTION: ${{ github.event.repository.description }}
|
||||
IMAGE_LICENSE: ${{ env.IMAGE_LICENSE }}
|
||||
DOCKERHUB_IMAGE: ${{ env.DOCKERHUB_IMAGE }}
|
||||
GHCR_IMAGE: ${{ env.GHCR_IMAGE }}
|
||||
DOCKER_HUB_USER: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
REPO: ${{ github.repository }}
|
||||
OWNER: ${{ github.repository_owner }}
|
||||
WORKFLOW_REF: ${{ github.workflow_ref }}
|
||||
REF: ${{ github.ref }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "=== OCI Label Values ==="
|
||||
echo "org.opencontainers.image.title=${IMAGE_TITLE}"
|
||||
echo "org.opencontainers.image.version=${IMAGE_VERSION}"
|
||||
echo "org.opencontainers.image.revision=${IMAGE_REVISION}"
|
||||
echo "org.opencontainers.image.source=${IMAGE_SOURCE_URL}"
|
||||
echo "org.opencontainers.image.url=${IMAGE_URL}"
|
||||
echo "org.opencontainers.image.description=${IMAGE_DESCRIPTION}"
|
||||
echo "org.opencontainers.image.licenses=${IMAGE_LICENSE}"
|
||||
echo
|
||||
echo "=== Images ==="
|
||||
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE}"
|
||||
echo "GHCR_IMAGE=${GHCR_IMAGE}"
|
||||
echo "DOCKER_HUB_USERNAME=${DOCKER_HUB_USER}"
|
||||
echo
|
||||
echo "=== GitHub Kontext ==="
|
||||
echo "repository=${REPO}"
|
||||
echo "owner=${OWNER}"
|
||||
echo "workflow_ref=${WORKFLOW_REF}"
|
||||
echo "ref=${REF}"
|
||||
echo "ref_name=${REF_NAME}"
|
||||
echo "run_url=${RUN_URL}"
|
||||
echo
|
||||
echo "=== docker/metadata-action outputs (Tags/Labels), raw ==="
|
||||
echo "::group::tags"
|
||||
echo "${{ steps.meta.outputs.tags }}"
|
||||
echo "::endgroup::"
|
||||
echo "::group::labels"
|
||||
echo "${{ steps.meta.outputs.labels }}"
|
||||
echo "::endgroup::"
|
||||
- name: Build and push (Docker Hub + GHCR)
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6.18.0
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64,linux/arm/v7
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha,scope=${{ github.repository }}
|
||||
cache-to: type=gha,mode=max,scope=${{ github.repository }}
|
||||
provenance: mode=max
|
||||
sbom: true
|
||||
|
||||
- name: Compute image digest refs
|
||||
run: |
|
||||
echo "DIGEST=${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
|
||||
echo "GHCR_REF=$GHCR_IMAGE@${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
|
||||
echo "DH_REF=$DOCKERHUB_IMAGE@${{ steps.build.outputs.digest }}" >> $GITHUB_ENV
|
||||
echo "Built digest: ${{ steps.build.outputs.digest }}"
|
||||
shell: bash
|
||||
|
||||
- name: Attest build provenance (GHCR)
|
||||
id: attest-ghcr
|
||||
uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
|
||||
with:
|
||||
subject-name: ${{ env.GHCR_IMAGE }}
|
||||
subject-digest: ${{ steps.build.outputs.digest }}
|
||||
push-to-registry: true
|
||||
show-summary: true
|
||||
|
||||
- name: Attest build provenance (Docker Hub)
|
||||
continue-on-error: true
|
||||
id: attest-dh
|
||||
uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
|
||||
with:
|
||||
subject-name: index.docker.io/fosrl/${{ github.event.repository.name }}
|
||||
subject-digest: ${{ steps.build.outputs.digest }}
|
||||
push-to-registry: true
|
||||
show-summary: true
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
with:
|
||||
cosign-release: 'v3.0.2'
|
||||
|
||||
- name: Sanity check cosign private key
|
||||
env:
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
cosign public-key --key env://COSIGN_PRIVATE_KEY >/dev/null
|
||||
shell: bash
|
||||
|
||||
- name: Sign GHCR image (digest) with key (recursive)
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Signing ${GHCR_REF} (digest) recursively with provided key"
|
||||
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${GHCR_REF}"
|
||||
echo "Waiting 30 seconds for signatures to propagate..."
|
||||
shell: bash
|
||||
|
||||
- name: Generate SBOM (SPDX JSON)
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # v0.33.1
|
||||
with:
|
||||
image-ref: ${{ env.GHCR_IMAGE }}@${{ steps.build.outputs.digest }}
|
||||
format: spdx-json
|
||||
output: sbom.spdx.json
|
||||
|
||||
- name: Validate SBOM JSON
|
||||
run: jq -e . sbom.spdx.json >/dev/null
|
||||
shell: bash
|
||||
|
||||
- name: Minify SBOM JSON (optional hardening)
|
||||
run: jq -c . sbom.spdx.json > sbom.min.json && mv sbom.min.json sbom.spdx.json
|
||||
shell: bash
|
||||
|
||||
- name: Create SBOM attestation (GHCR, private key)
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
cosign attest \
|
||||
--key env://COSIGN_PRIVATE_KEY \
|
||||
--type spdxjson \
|
||||
--predicate sbom.spdx.json \
|
||||
"${GHCR_REF}"
|
||||
shell: bash
|
||||
|
||||
- name: Create SBOM attestation (Docker Hub, private key)
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
cosign attest \
|
||||
--key env://COSIGN_PRIVATE_KEY \
|
||||
--type spdxjson \
|
||||
--predicate sbom.spdx.json \
|
||||
"${DH_REF}"
|
||||
shell: bash
|
||||
|
||||
- name: Keyless sign & verify GHCR digest (OIDC)
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
WORKFLOW_REF: ${{ github.workflow_ref }} # owner/repo/.github/workflows/<file>@refs/tags/<tag>
|
||||
ISSUER: https://token.actions.githubusercontent.com
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Keyless signing ${GHCR_REF}"
|
||||
cosign sign --rekor-url https://rekor.sigstore.dev --recursive "${GHCR_REF}"
|
||||
echo "Verify keyless (OIDC) signature policy on ${GHCR_REF}"
|
||||
cosign verify \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${WORKFLOW_REF}" \
|
||||
"${GHCR_REF}" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Sign Docker Hub image (digest) with key (recursive)
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Signing ${DH_REF} (digest) recursively with provided key (Docker media types fallback)"
|
||||
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${DH_REF}"
|
||||
shell: bash
|
||||
|
||||
- name: Keyless sign & verify Docker Hub digest (OIDC)
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_YES: "true"
|
||||
ISSUER: https://token.actions.githubusercontent.com
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Keyless signing ${DH_REF} (force public-good Rekor)"
|
||||
cosign sign --rekor-url https://rekor.sigstore.dev --recursive "${DH_REF}"
|
||||
echo "Keyless verify via Rekor (strict identity)"
|
||||
if ! cosign verify \
|
||||
--rekor-url https://rekor.sigstore.dev \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
|
||||
"${DH_REF}" -o text; then
|
||||
echo "Rekor verify failed — retry offline bundle verify (no Rekor)"
|
||||
if ! cosign verify \
|
||||
--offline \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
|
||||
"${DH_REF}" -o text; then
|
||||
echo "Offline bundle verify failed — ignore tlog (TEMP for debugging)"
|
||||
cosign verify \
|
||||
--insecure-ignore-tlog=true \
|
||||
--certificate-oidc-issuer "${ISSUER}" \
|
||||
--certificate-identity "https://github.com/${{ github.workflow_ref }}" \
|
||||
"${DH_REF}" -o text || true
|
||||
fi
|
||||
fi
|
||||
- name: Verify signature (public key) GHCR digest + tag
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG_VAR="${TAG}"
|
||||
echo "Verifying (digest) ${GHCR_REF}"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "$GHCR_REF" -o text
|
||||
echo "Verifying (tag) $GHCR_IMAGE:$TAG_VAR"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "$GHCR_IMAGE:$TAG_VAR" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify SBOM attestation (GHCR)
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: cosign verify-attestation --key env://COSIGN_PUBLIC_KEY --type spdxjson "$GHCR_REF" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify SLSA provenance (GHCR)
|
||||
env:
|
||||
ISSUER: https://token.actions.githubusercontent.com
|
||||
WFREF: ${{ github.workflow_ref }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
# (optional) show which predicate types are present to aid debugging
|
||||
cosign download attestation "$GHCR_REF" \
|
||||
| jq -r '.payload | @base64d | fromjson | .predicateType' | sort -u || true
|
||||
# Verify the SLSA v1 provenance attestation (predicate URL)
|
||||
cosign verify-attestation \
|
||||
--type 'https://slsa.dev/provenance/v1' \
|
||||
--certificate-oidc-issuer "$ISSUER" \
|
||||
--certificate-identity "https://github.com/${WFREF}" \
|
||||
--rekor-url https://rekor.sigstore.dev \
|
||||
"$GHCR_REF" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify signature (public key) Docker Hub digest
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Verifying (digest) ${DH_REF} with Docker media types"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "${DH_REF}" -o text
|
||||
shell: bash
|
||||
|
||||
- name: Verify signature (public key) Docker Hub tag
|
||||
continue-on-error: true
|
||||
env:
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
COSIGN_DOCKER_MEDIA_TYPES: "1"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Verifying (tag) $DOCKERHUB_IMAGE:$TAG with Docker media types"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "$DOCKERHUB_IMAGE:$TAG" -o text
|
||||
shell: bash
|
||||
|
||||
# - name: Trivy scan (GHCR image)
|
||||
# id: trivy
|
||||
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # v0.33.1
|
||||
# with:
|
||||
# image-ref: ${{ env.GHCR_IMAGE }}@${{ steps.build.outputs.digest }}
|
||||
# format: sarif
|
||||
# output: trivy-ghcr.sarif
|
||||
# ignore-unfixed: true
|
||||
# vuln-type: os,library
|
||||
# severity: CRITICAL,HIGH
|
||||
# exit-code: ${{ (vars.TRIVY_FAIL || '0') }}
|
||||
|
||||
# - name: Upload SARIF
|
||||
# if: ${{ always() && hashFiles('trivy-ghcr.sarif') != '' }}
|
||||
# uses: github/codeql-action/upload-sarif@fdbfb4d2750291e159f0156def62b853c2798ca2 # v4.31.5
|
||||
# with:
|
||||
# sarif_file: trivy-ghcr.sarif
|
||||
# category: Image Vulnerability Scan
|
||||
|
||||
- name: Build binaries
|
||||
env:
|
||||
CGO_ENABLED: "0"
|
||||
GOFLAGS: "-trimpath"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG_VAR="${TAG}"
|
||||
make go-build-release tag=$TAG_VAR
|
||||
shell: bash
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: softprops/action-gh-release@a06a81a03ee405af7f2048a818ed3f03bbf83c7b # v2.5.0
|
||||
with:
|
||||
tag_name: ${{ env.TAG }}
|
||||
generate_release_notes: true
|
||||
prerelease: ${{ env.IS_RC == 'true' }}
|
||||
files: |
|
||||
bin/*
|
||||
fail_on_unmatched_files: true
|
||||
draft: true
|
||||
body: |
|
||||
## Container Images
|
||||
- GHCR: `${{ env.GHCR_REF }}`
|
||||
- Docker Hub: `${{ env.DH_REF || 'N/A' }}`
|
||||
**Digest:** `${{ steps.build.outputs.digest }}`
|
||||
|
||||
37
.github/workflows/stale-bot.yml
vendored
Normal file
37
.github/workflows/stale-bot.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
name: Mark and Close Stale Issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
workflow_dispatch: # Allow manual trigger
|
||||
|
||||
permissions:
|
||||
contents: write # only for delete-branch option
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
|
||||
with:
|
||||
days-before-stale: 14
|
||||
days-before-close: 14
|
||||
stale-issue-message: 'This issue has been automatically marked as stale due to 14 days of inactivity. It will be closed in 14 days if no further activity occurs.'
|
||||
close-issue-message: 'This issue has been automatically closed due to inactivity. If you believe this is still relevant, please open a new issue with up-to-date information.'
|
||||
stale-issue-label: 'stale'
|
||||
|
||||
exempt-issue-labels: 'needs investigating, networking, new feature, reverse proxy, bug, api, authentication, documentation, enhancement, help wanted, good first issue, question'
|
||||
|
||||
exempt-all-issue-assignees: true
|
||||
|
||||
only-labels: ''
|
||||
exempt-pr-labels: ''
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
|
||||
operations-per-run: 100
|
||||
remove-stale-when-updated: true
|
||||
delete-branch: false
|
||||
enable-statistics: true
|
||||
38
.github/workflows/test.yml
vendored
38
.github/workflows/test.yml
vendored
@@ -1,5 +1,8 @@
|
||||
name: Run Tests
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
@@ -7,28 +10,33 @@ on:
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
test:
|
||||
build-go:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Clone fosrl/newt
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
repository: fosrl/newt
|
||||
path: ../newt
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
|
||||
with:
|
||||
go-version: 1.25
|
||||
|
||||
- name: Build go
|
||||
run: go build
|
||||
|
||||
- name: Build Docker image
|
||||
run: make build
|
||||
|
||||
- name: Build binaries
|
||||
run: make go-build-release
|
||||
|
||||
build-docker:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
||||
|
||||
- name: Set up 1.2.0 Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
||||
|
||||
- name: Build Docker image
|
||||
run: make docker-build-dev
|
||||
|
||||
393
API.md
Normal file
393
API.md
Normal file
@@ -0,0 +1,393 @@
|
||||
## API
|
||||
|
||||
Olm can be controlled with an embedded API server when using `--enable-api`. This allows you to start it as a daemon and trigger it with the following endpoints. The API can listen on either a TCP address or a Unix socket/Windows named pipe.
|
||||
|
||||
### Socket vs TCP
|
||||
|
||||
When `--enable-api` is used, Olm can listen on a TCP address when configured via `--http-addr` (like `:9452`). Alternatively, Olm can listen on a Unix socket (Linux/macOS) or Windows named pipe for local-only communication with better security when using `--socket-path` (like `/var/run/olm.sock`).
|
||||
|
||||
**Unix Socket (Linux/macOS):**
|
||||
- Socket path example: `/var/run/olm/olm.sock`
|
||||
- The directory is created automatically if it doesn't exist
|
||||
- Socket permissions are set to `0666` to allow access
|
||||
- Existing socket files are automatically removed on startup
|
||||
- Socket file is cleaned up when Olm stops
|
||||
|
||||
**Windows Named Pipe:**
|
||||
- Pipe path example: `\\.\pipe\olm`
|
||||
- If the path doesn't start with `\`, it's automatically prefixed with `\\.\pipe\`
|
||||
- Security descriptor grants full access to Everyone and the current owner
|
||||
- Named pipes are automatically cleaned up by Windows
|
||||
|
||||
**Connecting to the Socket:**
|
||||
|
||||
```bash
|
||||
# Linux/macOS - using curl with Unix socket
|
||||
curl --unix-socket /var/run/olm/olm.sock http://localhost/status
|
||||
|
||||
---
|
||||
|
||||
### POST /connect
|
||||
Initiates a new connection request to a Pangolin server.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"id": "string",
|
||||
"secret": "string",
|
||||
"endpoint": "string",
|
||||
"userToken": "string",
|
||||
"mtu": 1280,
|
||||
"dns": "8.8.8.8",
|
||||
"dnsProxyIP": "string",
|
||||
"upstreamDNS": ["8.8.8.8:53", "1.1.1.1:53"],
|
||||
"interfaceName": "olm",
|
||||
"holepunch": false,
|
||||
"tlsClientCert": "string",
|
||||
"pingInterval": "3s",
|
||||
"pingTimeout": "5s",
|
||||
"orgId": "string",
|
||||
"fingerprint": {
|
||||
"username": "string",
|
||||
"hostname": "string",
|
||||
"platform": "string",
|
||||
"osVersion": "string",
|
||||
"kernelVersion": "string",
|
||||
"arch": "string",
|
||||
"deviceModel": "string",
|
||||
"serialNumber": "string"
|
||||
},
|
||||
"postures": {}
|
||||
}
|
||||
```
|
||||
|
||||
**Required Fields:**
|
||||
- `id`: Olm ID generated by Pangolin
|
||||
- `secret`: Authentication secret for the Olm ID
|
||||
- `endpoint`: Target Pangolin endpoint URL
|
||||
|
||||
**Optional Fields:**
|
||||
- `userToken`: User authentication token
|
||||
- `mtu`: MTU for the internal WireGuard interface (default: 1280)
|
||||
- `dns`: DNS server to use for resolving the endpoint
|
||||
- `dnsProxyIP`: DNS proxy IP address
|
||||
- `upstreamDNS`: Array of upstream DNS servers
|
||||
- `interfaceName`: Name of the WireGuard interface (default: olm)
|
||||
- `holepunch`: Enable NAT hole punching (default: false)
|
||||
- `tlsClientCert`: TLS client certificate
|
||||
- `pingInterval`: Interval for pinging the server (default: 3s)
|
||||
- `pingTimeout`: Timeout for each ping (default: 5s)
|
||||
- `orgId`: Organization ID to connect to
|
||||
- `fingerprint`: Device fingerprinting information (should be set before connecting)
|
||||
- `username`: Current username on the device
|
||||
- `hostname`: Device hostname
|
||||
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
|
||||
- `osVersion`: Operating system version
|
||||
- `kernelVersion`: Kernel version
|
||||
- `arch`: System architecture (e.g., amd64, arm64)
|
||||
- `deviceModel`: Device model identifier
|
||||
- `serialNumber`: Device serial number
|
||||
- `postures`: Device posture/security information
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `202 Accepted`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "connection request accepted"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `400 Bad Request` - Invalid JSON or missing required fields
|
||||
- `409 Conflict` - Already connected to a server (disconnect first)
|
||||
|
||||
---
|
||||
|
||||
### GET /status
|
||||
Returns the current connection status, registration state, and peer information.
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"connected": true,
|
||||
"registered": true,
|
||||
"terminated": false,
|
||||
"version": "1.0.0",
|
||||
"agent": "olm",
|
||||
"orgId": "org_123",
|
||||
"peers": {
|
||||
"10": {
|
||||
"siteId": 10,
|
||||
"name": "Site A",
|
||||
"connected": true,
|
||||
"rtt": 145338339,
|
||||
"lastSeen": "2025-08-13T14:39:17.208334428-07:00",
|
||||
"endpoint": "p.fosrl.io:21820",
|
||||
"isRelay": true,
|
||||
"peerAddress": "100.89.128.5",
|
||||
"holepunchConnected": false
|
||||
},
|
||||
"8": {
|
||||
"siteId": 8,
|
||||
"name": "Site B",
|
||||
"connected": false,
|
||||
"rtt": 0,
|
||||
"lastSeen": "2025-08-13T14:39:19.663823645-07:00",
|
||||
"endpoint": "p.fosrl.io:21820",
|
||||
"isRelay": true,
|
||||
"peerAddress": "100.89.128.10",
|
||||
"holepunchConnected": false
|
||||
}
|
||||
},
|
||||
"networkSettings": {
|
||||
"tunnelIP": "100.89.128.3/20"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Fields:**
|
||||
- `connected`: Boolean indicating if connected to Pangolin
|
||||
- `registered`: Boolean indicating if registered with the server
|
||||
- `terminated`: Boolean indicating if the connection was terminated
|
||||
- `version`: Olm version string
|
||||
- `agent`: Agent identifier
|
||||
- `orgId`: Current organization ID
|
||||
- `peers`: Map of peer statuses by site ID
|
||||
- `siteId`: Peer site identifier
|
||||
- `name`: Site name
|
||||
- `connected`: Boolean peer connection state
|
||||
- `rtt`: Peer round-trip time (integer, nanoseconds)
|
||||
- `lastSeen`: Last time peer was seen (RFC3339 timestamp)
|
||||
- `endpoint`: Peer endpoint address
|
||||
- `isRelay`: Whether the peer is relayed (true) or direct (false)
|
||||
- `peerAddress`: Peer's IP address in the tunnel
|
||||
- `holepunchConnected`: Whether holepunch connection is established
|
||||
- `networkSettings`: Current network configuration including tunnel IP
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-GET requests
|
||||
|
||||
---
|
||||
|
||||
### POST /disconnect
|
||||
Disconnects from the current Pangolin server and tears down the WireGuard tunnel.
|
||||
|
||||
**Request Body:** None required
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "disconnect initiated"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `409 Conflict` - Not currently connected to a server
|
||||
|
||||
---
|
||||
|
||||
### POST /switch-org
|
||||
Switches to a different organization while maintaining the connection.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"orgId": "string"
|
||||
}
|
||||
```
|
||||
|
||||
**Required Fields:**
|
||||
- `orgId`: The organization ID to switch to
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "org switch request accepted"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `400 Bad Request` - Invalid JSON or missing orgId field
|
||||
- `500 Internal Server Error` - Org switch failed
|
||||
|
||||
---
|
||||
|
||||
### PUT /metadata
|
||||
Updates device fingerprinting and posture information. This endpoint can be called at any time to update metadata, but it's recommended to provide this information in the initial `/connect` request or immediately before connecting.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"fingerprint": {
|
||||
"username": "string",
|
||||
"hostname": "string",
|
||||
"platform": "string",
|
||||
"osVersion": "string",
|
||||
"kernelVersion": "string",
|
||||
"arch": "string",
|
||||
"deviceModel": "string",
|
||||
"serialNumber": "string"
|
||||
},
|
||||
"postures": {}
|
||||
}
|
||||
```
|
||||
|
||||
**Optional Fields:**
|
||||
- `fingerprint`: Device fingerprinting information
|
||||
- `username`: Current username on the device
|
||||
- `hostname`: Device hostname
|
||||
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
|
||||
- `osVersion`: Operating system version
|
||||
- `kernelVersion`: Kernel version
|
||||
- `arch`: System architecture (e.g., amd64, arm64)
|
||||
- `deviceModel`: Device model identifier
|
||||
- `serialNumber`: Device serial number
|
||||
- `postures`: Device posture/security information (object with arbitrary key-value pairs)
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "metadata updated"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-PUT requests
|
||||
- `400 Bad Request` - Invalid JSON
|
||||
|
||||
**Note:** It's recommended to call this endpoint BEFORE `/connect` to ensure fingerprinting information is available during the initial connection handshake.
|
||||
|
||||
---
|
||||
|
||||
### POST /exit
|
||||
Initiates a graceful shutdown of the Olm process.
|
||||
|
||||
**Request Body:** None required
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "shutdown initiated"
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** The response is sent before shutdown begins. There is a 100ms delay before the actual shutdown to ensure the response is delivered.
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
|
||||
---
|
||||
|
||||
### GET /health
|
||||
Simple health check endpoint to verify the API server is running.
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "ok"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-GET requests
|
||||
|
||||
---
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Update metadata before connecting (recommended)
|
||||
```bash
|
||||
curl -X PUT http://localhost:9452/metadata \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"fingerprint": {
|
||||
"username": "john",
|
||||
"hostname": "johns-laptop",
|
||||
"platform": "macos",
|
||||
"osVersion": "14.2.1",
|
||||
"arch": "arm64",
|
||||
"deviceModel": "MacBookPro18,3"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### Connect to a peer
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/connect \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"id": "31frd0uzbjvp721",
|
||||
"secret": "h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6",
|
||||
"endpoint": "https://example.com"
|
||||
}'
|
||||
```
|
||||
|
||||
### Connect with additional options
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/connect \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"id": "31frd0uzbjvp721",
|
||||
"secret": "h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6",
|
||||
"endpoint": "https://example.com",
|
||||
"mtu": 1400,
|
||||
"holepunch": true,
|
||||
"pingInterval": "5s"
|
||||
}'
|
||||
```
|
||||
|
||||
### Check connection status
|
||||
```bash
|
||||
curl http://localhost:9452/status
|
||||
```
|
||||
|
||||
### Switch organization
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/switch-org \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"orgId": "org_456"}'
|
||||
```
|
||||
|
||||
### Disconnect from server
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/disconnect
|
||||
```
|
||||
|
||||
### Health check
|
||||
```bash
|
||||
curl http://localhost:9452/health
|
||||
```
|
||||
|
||||
### Shutdown Olm
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/exit
|
||||
```
|
||||
|
||||
### Using Unix socket (Linux/macOS)
|
||||
```bash
|
||||
curl --unix-socket /var/run/olm/olm.sock http://localhost/status
|
||||
curl --unix-socket /var/run/olm/olm.sock -X POST http://localhost/disconnect
|
||||
```
|
||||
63
Makefile
63
Makefile
@@ -1,26 +1,67 @@
|
||||
.PHONY: all local docker-build-release
|
||||
|
||||
all: go-build-release
|
||||
all: local
|
||||
|
||||
local:
|
||||
CGO_ENABLED=0 go build -o ./bin/olm
|
||||
|
||||
docker-build:
|
||||
docker build -t fosrl/olm:latest .
|
||||
|
||||
docker-build-release:
|
||||
@if [ -z "$(tag)" ]; then \
|
||||
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
|
||||
exit 1; \
|
||||
fi
|
||||
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:latest -f Dockerfile --push .
|
||||
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:$(tag) -f Dockerfile --push .
|
||||
docker buildx build . \
|
||||
--platform linux/arm/v7,linux/arm64,linux/amd64 \
|
||||
-t fosrl/olm:latest \
|
||||
-t fosrl/olm:$(tag) \
|
||||
-f Dockerfile \
|
||||
--push
|
||||
|
||||
local:
|
||||
CGO_ENABLED=0 go build -o bin/olm
|
||||
docker-build-dev:
|
||||
docker buildx build . \
|
||||
--platform linux/arm/v7,linux/arm64,linux/amd64 \
|
||||
-t fosrl/olm:latest \
|
||||
-f Dockerfile
|
||||
|
||||
build:
|
||||
docker build -t fosrl/olm:latest .
|
||||
.PHONY: go-build-release \
|
||||
go-build-release-linux-arm64 go-build-release-linux-arm32-v7 \
|
||||
go-build-release-linux-arm32-v6 go-build-release-linux-amd64 \
|
||||
go-build-release-linux-riscv64 go-build-release-darwin-arm64 \
|
||||
go-build-release-darwin-amd64 go-build-release-windows-amd64
|
||||
|
||||
go-build-release:
|
||||
go-build-release: \
|
||||
go-build-release-linux-arm64 \
|
||||
go-build-release-linux-arm32-v7 \
|
||||
go-build-release-linux-arm32-v6 \
|
||||
go-build-release-linux-amd64 \
|
||||
go-build-release-linux-riscv64 \
|
||||
go-build-release-darwin-arm64 \
|
||||
go-build-release-darwin-amd64 \
|
||||
go-build-release-windows-amd64 \
|
||||
|
||||
go-build-release-linux-arm64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/olm_linux_arm64
|
||||
|
||||
go-build-release-linux-arm32-v7:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/olm_linux_arm32
|
||||
|
||||
go-build-release-linux-arm32-v6:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -o bin/olm_linux_arm32v6
|
||||
|
||||
go-build-release-linux-amd64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/olm_linux_amd64
|
||||
|
||||
go-build-release-linux-riscv64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -o bin/olm_linux_riscv64
|
||||
|
||||
go-build-release-darwin-arm64:
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/olm_darwin_arm64
|
||||
|
||||
go-build-release-darwin-amd64:
|
||||
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/olm_darwin_amd64
|
||||
|
||||
go-build-release-windows-amd64:
|
||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/olm_windows_amd64.exe
|
||||
|
||||
clean:
|
||||
rm olm
|
||||
503
README.md
503
README.md
@@ -6,7 +6,7 @@ Olm is a [WireGuard](https://www.wireguard.com/) tunnel client designed to secur
|
||||
|
||||
Olm is used with Pangolin and Newt as part of the larger system. See documentation below:
|
||||
|
||||
- [Full Documentation](https://docs.pangolin.net)
|
||||
- [Full Documentation](https://docs.pangolin.net/manage/clients/understanding-clients)
|
||||
|
||||
## Key Functions
|
||||
|
||||
@@ -18,513 +18,18 @@ Using the Olm ID and a secret, the olm will make HTTP requests to Pangolin to re
|
||||
|
||||
When Olm receives WireGuard control messages, it will use the information encoded (endpoint, public key) to bring up a WireGuard tunnel on your computer to a remote Newt. It will ping over the tunnel to ensure the peer is brought up.
|
||||
|
||||
## CLI Args
|
||||
|
||||
- `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket.
|
||||
- `id`: Olm ID generated by Pangolin to identify the olm.
|
||||
- `secret`: A unique secret (not shared and kept private) used to authenticate the olm ID with the websocket in order to receive commands.
|
||||
- `org` (optional): Organization ID to connect to.
|
||||
- `user-token` (optional): User authentication token.
|
||||
- `mtu` (optional): MTU for the internal WG interface. Default: 1280
|
||||
- `dns` (optional): DNS server to use to resolve the endpoint. Default: 8.8.8.8
|
||||
- `upstream-dns` (optional): Upstream DNS server(s), comma-separated. Default: 8.8.8.8:53
|
||||
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: INFO
|
||||
- `ping-interval` (optional): Interval for pinging the server. Default: 3s
|
||||
- `ping-timeout` (optional): Timeout for each ping. Default: 5s
|
||||
- `interface` (optional): Name of the WireGuard interface. Default: olm
|
||||
- `enable-api` (optional): Enable API server for receiving connection requests. Default: false
|
||||
- `http-addr` (optional): HTTP server address (e.g., ':9452'). Default: :9452
|
||||
- `socket-path` (optional): Unix socket path (or named pipe on Windows). Default: /var/run/olm.sock (Linux/macOS) or olm (Windows)
|
||||
- `disable-holepunch` (optional): Disable hole punching. Default: false
|
||||
- `override-dns` (optional): Override system DNS settings. Default: false
|
||||
- `disable-relay` (optional): Disable relay connections. Default: false
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All CLI arguments can also be set via environment variables:
|
||||
|
||||
- `PANGOLIN_ENDPOINT`: Equivalent to `--endpoint`
|
||||
- `OLM_ID`: Equivalent to `--id`
|
||||
- `OLM_SECRET`: Equivalent to `--secret`
|
||||
- `ORG`: Equivalent to `--org`
|
||||
- `USER_TOKEN`: Equivalent to `--user-token`
|
||||
- `MTU`: Equivalent to `--mtu`
|
||||
- `DNS`: Equivalent to `--dns`
|
||||
- `UPSTREAM_DNS`: Equivalent to `--upstream-dns`
|
||||
- `LOG_LEVEL`: Equivalent to `--log-level`
|
||||
- `INTERFACE`: Equivalent to `--interface`
|
||||
- `ENABLE_API`: Set to "true" to enable API server (equivalent to `--enable-api`)
|
||||
- `HTTP_ADDR`: Equivalent to `--http-addr`
|
||||
- `SOCKET_PATH`: Equivalent to `--socket-path`
|
||||
- `PING_INTERVAL`: Equivalent to `--ping-interval`
|
||||
- `PING_TIMEOUT`: Equivalent to `--ping-timeout`
|
||||
- `DISABLE_HOLEPUNCH`: Set to "true" to disable hole punching (equivalent to `--disable-holepunch`)
|
||||
- `OVERRIDE_DNS`: Set to "true" to override system DNS settings (equivalent to `--override-dns`)
|
||||
- `DISABLE_RELAY`: Set to "true" to disable relay connections (equivalent to `--disable-relay`)
|
||||
- `CONFIG_FILE`: Set to the location of a JSON file to load secret values
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
olm \
|
||||
--id 31frd0uzbjvp721 \
|
||||
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
||||
--endpoint https://example.com
|
||||
```
|
||||
|
||||
You can also run it with Docker compose. For example, a service in your `docker-compose.yml` might look like this using environment vars (recommended):
|
||||
|
||||
```yaml
|
||||
services:
|
||||
olm:
|
||||
image: fosrl/olm
|
||||
container_name: olm
|
||||
restart: unless-stopped
|
||||
network_mode: host
|
||||
devices:
|
||||
- /dev/net/tun:/dev/net/tun
|
||||
environment:
|
||||
- PANGOLIN_ENDPOINT=https://example.com
|
||||
- OLM_ID=31frd0uzbjvp721
|
||||
- OLM_SECRET=h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6
|
||||
```
|
||||
|
||||
You can also pass the CLI args to the container:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
olm:
|
||||
image: fosrl/olm
|
||||
container_name: olm
|
||||
restart: unless-stopped
|
||||
network_mode: host
|
||||
devices:
|
||||
- /dev/net/tun:/dev/net/tun
|
||||
command:
|
||||
- --id 31frd0uzbjvp721
|
||||
- --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6
|
||||
- --endpoint https://example.com
|
||||
```
|
||||
|
||||
**Docker Configuration Notes:**
|
||||
|
||||
- `network_mode: host` brings the olm network interface to the host system, allowing the WireGuard tunnel to function properly
|
||||
- `devices: - /dev/net/tun:/dev/net/tun` is required to give the container access to the TUN device for creating WireGuard interfaces
|
||||
|
||||
## Loading secrets from files
|
||||
|
||||
You can use `CONFIG_FILE` to define a location of a config file to store the credentials between runs.
|
||||
|
||||
```
|
||||
$ cat ~/.config/olm-client/config.json
|
||||
{
|
||||
"id": "spmzu8rbpzj1qq6",
|
||||
"secret": "f6v61mjutwme2kkydbw3fjo227zl60a2tsf5psw9r25hgae3",
|
||||
"endpoint": "https://app.pangolin.net",
|
||||
"org": "",
|
||||
"userToken": "",
|
||||
"mtu": 1280,
|
||||
"dns": "8.8.8.8",
|
||||
"upstreamDNS": ["8.8.8.8:53"],
|
||||
"interface": "olm",
|
||||
"logLevel": "INFO",
|
||||
"enableApi": false,
|
||||
"httpAddr": "",
|
||||
"socketPath": "/var/run/olm.sock",
|
||||
"pingInterval": "3s",
|
||||
"pingTimeout": "5s",
|
||||
"disableHolepunch": false,
|
||||
"overrideDNS": false,
|
||||
"disableRelay": false,
|
||||
"tlsClientCert": ""
|
||||
}
|
||||
```
|
||||
|
||||
This file is also written to when olm first starts up. So you do not need to run every time with --id and secret if you have run it once!
|
||||
|
||||
Default locations:
|
||||
|
||||
- **macOS**: `~/Library/Application Support/olm-client/config.json`
|
||||
- **Windows**: `%PROGRAMDATA%\olm\olm-client\config.json`
|
||||
- **Linux/Others**: `~/.config/olm-client/config.json`
|
||||
|
||||
## Hole Punching
|
||||
|
||||
In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to newt. If you want to disable hole punching, use the `--disable-holepunch` flag. Hole punching attempts to orchestrate a NAT hole punch between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil.
|
||||
|
||||
Right now, basic NAT hole punching is supported. We plan to add:
|
||||
|
||||
- [ ] Birthday paradox
|
||||
- [ ] UPnP
|
||||
- [ ] LAN detection
|
||||
|
||||
## Windows Service
|
||||
|
||||
On Windows, olm has to be installed and run as a Windows service. When running it with the cli args live above it will attempt to install and run the service to function like a cli tool. You can also run the following:
|
||||
|
||||
### Service Management Commands
|
||||
|
||||
```
|
||||
# Install the service
|
||||
olm.exe install
|
||||
|
||||
# Start the service
|
||||
olm.exe start
|
||||
|
||||
# Stop the service
|
||||
olm.exe stop
|
||||
|
||||
# Check service status
|
||||
olm.exe status
|
||||
|
||||
# Remove the service
|
||||
olm.exe remove
|
||||
|
||||
# Run in debug mode (console output) with our without id & secret
|
||||
olm.exe debug
|
||||
|
||||
# Show help
|
||||
olm.exe help
|
||||
```
|
||||
|
||||
Note running the service requires credentials in `%PROGRAMDATA%\olm\olm-client\config.json`.
|
||||
|
||||
### Service Configuration
|
||||
|
||||
When running as a service, Olm will read configuration from environment variables or you can modify the service to include command-line arguments:
|
||||
|
||||
1. Install the service: `olm.exe install`
|
||||
2. Set the credentials in `%PROGRAMDATA%\olm\olm-client\config.json`. Hint: if you run olm once with --id and --secret this file will be populated!
|
||||
3. Start the service: `olm.exe start`
|
||||
|
||||
### Service Logs
|
||||
|
||||
When running as a service, logs are written to:
|
||||
|
||||
- Windows Event Log (Application log, source: "OlmWireguardService")
|
||||
- Log files in: `%PROGRAMDATA%\olm\logs\olm.log`
|
||||
|
||||
You can view the Windows Event Log using Event Viewer or PowerShell:
|
||||
|
||||
```powershell
|
||||
Get-EventLog -LogName Application -Source "OlmWireguardService" -Newest 10
|
||||
```
|
||||
|
||||
## HTTP API
|
||||
|
||||
Olm can be controlled with an embedded HTTP server when using `--enable-http`. This allows you to start it as a daemon and trigger it with the following endpoints. The API can listen on either a TCP address or a Unix socket/Windows named pipe.
|
||||
|
||||
### Socket vs TCP
|
||||
|
||||
By default, when `--enable-http` is used, Olm listens on a TCP address (configured via `--http-addr`, default `:9452`). Alternatively, Olm can listen on a Unix socket (Linux/macOS) or Windows named pipe for local-only communication with better security.
|
||||
|
||||
**Unix Socket (Linux/macOS):**
|
||||
- Socket path example: `/var/run/olm/olm.sock`
|
||||
- The directory is created automatically if it doesn't exist
|
||||
- Socket permissions are set to `0666` to allow access
|
||||
- Existing socket files are automatically removed on startup
|
||||
- Socket file is cleaned up when Olm stops
|
||||
|
||||
**Windows Named Pipe:**
|
||||
- Pipe path example: `\\.\pipe\olm`
|
||||
- If the path doesn't start with `\`, it's automatically prefixed with `\\.\pipe\`
|
||||
- Security descriptor grants full access to Everyone and the current owner
|
||||
- Named pipes are automatically cleaned up by Windows
|
||||
|
||||
**Connecting to the Socket:**
|
||||
|
||||
```bash
|
||||
# Linux/macOS - using curl with Unix socket
|
||||
curl --unix-socket /var/run/olm/olm.sock http://localhost/status
|
||||
|
||||
---
|
||||
|
||||
### POST /connect
|
||||
Initiates a new connection request to a Pangolin server.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"id": "string",
|
||||
"secret": "string",
|
||||
"endpoint": "string",
|
||||
"userToken": "string",
|
||||
"mtu": 1280,
|
||||
"dns": "8.8.8.8",
|
||||
"dnsProxyIP": "string",
|
||||
"upstreamDNS": ["8.8.8.8:53", "1.1.1.1:53"],
|
||||
"interfaceName": "olm",
|
||||
"holepunch": false,
|
||||
"tlsClientCert": "string",
|
||||
"pingInterval": "3s",
|
||||
"pingTimeout": "5s",
|
||||
"orgId": "string"
|
||||
}
|
||||
```
|
||||
|
||||
**Required Fields:**
|
||||
- `id`: Olm ID generated by Pangolin
|
||||
- `secret`: Authentication secret for the Olm ID
|
||||
- `endpoint`: Target Pangolin endpoint URL
|
||||
|
||||
**Optional Fields:**
|
||||
- `userToken`: User authentication token
|
||||
- `mtu`: MTU for the internal WireGuard interface (default: 1280)
|
||||
- `dns`: DNS server to use for resolving the endpoint
|
||||
- `dnsProxyIP`: DNS proxy IP address
|
||||
- `upstreamDNS`: Array of upstream DNS servers
|
||||
- `interfaceName`: Name of the WireGuard interface (default: olm)
|
||||
- `holepunch`: Enable NAT hole punching (default: false)
|
||||
- `tlsClientCert`: TLS client certificate
|
||||
- `pingInterval`: Interval for pinging the server (default: 3s)
|
||||
- `pingTimeout`: Timeout for each ping (default: 5s)
|
||||
- `orgId`: Organization ID to connect to
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `202 Accepted`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "connection request accepted"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `400 Bad Request` - Invalid JSON or missing required fields
|
||||
- `409 Conflict` - Already connected to a server (disconnect first)
|
||||
|
||||
---
|
||||
|
||||
### GET /status
|
||||
Returns the current connection status, registration state, and peer information.
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"connected": true,
|
||||
"registered": true,
|
||||
"terminated": false,
|
||||
"version": "1.0.0",
|
||||
"agent": "olm",
|
||||
"orgId": "org_123",
|
||||
"peers": {
|
||||
"10": {
|
||||
"siteId": 10,
|
||||
"name": "Site A",
|
||||
"connected": true,
|
||||
"rtt": 145338339,
|
||||
"lastSeen": "2025-08-13T14:39:17.208334428-07:00",
|
||||
"endpoint": "p.fosrl.io:21820",
|
||||
"isRelay": true,
|
||||
"peerAddress": "100.89.128.5",
|
||||
"holepunchConnected": false
|
||||
},
|
||||
"8": {
|
||||
"siteId": 8,
|
||||
"name": "Site B",
|
||||
"connected": false,
|
||||
"rtt": 0,
|
||||
"lastSeen": "2025-08-13T14:39:19.663823645-07:00",
|
||||
"endpoint": "p.fosrl.io:21820",
|
||||
"isRelay": true,
|
||||
"peerAddress": "100.89.128.10",
|
||||
"holepunchConnected": false
|
||||
}
|
||||
},
|
||||
"networkSettings": {
|
||||
"tunnelIP": "100.89.128.3/20"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Fields:**
|
||||
- `connected`: Boolean indicating if connected to Pangolin
|
||||
- `registered`: Boolean indicating if registered with the server
|
||||
- `terminated`: Boolean indicating if the connection was terminated
|
||||
- `version`: Olm version string
|
||||
- `agent`: Agent identifier
|
||||
- `orgId`: Current organization ID
|
||||
- `peers`: Map of peer statuses by site ID
|
||||
- `siteId`: Peer site identifier
|
||||
- `name`: Site name
|
||||
- `connected`: Boolean peer connection state
|
||||
- `rtt`: Peer round-trip time (integer, nanoseconds)
|
||||
- `lastSeen`: Last time peer was seen (RFC3339 timestamp)
|
||||
- `endpoint`: Peer endpoint address
|
||||
- `isRelay`: Whether the peer is relayed (true) or direct (false)
|
||||
- `peerAddress`: Peer's IP address in the tunnel
|
||||
- `holepunchConnected`: Whether holepunch connection is established
|
||||
- `networkSettings`: Current network configuration including tunnel IP
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-GET requests
|
||||
|
||||
---
|
||||
|
||||
### POST /disconnect
|
||||
Disconnects from the current Pangolin server and tears down the WireGuard tunnel.
|
||||
|
||||
**Request Body:** None required
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "disconnect initiated"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `409 Conflict` - Not currently connected to a server
|
||||
|
||||
---
|
||||
|
||||
### POST /switch-org
|
||||
Switches to a different organization while maintaining the connection.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"orgId": "string"
|
||||
}
|
||||
```
|
||||
|
||||
**Required Fields:**
|
||||
- `orgId`: The organization ID to switch to
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "org switch request accepted"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `400 Bad Request` - Invalid JSON or missing orgId field
|
||||
- `500 Internal Server Error` - Org switch failed
|
||||
|
||||
---
|
||||
|
||||
### POST /exit
|
||||
Initiates a graceful shutdown of the Olm process.
|
||||
|
||||
**Request Body:** None required
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "shutdown initiated"
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** The response is sent before shutdown begins. There is a 100ms delay before the actual shutdown to ensure the response is delivered.
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
|
||||
---
|
||||
|
||||
### GET /health
|
||||
Simple health check endpoint to verify the API server is running.
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "ok"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-GET requests
|
||||
|
||||
---
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Connect to a peer
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/connect \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"id": "31frd0uzbjvp721",
|
||||
"secret": "h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6",
|
||||
"endpoint": "https://example.com"
|
||||
}'
|
||||
```
|
||||
|
||||
### Connect with additional options
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/connect \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"id": "31frd0uzbjvp721",
|
||||
"secret": "h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6",
|
||||
"endpoint": "https://example.com",
|
||||
"mtu": 1400,
|
||||
"holepunch": true,
|
||||
"pingInterval": "5s"
|
||||
}'
|
||||
```
|
||||
|
||||
### Check connection status
|
||||
```bash
|
||||
curl http://localhost:9452/status
|
||||
```
|
||||
|
||||
### Switch organization
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/switch-org \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"orgId": "org_456"}'
|
||||
```
|
||||
|
||||
### Disconnect from server
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/disconnect
|
||||
```
|
||||
|
||||
### Health check
|
||||
```bash
|
||||
curl http://localhost:9452/health
|
||||
```
|
||||
|
||||
### Shutdown Olm
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/exit
|
||||
```
|
||||
|
||||
### Using Unix socket (Linux/macOS)
|
||||
```bash
|
||||
curl --unix-socket /var/run/olm/olm.sock http://localhost/status
|
||||
curl --unix-socket /var/run/olm/olm.sock -X POST http://localhost/disconnect
|
||||
```
|
||||
In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to Newt. Hole punching attempts to orchestrate a NAT traversal between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil.
|
||||
|
||||
## Build
|
||||
|
||||
### Binary
|
||||
|
||||
Make sure to have Go 1.23.1 installed.
|
||||
Make sure to have Go 1.25 installed.
|
||||
|
||||
```bash
|
||||
make local
|
||||
make
|
||||
```
|
||||
|
||||
## Licensing
|
||||
|
||||
218
api/api.go
218
api/api.go
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -32,7 +33,12 @@ type ConnectionRequest struct {
|
||||
|
||||
// SwitchOrgRequest defines the structure for switching organizations
|
||||
type SwitchOrgRequest struct {
|
||||
OrgID string `json:"orgId"`
|
||||
OrgID string `json:"org_id"`
|
||||
}
|
||||
|
||||
// PowerModeRequest represents a request to change power mode
|
||||
type PowerModeRequest struct {
|
||||
Mode string `json:"mode"` // "normal" or "low"
|
||||
}
|
||||
|
||||
// PeerStatus represents the status of a peer connection
|
||||
@@ -48,11 +54,18 @@ type PeerStatus struct {
|
||||
HolepunchConnected bool `json:"holepunchConnected"`
|
||||
}
|
||||
|
||||
// OlmError holds error information from registration failures
|
||||
type OlmError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// StatusResponse is returned by the status endpoint
|
||||
type StatusResponse struct {
|
||||
Connected bool `json:"connected"`
|
||||
Registered bool `json:"registered"`
|
||||
Terminated bool `json:"terminated"`
|
||||
OlmError *OlmError `json:"error,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Agent string `json:"agent,omitempty"`
|
||||
OrgID string `json:"orgId,omitempty"`
|
||||
@@ -60,25 +73,37 @@ type StatusResponse struct {
|
||||
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
|
||||
}
|
||||
|
||||
type MetadataChangeRequest struct {
|
||||
Fingerprint map[string]any `json:"fingerprint"`
|
||||
Postures map[string]any `json:"postures"`
|
||||
}
|
||||
|
||||
// API represents the HTTP server and its state
|
||||
type API struct {
|
||||
addr string
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
onConnect func(ConnectionRequest) error
|
||||
onSwitchOrg func(SwitchOrgRequest) error
|
||||
onDisconnect func() error
|
||||
onExit func() error
|
||||
addr string
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
|
||||
onConnect func(ConnectionRequest) error
|
||||
onSwitchOrg func(SwitchOrgRequest) error
|
||||
onMetadataChange func(MetadataChangeRequest) error
|
||||
onDisconnect func() error
|
||||
onExit func() error
|
||||
onRebind func() error
|
||||
onPowerMode func(PowerModeRequest) error
|
||||
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
connectedAt time.Time
|
||||
isConnected bool
|
||||
isRegistered bool
|
||||
isTerminated bool
|
||||
version string
|
||||
agent string
|
||||
orgID string
|
||||
olmError *OlmError
|
||||
|
||||
version string
|
||||
agent string
|
||||
orgID string
|
||||
}
|
||||
|
||||
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||
@@ -101,28 +126,49 @@ func NewAPISocket(socketPath string) *API {
|
||||
return s
|
||||
}
|
||||
|
||||
func NewAPIStub() *API {
|
||||
s := &API{
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SetHandlers sets the callback functions for handling API requests
|
||||
func (s *API) SetHandlers(
|
||||
onConnect func(ConnectionRequest) error,
|
||||
onSwitchOrg func(SwitchOrgRequest) error,
|
||||
onMetadataChange func(MetadataChangeRequest) error,
|
||||
onDisconnect func() error,
|
||||
onExit func() error,
|
||||
onRebind func() error,
|
||||
onPowerMode func(PowerModeRequest) error,
|
||||
) {
|
||||
s.onConnect = onConnect
|
||||
s.onSwitchOrg = onSwitchOrg
|
||||
s.onMetadataChange = onMetadataChange
|
||||
s.onDisconnect = onDisconnect
|
||||
s.onExit = onExit
|
||||
s.onRebind = onRebind
|
||||
s.onPowerMode = onPowerMode
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *API) Start() error {
|
||||
if s.socketPath == "" && s.addr == "" {
|
||||
return fmt.Errorf("either socketPath or addr must be provided to start the API server")
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/connect", s.handleConnect)
|
||||
mux.HandleFunc("/status", s.handleStatus)
|
||||
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||
mux.HandleFunc("/metadata", s.handleMetadataChange)
|
||||
mux.HandleFunc("/disconnect", s.handleDisconnect)
|
||||
mux.HandleFunc("/exit", s.handleExit)
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
mux.HandleFunc("/rebind", s.handleRebind)
|
||||
mux.HandleFunc("/power-mode", s.handlePowerMode)
|
||||
|
||||
s.server = &http.Server{
|
||||
Handler: mux,
|
||||
@@ -160,7 +206,7 @@ func (s *API) Stop() error {
|
||||
|
||||
// Close the server first, which will also close the listener gracefully
|
||||
if s.server != nil {
|
||||
s.server.Close()
|
||||
_ = s.server.Close()
|
||||
}
|
||||
|
||||
// Clean up socket file if using Unix socket
|
||||
@@ -226,9 +272,6 @@ func (s *API) SetConnectionStatus(isConnected bool) {
|
||||
|
||||
if isConnected {
|
||||
s.connectedAt = time.Now()
|
||||
} else {
|
||||
// Clear peer statuses when disconnected
|
||||
s.peerStatuses = make(map[int]*PeerStatus)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,6 +279,27 @@ func (s *API) SetRegistered(registered bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.isRegistered = registered
|
||||
// Clear any registration error when successfully registered
|
||||
if registered {
|
||||
s.olmError = nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetOlmError sets the registration error
|
||||
func (s *API) SetOlmError(code string, message string) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.olmError = &OlmError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// ClearOlmError clears any registration error
|
||||
func (s *API) ClearOlmError() {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.olmError = nil
|
||||
}
|
||||
|
||||
func (s *API) SetTerminated(terminated bool) {
|
||||
@@ -345,7 +409,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "connection request accepted",
|
||||
})
|
||||
}
|
||||
@@ -358,12 +422,12 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
s.statusMu.RLock()
|
||||
defer s.statusMu.RUnlock()
|
||||
|
||||
resp := StatusResponse{
|
||||
Connected: s.isConnected,
|
||||
Registered: s.isRegistered,
|
||||
Terminated: s.isTerminated,
|
||||
OlmError: s.olmError,
|
||||
Version: s.version,
|
||||
Agent: s.agent,
|
||||
OrgID: s.orgID,
|
||||
@@ -371,8 +435,18 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
NetworkSettings: network.GetSettings(),
|
||||
}
|
||||
|
||||
s.statusMu.RUnlock()
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
// handleHealth handles the /health endpoint
|
||||
@@ -384,7 +458,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "ok",
|
||||
})
|
||||
}
|
||||
@@ -401,7 +475,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||
// Return a success response first
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "shutdown initiated",
|
||||
})
|
||||
|
||||
@@ -450,7 +524,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "org switch request accepted",
|
||||
})
|
||||
}
|
||||
@@ -484,16 +558,43 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "disconnect initiated",
|
||||
})
|
||||
}
|
||||
|
||||
// handleMetadataChange handles the /metadata endpoint
|
||||
func (s *API) handleMetadataChange(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req MetadataChangeRequest
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received metadata change request via API: %v", req)
|
||||
|
||||
_ = s.onMetadataChange(req)
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "metadata updated",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *API) GetStatus() StatusResponse {
|
||||
return StatusResponse{
|
||||
Connected: s.isConnected,
|
||||
Registered: s.isRegistered,
|
||||
Terminated: s.isTerminated,
|
||||
OlmError: s.olmError,
|
||||
Version: s.version,
|
||||
Agent: s.agent,
|
||||
OrgID: s.orgID,
|
||||
@@ -501,3 +602,74 @@ func (s *API) GetStatus() StatusResponse {
|
||||
NetworkSettings: network.GetSettings(),
|
||||
}
|
||||
}
|
||||
|
||||
// handleRebind handles the /rebind endpoint
|
||||
// This triggers a socket rebind, which is necessary when network connectivity changes
|
||||
// (e.g., WiFi to cellular transition on macOS/iOS) and the old socket becomes stale.
|
||||
func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received rebind request via API")
|
||||
|
||||
// Call the rebind handler if set
|
||||
if s.onRebind != nil {
|
||||
if err := s.onRebind(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Rebind failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "Rebind handler not configured", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "socket rebound successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// handlePowerMode handles the /power-mode endpoint
|
||||
// This allows changing the power mode between "normal" and "low"
|
||||
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req PowerModeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate power mode
|
||||
if req.Mode != "normal" && req.Mode != "low" {
|
||||
http.Error(w, "Invalid power mode: must be 'normal' or 'low'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received power mode change request via API: mode=%s", req.Mode)
|
||||
|
||||
// Call the power mode handler if set
|
||||
if s.onPowerMode != nil {
|
||||
if err := s.onPowerMode(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Power mode change failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "Power mode handler not configured", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": fmt.Sprintf("power mode changed to %s successfully", req.Mode),
|
||||
})
|
||||
}
|
||||
|
||||
16
config.go
16
config.go
@@ -43,6 +43,7 @@ type OlmConfig struct {
|
||||
DisableHolepunch bool `json:"disableHolepunch"`
|
||||
TlsClientCert string `json:"tlsClientCert"`
|
||||
OverrideDNS bool `json:"overrideDNS"`
|
||||
TunnelDNS bool `json:"tunnelDNS"`
|
||||
DisableRelay bool `json:"disableRelay"`
|
||||
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
|
||||
|
||||
@@ -88,6 +89,8 @@ func DefaultConfig() *OlmConfig {
|
||||
PingInterval: "3s",
|
||||
PingTimeout: "5s",
|
||||
DisableHolepunch: false,
|
||||
OverrideDNS: true,
|
||||
TunnelDNS: false,
|
||||
// DoNotCreateNewClient: false,
|
||||
sources: make(map[string]string),
|
||||
}
|
||||
@@ -105,6 +108,7 @@ func DefaultConfig() *OlmConfig {
|
||||
config.sources["pingTimeout"] = string(SourceDefault)
|
||||
config.sources["disableHolepunch"] = string(SourceDefault)
|
||||
config.sources["overrideDNS"] = string(SourceDefault)
|
||||
config.sources["tunnelDNS"] = string(SourceDefault)
|
||||
config.sources["disableRelay"] = string(SourceDefault)
|
||||
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
|
||||
|
||||
@@ -265,6 +269,10 @@ func loadConfigFromEnv(config *OlmConfig) {
|
||||
config.DisableRelay = true
|
||||
config.sources["disableRelay"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("TUNNEL_DNS"); val == "true" {
|
||||
config.TunnelDNS = true
|
||||
config.sources["tunnelDNS"] = string(SourceEnv)
|
||||
}
|
||||
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
|
||||
// config.DoNotCreateNewClient = true
|
||||
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
|
||||
@@ -295,6 +303,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||
"disableHolepunch": config.DisableHolepunch,
|
||||
"overrideDNS": config.OverrideDNS,
|
||||
"disableRelay": config.DisableRelay,
|
||||
"tunnelDNS": config.TunnelDNS,
|
||||
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||
}
|
||||
|
||||
@@ -316,8 +325,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
|
||||
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
|
||||
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
|
||||
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings")
|
||||
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "When enabled, the client uses custom DNS servers to resolve internal resources and aliases. This overrides your system's default DNS settings. Queries that cannot be resolved as a Pangolin resource will be forwarded to your configured Upstream DNS Server. (default false)")
|
||||
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
|
||||
serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "When enabled, DNS queries are routed through the tunnel for remote resolution. To ensure queries are tunneled correctly, you must define the DNS server as a Pangolin resource and enter its address as an Upstream DNS Server. (default false)")
|
||||
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
|
||||
|
||||
version := serviceFlags.Bool("version", false, "Print the version")
|
||||
@@ -393,6 +403,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||
if config.DisableRelay != origValues["disableRelay"].(bool) {
|
||||
config.sources["disableRelay"] = string(SourceCLI)
|
||||
}
|
||||
if config.TunnelDNS != origValues["tunnelDNS"].(bool) {
|
||||
config.sources["tunnelDNS"] = string(SourceCLI)
|
||||
}
|
||||
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
|
||||
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
|
||||
// }
|
||||
@@ -606,6 +619,7 @@ func (c *OlmConfig) ShowConfig() {
|
||||
fmt.Println("\nAdvanced:")
|
||||
fmt.Printf(" disable-holepunch = %v [%s]\n", c.DisableHolepunch, getSource("disableHolepunch"))
|
||||
fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS"))
|
||||
fmt.Printf(" tunnel-dns = %v [%s]\n", c.TunnelDNS, getSource("tunnelDNS"))
|
||||
fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay"))
|
||||
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
|
||||
if c.TlsClientCert != "" {
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@@ -18,14 +21,68 @@ type FilterRule struct {
|
||||
Handler PacketHandler
|
||||
}
|
||||
|
||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||
type MiddleDevice struct {
|
||||
// closeAwareDevice wraps a tun.Device along with a flag
|
||||
// indicating whether its Close method was called.
|
||||
type closeAwareDevice struct {
|
||||
isClosed atomic.Bool
|
||||
tun.Device
|
||||
rules []FilterRule
|
||||
mutex sync.RWMutex
|
||||
readCh chan readResult
|
||||
injectCh chan []byte
|
||||
closed chan struct{}
|
||||
closeEventCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice {
|
||||
return &closeAwareDevice{
|
||||
Device: tunDevice,
|
||||
isClosed: atomic.Bool{},
|
||||
closeEventCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// redirectEvents redirects the Events() method of the underlying tun.Device
|
||||
// to the given channel.
|
||||
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-c.Device.Events():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if ev == tun.EventDown {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case out <- ev:
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Close calls the underlying Device's Close method
|
||||
// after setting isClosed to true.
|
||||
func (c *closeAwareDevice) Close() (err error) {
|
||||
c.closeOnce.Do(func() {
|
||||
c.isClosed.Store(true)
|
||||
close(c.closeEventCh)
|
||||
err = c.Device.Close()
|
||||
c.wg.Wait()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *closeAwareDevice) IsClosed() bool {
|
||||
return c.isClosed.Load()
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
@@ -36,58 +93,136 @@ type readResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||
// and supports swapping the underlying device.
|
||||
type MiddleDevice struct {
|
||||
devices []*closeAwareDevice
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
rules []FilterRule
|
||||
rulesMutex sync.RWMutex
|
||||
readCh chan readResult
|
||||
injectCh chan []byte
|
||||
closed atomic.Bool
|
||||
events chan tun.Event
|
||||
}
|
||||
|
||||
// NewMiddleDevice creates a new filtered TUN device wrapper
|
||||
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||
d := &MiddleDevice{
|
||||
Device: device,
|
||||
devices: make([]*closeAwareDevice, 0),
|
||||
rules: make([]FilterRule, 0),
|
||||
readCh: make(chan readResult),
|
||||
readCh: make(chan readResult, 16),
|
||||
injectCh: make(chan []byte, 100),
|
||||
closed: make(chan struct{}),
|
||||
events: make(chan tun.Event, 16),
|
||||
}
|
||||
go d.pump()
|
||||
d.cond = sync.NewCond(&d.mu)
|
||||
|
||||
if device != nil {
|
||||
d.AddDevice(device)
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) pump() {
|
||||
// AddDevice adds a new underlying TUN device, closing any previous one
|
||||
func (d *MiddleDevice) AddDevice(device tun.Device) {
|
||||
d.mu.Lock()
|
||||
if d.closed.Load() {
|
||||
d.mu.Unlock()
|
||||
_ = device.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var toClose *closeAwareDevice
|
||||
if len(d.devices) > 0 {
|
||||
toClose = d.devices[len(d.devices)-1]
|
||||
}
|
||||
|
||||
cad := newCloseAwareDevice(device)
|
||||
cad.redirectEvents(d.events)
|
||||
|
||||
d.devices = []*closeAwareDevice{cad}
|
||||
|
||||
// Start pump for the new device
|
||||
go d.pump(cad)
|
||||
|
||||
d.cond.Broadcast()
|
||||
d.mu.Unlock()
|
||||
|
||||
if toClose != nil {
|
||||
logger.Debug("MiddleDevice: Closing previous device")
|
||||
if err := toClose.Close(); err != nil {
|
||||
logger.Debug("MiddleDevice: Error closing previous device: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) pump(dev *closeAwareDevice) {
|
||||
const defaultOffset = 16
|
||||
batchSize := d.Device.BatchSize()
|
||||
logger.Debug("MiddleDevice: pump started")
|
||||
batchSize := dev.BatchSize()
|
||||
logger.Debug("MiddleDevice: pump started for device")
|
||||
|
||||
// Recover from panic if readCh is closed while we're trying to send
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Debug("MiddleDevice: pump recovered from panic (channel closed)")
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
// Check closed first with priority
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel")
|
||||
// Check if this device is closed
|
||||
if dev.IsClosed() {
|
||||
logger.Debug("MiddleDevice: pump exiting, device is closed")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if MiddleDevice itself is closed
|
||||
if d.closed.Load() {
|
||||
logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Allocate buffers for reading
|
||||
// We allocate new buffers for each read to avoid race conditions
|
||||
// since we pass them to the channel
|
||||
bufs := make([][]byte, batchSize)
|
||||
sizes := make([]int, batchSize)
|
||||
for i := range bufs {
|
||||
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
|
||||
}
|
||||
|
||||
n, err := d.Device.Read(bufs, sizes, defaultOffset)
|
||||
n, err := dev.Read(bufs, sizes, defaultOffset)
|
||||
|
||||
// Check closed again after read returns
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
|
||||
// Check if device was closed during read
|
||||
if dev.IsClosed() {
|
||||
logger.Debug("MiddleDevice: pump exiting, device closed during read")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if MiddleDevice was closed during read
|
||||
if d.closed.Load() {
|
||||
logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read")
|
||||
return
|
||||
}
|
||||
|
||||
// Try to send the result - check closed state first to avoid sending on closed channel
|
||||
if d.closed.Load() {
|
||||
logger.Debug("MiddleDevice: pump exiting, device closed before send")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Now try to send the result
|
||||
select {
|
||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
|
||||
return
|
||||
default:
|
||||
// Channel full, check if we should exit
|
||||
if dev.IsClosed() || d.closed.Load() {
|
||||
return
|
||||
}
|
||||
// Try again with blocking
|
||||
select {
|
||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||
case <-dev.closeEventCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -99,16 +234,28 @@ func (d *MiddleDevice) pump() {
|
||||
|
||||
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
|
||||
func (d *MiddleDevice) InjectOutbound(packet []byte) {
|
||||
if d.closed.Load() {
|
||||
return
|
||||
}
|
||||
// Use defer/recover to handle panic from sending on closed channel
|
||||
// This can happen during shutdown race conditions
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Debug("MiddleDevice: InjectOutbound recovered from panic (channel closed)")
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case d.injectCh <- packet:
|
||||
case <-d.closed:
|
||||
default:
|
||||
// Channel full, drop packet
|
||||
logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full")
|
||||
}
|
||||
}
|
||||
|
||||
// AddRule adds a packet filtering rule
|
||||
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
d.rulesMutex.Lock()
|
||||
defer d.rulesMutex.Unlock()
|
||||
d.rules = append(d.rules, FilterRule{
|
||||
DestIP: destIP,
|
||||
Handler: handler,
|
||||
@@ -117,8 +264,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||
|
||||
// RemoveRule removes all rules for a given destination IP
|
||||
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
d.rulesMutex.Lock()
|
||||
defer d.rulesMutex.Unlock()
|
||||
newRules := make([]FilterRule, 0, len(d.rules))
|
||||
for _, rule := range d.rules {
|
||||
if rule.DestIP != destIP {
|
||||
@@ -130,18 +277,120 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||
|
||||
// Close stops the device
|
||||
func (d *MiddleDevice) Close() error {
|
||||
select {
|
||||
case <-d.closed:
|
||||
// Already closed
|
||||
return nil
|
||||
default:
|
||||
logger.Debug("MiddleDevice: Closing, signaling closed channel")
|
||||
close(d.closed)
|
||||
if !d.closed.CompareAndSwap(false, true) {
|
||||
return nil // already closed
|
||||
}
|
||||
logger.Debug("MiddleDevice: Closing underlying TUN device")
|
||||
err := d.Device.Close()
|
||||
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
|
||||
return err
|
||||
|
||||
d.mu.Lock()
|
||||
devices := d.devices
|
||||
d.devices = nil
|
||||
d.cond.Broadcast()
|
||||
d.mu.Unlock()
|
||||
|
||||
// Close underlying devices first - this causes the pump goroutines to exit
|
||||
// when their read operations return errors
|
||||
var lastErr error
|
||||
logger.Debug("MiddleDevice: Closing %d devices", len(devices))
|
||||
for _, device := range devices {
|
||||
if err := device.Close(); err != nil {
|
||||
logger.Debug("MiddleDevice: Error closing device: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
// Now close channels to unblock any remaining readers
|
||||
// The pump should have exited by now, but close channels to be safe
|
||||
close(d.readCh)
|
||||
close(d.injectCh)
|
||||
close(d.events)
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// Events returns the events channel
|
||||
func (d *MiddleDevice) Events() <-chan tun.Event {
|
||||
return d.events
|
||||
}
|
||||
|
||||
// File returns the underlying file descriptor
|
||||
func (d *MiddleDevice) File() *os.File {
|
||||
for {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
file := dev.File()
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return file
|
||||
}
|
||||
}
|
||||
|
||||
// MTU returns the MTU of the underlying device
|
||||
func (d *MiddleDevice) MTU() (int, error) {
|
||||
for {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
mtu, err := dev.MTU()
|
||||
if err == nil {
|
||||
return mtu, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the name of the underlying device
|
||||
func (d *MiddleDevice) Name() (string, error) {
|
||||
for {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return "", io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
name, err := dev.Name()
|
||||
if err == nil {
|
||||
return name, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// BatchSize returns the batch size
|
||||
func (d *MiddleDevice) BatchSize() int {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
return 1
|
||||
}
|
||||
return dev.BatchSize()
|
||||
}
|
||||
|
||||
// extractDestIP extracts destination IP from packet (fast path)
|
||||
@@ -176,156 +425,239 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
|
||||
|
||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
// Check if already closed first (non-blocking)
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
|
||||
return 0, os.ErrClosed
|
||||
default:
|
||||
}
|
||||
|
||||
// Now block waiting for data
|
||||
select {
|
||||
case res := <-d.readCh:
|
||||
if res.err != nil {
|
||||
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
||||
return 0, res.err
|
||||
for {
|
||||
if d.closed.Load() {
|
||||
logger.Debug("MiddleDevice: Read returning io.EOF, device closed")
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Copy packets from result to provided buffers
|
||||
count := 0
|
||||
for i := 0; i < res.n && i < len(bufs); i++ {
|
||||
// Handle offset mismatch if necessary
|
||||
// We assume the pump used defaultOffset (16)
|
||||
// If caller asks for different offset, we need to shift
|
||||
src := res.bufs[i]
|
||||
srcOffset := res.offset
|
||||
srcSize := res.sizes[i]
|
||||
|
||||
// Calculate where the packet data starts and ends in src
|
||||
pktData := src[srcOffset : srcOffset+srcSize]
|
||||
|
||||
// Ensure dest buffer is large enough
|
||||
if len(bufs[i]) < offset+len(pktData) {
|
||||
continue // Skip if buffer too small
|
||||
// Wait for a device to be available
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
copy(bufs[i][offset:], pktData)
|
||||
sizes[i] = len(pktData)
|
||||
count++
|
||||
}
|
||||
n = count
|
||||
|
||||
case pkt := <-d.injectCh:
|
||||
if len(bufs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if len(bufs[0]) < offset+len(pkt) {
|
||||
return 0, nil // Buffer too small
|
||||
}
|
||||
copy(bufs[0][offset:], pkt)
|
||||
sizes[0] = len(pkt)
|
||||
n = 1
|
||||
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
|
||||
return 0, os.ErrClosed // Signal that device is closed
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
rules := d.rules
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Process packets and filter out handled ones
|
||||
writeIdx := 0
|
||||
for readIdx := 0; readIdx < n; readIdx++ {
|
||||
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
||||
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
// Can't parse, keep packet
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if packet matches any rule
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
// Packet was handled and should be dropped
|
||||
handled = true
|
||||
break
|
||||
// Now block waiting for data from readCh or injectCh
|
||||
select {
|
||||
case res, ok := <-d.readCh:
|
||||
if !ok {
|
||||
// Channel closed, device is shutting down
|
||||
return 0, io.EOF
|
||||
}
|
||||
if res.err != nil {
|
||||
// Check if device was swapped
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
||||
return 0, res.err
|
||||
}
|
||||
|
||||
// Copy packets from result to provided buffers
|
||||
count := 0
|
||||
for i := 0; i < res.n && i < len(bufs); i++ {
|
||||
src := res.bufs[i]
|
||||
srcOffset := res.offset
|
||||
srcSize := res.sizes[i]
|
||||
|
||||
pktData := src[srcOffset : srcOffset+srcSize]
|
||||
|
||||
if len(bufs[i]) < offset+len(pktData) {
|
||||
continue
|
||||
}
|
||||
|
||||
copy(bufs[i][offset:], pktData)
|
||||
sizes[i] = len(pktData)
|
||||
count++
|
||||
}
|
||||
n = count
|
||||
|
||||
case pkt, ok := <-d.injectCh:
|
||||
if !ok {
|
||||
// Channel closed, device is shutting down
|
||||
return 0, io.EOF
|
||||
}
|
||||
if len(bufs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if len(bufs[0]) < offset+len(pkt) {
|
||||
return 0, nil
|
||||
}
|
||||
copy(bufs[0][offset:], pkt)
|
||||
sizes[0] = len(pkt)
|
||||
n = 1
|
||||
}
|
||||
|
||||
// Apply filtering rules
|
||||
d.rulesMutex.RLock()
|
||||
rules := d.rules
|
||||
d.rulesMutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Process packets and filter out handled ones
|
||||
writeIdx := 0
|
||||
for readIdx := 0; readIdx < n; readIdx++ {
|
||||
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
||||
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
// Keep packet
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
if !handled {
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
}
|
||||
writeIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return writeIdx, err
|
||||
return writeIdx, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
||||
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
d.mutex.RLock()
|
||||
rules := d.rules
|
||||
d.mutex.RUnlock()
|
||||
for {
|
||||
if d.closed.Load() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if len(rules) == 0 {
|
||||
return d.Device.Write(bufs, offset)
|
||||
}
|
||||
|
||||
// Filter packets going down
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
for _, buf := range bufs {
|
||||
if len(buf) <= offset {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
packet := buf[offset:]
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
// Can't parse, keep packet
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
continue
|
||||
}
|
||||
d.rulesMutex.RLock()
|
||||
rules := d.rules
|
||||
d.rulesMutex.RUnlock()
|
||||
|
||||
// Check if packet matches any rule
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
// Packet was handled and should be dropped
|
||||
handled = true
|
||||
break
|
||||
var filteredBufs [][]byte
|
||||
if len(rules) == 0 {
|
||||
filteredBufs = bufs
|
||||
} else {
|
||||
filteredBufs = make([][]byte, 0, len(bufs))
|
||||
for _, buf := range bufs {
|
||||
if len(buf) <= offset {
|
||||
continue
|
||||
}
|
||||
|
||||
packet := buf[offset:]
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
continue
|
||||
}
|
||||
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
if len(filteredBufs) == 0 {
|
||||
return len(bufs), nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredBufs) == 0 {
|
||||
return len(bufs), nil // All packets were handled
|
||||
}
|
||||
n, err := dev.Write(filteredBufs, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
return d.Device.Write(filteredBufs, offset)
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) waitForDevice() bool {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
for len(d.devices) == 0 && !d.closed.Load() {
|
||||
d.cond.Wait()
|
||||
}
|
||||
return !d.closed.Load()
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) peekLast() *closeAwareDevice {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if len(d.devices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return d.devices[len(d.devices)-1]
|
||||
}
|
||||
|
||||
// WriteToTun writes packets directly to the underlying TUN device,
|
||||
// bypassing WireGuard. This is useful for sending packets that should
|
||||
// appear to come from the TUN interface (e.g., DNS responses from a proxy).
|
||||
// Unlike Write(), this does not go through packet filtering rules.
|
||||
func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) {
|
||||
for {
|
||||
if d.closed.Load() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := dev.Write(bufs, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !windows
|
||||
//go:build darwin
|
||||
|
||||
package device
|
||||
|
||||
@@ -26,7 +26,7 @@ func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
|
||||
device, err := tun.CreateTUNFromFile(file, mtuInt)
|
||||
device, err := tun.CreateTUNFromFile(file, 0)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, err
|
||||
50
device/tun_linux.go
Normal file
50
device/tun_linux.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build linux
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
if runtime.GOOS == "android" { // otherwise we get a permission denied
|
||||
theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
|
||||
return theTun, err
|
||||
}
|
||||
|
||||
dupTunFd, err := unix.Dup(int(tunFd))
|
||||
if err != nil {
|
||||
logger.Error("Unable to dup tun fd: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(dupTunFd, true)
|
||||
if err != nil {
|
||||
unix.Close(dupTunFd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
|
||||
device, err := tun.CreateTUNFromFile(file, mtuInt)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(interfaceName)
|
||||
}
|
||||
|
||||
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||
}
|
||||
358
dns/dns_proxy.go
358
dns/dns_proxy.go
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/device"
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
@@ -34,18 +33,25 @@ type DNSProxy struct {
|
||||
ep *channel.Endpoint
|
||||
proxyIP netip.Addr
|
||||
upstreamDNS []string
|
||||
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
|
||||
mtu int
|
||||
tunDevice tun.Device // Direct reference to underlying TUN device for responses
|
||||
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
|
||||
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes
|
||||
recordStore *DNSRecordStore // Local DNS records
|
||||
|
||||
// Tunnel DNS fields - for sending queries over WireGuard
|
||||
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
|
||||
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
|
||||
tunnelEp *channel.Endpoint
|
||||
tunnelActivePorts map[uint16]bool
|
||||
tunnelPortsLock sync.Mutex
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewDNSProxy creates a new DNS proxy
|
||||
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string) (*DNSProxy, error) {
|
||||
func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
|
||||
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
||||
@@ -58,17 +64,27 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
proxy := &DNSProxy{
|
||||
proxyIP: proxyIP,
|
||||
mtu: mtu,
|
||||
tunDevice: tunDevice,
|
||||
middleDevice: middleDevice,
|
||||
upstreamDNS: upstreamDns,
|
||||
recordStore: NewDNSRecordStore(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
proxyIP: proxyIP,
|
||||
mtu: mtu,
|
||||
middleDevice: middleDevice,
|
||||
upstreamDNS: upstreamDns,
|
||||
tunnelDNS: tunnelDns,
|
||||
recordStore: NewDNSRecordStore(),
|
||||
tunnelActivePorts: make(map[uint16]bool),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create gvisor netstack
|
||||
// Parse tunnel IP if provided (needed for tunneled DNS)
|
||||
if tunnelIP != "" {
|
||||
addr, err := netip.ParseAddr(tunnelIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tunnel IP: %v", err)
|
||||
}
|
||||
proxy.tunnelIP = addr
|
||||
}
|
||||
|
||||
// Create gvisor netstack for receiving DNS queries
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
@@ -101,9 +117,104 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
// Initialize tunnel netstack if tunnel DNS is enabled
|
||||
if tunnelDns {
|
||||
if !proxy.tunnelIP.IsValid() {
|
||||
return nil, fmt.Errorf("tunnel IP is required when tunnelDNS is enabled")
|
||||
}
|
||||
|
||||
// TODO: DO WE NEED TO ESTABLISH ANOTHER NETSTACK HERE OR CAN WE COMBINE WITH WGTESTER?
|
||||
if err := proxy.initTunnelNetstack(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize tunnel netstack: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// initTunnelNetstack creates a separate netstack for outbound DNS queries through the tunnel
|
||||
func (p *DNSProxy) initTunnelNetstack() error {
|
||||
// Create gvisor netstack for outbound tunnel queries
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
HandleLocal: true,
|
||||
}
|
||||
|
||||
p.tunnelEp = channel.New(256, uint32(p.mtu), "")
|
||||
p.tunnelStack = stack.New(stackOpts)
|
||||
|
||||
// Create NIC
|
||||
if err := p.tunnelStack.CreateNIC(1, p.tunnelEp); err != nil {
|
||||
return fmt.Errorf("failed to create tunnel NIC: %v", err)
|
||||
}
|
||||
|
||||
// Add tunnel IP address (WireGuard interface IP)
|
||||
ipBytes := p.tunnelIP.As4()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
||||
}
|
||||
|
||||
if err := p.tunnelStack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return fmt.Errorf("failed to add tunnel protocol address: %v", err)
|
||||
}
|
||||
|
||||
// Add default route
|
||||
p.tunnelStack.AddRoute(tcpip.Route{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
// Register filter rule on MiddleDevice to intercept responses
|
||||
p.middleDevice.AddRule(p.tunnelIP, p.handleTunnelResponse)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleTunnelResponse handles packets coming back from the tunnel destined for the tunnel IP
|
||||
func (p *DNSProxy) handleTunnelResponse(packet []byte) bool {
|
||||
// Check if it's UDP
|
||||
proto, ok := util.GetProtocol(packet)
|
||||
if !ok || proto != 17 { // UDP
|
||||
return false
|
||||
}
|
||||
|
||||
// Check destination port - should be one of our active outbound ports
|
||||
port, ok := util.GetDestPort(packet)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if we are expecting a response on this port
|
||||
p.tunnelPortsLock.Lock()
|
||||
active := p.tunnelActivePorts[uint16(port)]
|
||||
p.tunnelPortsLock.Unlock()
|
||||
|
||||
if !active {
|
||||
return false
|
||||
}
|
||||
|
||||
// Inject into tunnel netstack
|
||||
version := packet[0] >> 4
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
p.tunnelEp.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||
case 6:
|
||||
p.tunnelEp.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||
default:
|
||||
pkb.DecRef()
|
||||
return false
|
||||
}
|
||||
|
||||
pkb.DecRef()
|
||||
return true // Handled
|
||||
}
|
||||
|
||||
// Start starts the DNS proxy and registers with the filter
|
||||
func (p *DNSProxy) Start() error {
|
||||
// Install packet filter rule
|
||||
@@ -114,7 +225,13 @@ func (p *DNSProxy) Start() error {
|
||||
go p.runDNSListener()
|
||||
go p.runPacketSender()
|
||||
|
||||
logger.Info("DNS proxy started on %s:%d", p.proxyIP.String(), DNSPort)
|
||||
// Start tunnel packet sender if tunnel DNS is enabled
|
||||
if p.tunnelDNS {
|
||||
p.wg.Add(1)
|
||||
go p.runTunnelPacketSender()
|
||||
}
|
||||
|
||||
logger.Info("DNS proxy started on %s:%d (tunnelDNS=%v)", p.proxyIP.String(), DNSPort, p.tunnelDNS)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -122,6 +239,9 @@ func (p *DNSProxy) Start() error {
|
||||
func (p *DNSProxy) Stop() {
|
||||
if p.middleDevice != nil {
|
||||
p.middleDevice.RemoveRule(p.proxyIP)
|
||||
if p.tunnelDNS && p.tunnelIP.IsValid() {
|
||||
p.middleDevice.RemoveRule(p.tunnelIP)
|
||||
}
|
||||
}
|
||||
p.cancel()
|
||||
|
||||
@@ -130,12 +250,21 @@ func (p *DNSProxy) Stop() {
|
||||
p.ep.Close()
|
||||
}
|
||||
|
||||
// Close tunnel endpoint if it exists
|
||||
if p.tunnelEp != nil {
|
||||
p.tunnelEp.Close()
|
||||
}
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
if p.stack != nil {
|
||||
p.stack.Close()
|
||||
}
|
||||
|
||||
if p.tunnelStack != nil {
|
||||
p.tunnelStack.Close()
|
||||
}
|
||||
|
||||
logger.Info("DNS proxy stopped")
|
||||
}
|
||||
|
||||
@@ -251,7 +380,7 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
|
||||
|
||||
// Check if we have local records for this query
|
||||
var response *dns.Msg
|
||||
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
|
||||
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA || question.Qtype == dns.TypePTR {
|
||||
response = p.checkLocalRecords(msg, question)
|
||||
}
|
||||
|
||||
@@ -281,6 +410,34 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
|
||||
|
||||
// checkLocalRecords checks if we have local records for the query
|
||||
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
|
||||
// Handle PTR queries
|
||||
if question.Qtype == dns.TypePTR {
|
||||
if ptrDomain, ok := p.recordStore.GetPTRRecord(question.Name); ok {
|
||||
logger.Debug("Found local PTR record for %s -> %s", question.Name, ptrDomain)
|
||||
|
||||
// Create response message
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(query)
|
||||
response.Authoritative = true
|
||||
|
||||
// Add PTR answer record
|
||||
rr := &dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300, // 5 minutes
|
||||
},
|
||||
Ptr: ptrDomain,
|
||||
}
|
||||
response.Answer = append(response.Answer, rr)
|
||||
|
||||
return response
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle A and AAAA queries
|
||||
var recordType RecordType
|
||||
if question.Qtype == dns.TypeA {
|
||||
recordType = RecordTypeA
|
||||
@@ -348,8 +505,16 @@ func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
|
||||
return response
|
||||
}
|
||||
|
||||
// queryUpstream sends a DNS query to upstream server using miekg/dns
|
||||
// queryUpstream sends a DNS query to upstream server
|
||||
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||
if p.tunnelDNS {
|
||||
return p.queryUpstreamTunnel(server, query, timeout)
|
||||
}
|
||||
return p.queryUpstreamDirect(server, query, timeout)
|
||||
}
|
||||
|
||||
// queryUpstreamDirect sends a DNS query to upstream server using miekg/dns directly (host networking)
|
||||
func (p *DNSProxy) queryUpstreamDirect(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||
client := &dns.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
@@ -362,6 +527,147 @@ func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Dur
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// queryUpstreamTunnel sends a DNS query through the WireGuard tunnel
|
||||
func (p *DNSProxy) queryUpstreamTunnel(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||
// Dial through the tunnel netstack
|
||||
conn, port, err := p.dialTunnel("udp", server)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial tunnel: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
conn.Close()
|
||||
p.removeTunnelPort(port)
|
||||
}()
|
||||
|
||||
// Pack the query
|
||||
queryData, err := query.Pack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pack query: %v", err)
|
||||
}
|
||||
|
||||
// Set deadline
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
|
||||
// Send the query
|
||||
_, err = conn.Write(queryData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send query: %v", err)
|
||||
}
|
||||
|
||||
// Read the response
|
||||
buf := make([]byte, 4096)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
// Parse the response
|
||||
response := new(dns.Msg)
|
||||
if err := response.Unpack(buf[:n]); err != nil {
|
||||
return nil, fmt.Errorf("failed to unpack response: %v", err)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// dialTunnel creates a UDP connection through the tunnel netstack
|
||||
func (p *DNSProxy) dialTunnel(network, addr string) (net.Conn, uint16, error) {
|
||||
if p.tunnelStack == nil {
|
||||
return nil, 0, fmt.Errorf("tunnel netstack not initialized")
|
||||
}
|
||||
|
||||
// Parse remote address
|
||||
raddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Use tunnel IP as source
|
||||
ipBytes := p.tunnelIP.As4()
|
||||
|
||||
// Create UDP connection with ephemeral port
|
||||
laddr := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4(ipBytes),
|
||||
Port: 0,
|
||||
}
|
||||
|
||||
raddrTcpip := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
|
||||
Port: uint16(raddr.Port),
|
||||
}
|
||||
|
||||
conn, err := gonet.DialUDP(p.tunnelStack, laddr, raddrTcpip, ipv4.ProtocolNumber)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Get local port
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
port := uint16(localAddr.Port)
|
||||
|
||||
// Register port so we can receive responses
|
||||
p.tunnelPortsLock.Lock()
|
||||
p.tunnelActivePorts[port] = true
|
||||
p.tunnelPortsLock.Unlock()
|
||||
|
||||
return conn, port, nil
|
||||
}
|
||||
|
||||
// removeTunnelPort removes a port from the active ports map
|
||||
func (p *DNSProxy) removeTunnelPort(port uint16) {
|
||||
p.tunnelPortsLock.Lock()
|
||||
delete(p.tunnelActivePorts, port)
|
||||
p.tunnelPortsLock.Unlock()
|
||||
}
|
||||
|
||||
// runTunnelPacketSender reads packets from tunnel netstack and injects them into WireGuard
|
||||
func (p *DNSProxy) runTunnelPacketSender() {
|
||||
defer p.wg.Done()
|
||||
logger.Debug("DNS tunnel packet sender goroutine started")
|
||||
|
||||
for {
|
||||
// Use blocking ReadContext instead of polling - much more CPU efficient
|
||||
// This will block until a packet is available or context is cancelled
|
||||
pkt := p.tunnelEp.ReadContext(p.ctx)
|
||||
if pkt == nil {
|
||||
// Context was cancelled or endpoint closed
|
||||
logger.Debug("DNS tunnel packet sender exiting")
|
||||
// Drain any remaining packets
|
||||
for {
|
||||
pkt := p.tunnelEp.Read()
|
||||
if pkt == nil {
|
||||
break
|
||||
}
|
||||
pkt.DecRef()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Extract packet data
|
||||
slices := pkt.AsSlices()
|
||||
if len(slices) > 0 {
|
||||
var totalSize int
|
||||
for _, slice := range slices {
|
||||
totalSize += len(slice)
|
||||
}
|
||||
|
||||
buf := make([]byte, totalSize)
|
||||
pos := 0
|
||||
for _, slice := range slices {
|
||||
copy(buf[pos:], slice)
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Inject into MiddleDevice (outbound to WG)
|
||||
p.middleDevice.InjectOutbound(buf)
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
// runPacketSender sends packets from netstack back to TUN
|
||||
func (p *DNSProxy) runPacketSender() {
|
||||
defer p.wg.Done()
|
||||
@@ -371,18 +677,12 @@ func (p *DNSProxy) runPacketSender() {
|
||||
const offset = 16
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read packets from netstack endpoint
|
||||
pkt := p.ep.Read()
|
||||
// Use blocking ReadContext instead of polling - much more CPU efficient
|
||||
// This will block until a packet is available or context is cancelled
|
||||
pkt := p.ep.ReadContext(p.ctx)
|
||||
if pkt == nil {
|
||||
// No packet available, small sleep to avoid busy loop
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
// Context was cancelled or endpoint closed
|
||||
return
|
||||
}
|
||||
|
||||
// Extract packet data as slices
|
||||
@@ -405,9 +705,9 @@ func (p *DNSProxy) runPacketSender() {
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Write packet to TUN device
|
||||
// Write packet to TUN device via MiddleDevice
|
||||
// offset=16 indicates packet data starts at position 16 in the buffer
|
||||
_, err := p.tunDevice.Write([][]byte{buf}, offset)
|
||||
_, err := p.middleDevice.WriteToTun([][]byte{buf}, offset)
|
||||
if err != nil {
|
||||
logger.Error("Failed to write DNS response to TUN: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -13,26 +15,35 @@ type RecordType uint16
|
||||
const (
|
||||
RecordTypeA RecordType = RecordType(dns.TypeA)
|
||||
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
|
||||
RecordTypePTR RecordType = RecordType(dns.TypePTR)
|
||||
)
|
||||
|
||||
// DNSRecordStore manages local DNS records for A and AAAA queries
|
||||
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries
|
||||
type DNSRecordStore struct {
|
||||
mu sync.RWMutex
|
||||
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
||||
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
||||
mu sync.RWMutex
|
||||
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
||||
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
||||
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
|
||||
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
|
||||
ptrRecords map[string]string // IP address string -> domain name
|
||||
}
|
||||
|
||||
// NewDNSRecordStore creates a new DNS record store
|
||||
func NewDNSRecordStore() *DNSRecordStore {
|
||||
return &DNSRecordStore{
|
||||
aRecords: make(map[string][]net.IP),
|
||||
aaaaRecords: make(map[string][]net.IP),
|
||||
aRecords: make(map[string][]net.IP),
|
||||
aaaaRecords: make(map[string][]net.IP),
|
||||
aWildcards: make(map[string][]net.IP),
|
||||
aaaaWildcards: make(map[string][]net.IP),
|
||||
ptrRecords: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRecord adds a DNS record mapping (A or AAAA)
|
||||
// domain should be in FQDN format (e.g., "example.com.")
|
||||
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
// Automatically adds a corresponding PTR record for non-wildcard domains
|
||||
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -42,15 +53,30 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase
|
||||
domain = dns.Fqdn(domain)
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check if domain contains wildcards
|
||||
isWildcard := strings.ContainsAny(domain, "*?")
|
||||
|
||||
if ip.To4() != nil {
|
||||
// IPv4 address
|
||||
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
||||
if isWildcard {
|
||||
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
|
||||
} else {
|
||||
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
||||
// Automatically add PTR record for non-wildcard domains
|
||||
s.ptrRecords[ip.String()] = domain
|
||||
}
|
||||
} else if ip.To16() != nil {
|
||||
// IPv6 address
|
||||
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
||||
if isWildcard {
|
||||
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
|
||||
} else {
|
||||
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
||||
// Automatically add PTR record for non-wildcard domains
|
||||
s.ptrRecords[ip.String()] = domain
|
||||
}
|
||||
} else {
|
||||
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
||||
}
|
||||
@@ -58,8 +84,30 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
// domain should be in FQDN format (e.g., "example.com.")
|
||||
func (s *DNSRecordStore) AddPTRRecord(ip net.IP, domain string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Store PTR record using IP string as key
|
||||
s.ptrRecords[ip.String()] = domain
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRecord removes a specific DNS record mapping
|
||||
// If ip is nil, removes all records for the domain
|
||||
// If ip is nil, removes all records for the domain (including wildcards)
|
||||
// Automatically removes corresponding PTR records for non-wildcard domains
|
||||
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -69,82 +117,223 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase
|
||||
domain = dns.Fqdn(domain)
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check if domain contains wildcards
|
||||
isWildcard := strings.ContainsAny(domain, "*?")
|
||||
|
||||
if ip == nil {
|
||||
// Remove all records for this domain
|
||||
delete(s.aRecords, domain)
|
||||
delete(s.aaaaRecords, domain)
|
||||
if isWildcard {
|
||||
delete(s.aWildcards, domain)
|
||||
delete(s.aaaaWildcards, domain)
|
||||
} else {
|
||||
// For non-wildcard domains, remove PTR records for all IPs
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
for _, ipAddr := range ips {
|
||||
// Only remove PTR if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ipAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
for _, ipAddr := range ips {
|
||||
// Only remove PTR if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ipAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(s.aRecords, domain)
|
||||
delete(s.aaaaRecords, domain)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
// Remove specific IPv4 address
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
s.aRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aRecords[domain]) == 0 {
|
||||
delete(s.aRecords, domain)
|
||||
if isWildcard {
|
||||
if ips, ok := s.aWildcards[domain]; ok {
|
||||
s.aWildcards[domain] = removeIP(ips, ip)
|
||||
if len(s.aWildcards[domain]) == 0 {
|
||||
delete(s.aWildcards, domain)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
s.aRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aRecords[domain]) == 0 {
|
||||
delete(s.aRecords, domain)
|
||||
}
|
||||
// Automatically remove PTR record if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if ip.To16() != nil {
|
||||
// Remove specific IPv6 address
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
s.aaaaRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aaaaRecords[domain]) == 0 {
|
||||
delete(s.aaaaRecords, domain)
|
||||
if isWildcard {
|
||||
if ips, ok := s.aaaaWildcards[domain]; ok {
|
||||
s.aaaaWildcards[domain] = removeIP(ips, ip)
|
||||
if len(s.aaaaWildcards[domain]) == 0 {
|
||||
delete(s.aaaaWildcards, domain)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
s.aaaaRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aaaaRecords[domain]) == 0 {
|
||||
delete(s.aaaaRecords, domain)
|
||||
}
|
||||
// Automatically remove PTR record if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemovePTRRecord removes a PTR record for an IP address
|
||||
func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
|
||||
// GetRecords returns all IP addresses for a domain and record type
|
||||
// First checks for exact matches, then checks wildcard patterns
|
||||
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = dns.Fqdn(domain)
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
var records []net.IP
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
// Check exact match first
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
return records
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern, ips := range s.aWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
records = append(records, ips...)
|
||||
}
|
||||
}
|
||||
if len(records) > 0 {
|
||||
// Return a copy
|
||||
result := make([]net.IP, len(records))
|
||||
copy(result, records)
|
||||
return result
|
||||
}
|
||||
|
||||
case RecordTypeAAAA:
|
||||
// Check exact match first
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
return records
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern, ips := range s.aaaaWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
records = append(records, ips...)
|
||||
}
|
||||
}
|
||||
if len(records) > 0 {
|
||||
// Return a copy
|
||||
result := make([]net.IP, len(records))
|
||||
copy(result, records)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// GetPTRRecord returns the domain name for a PTR record query
|
||||
// domain should be in reverse DNS format (e.g., "1.0.0.127.in-addr.arpa.")
|
||||
func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Convert reverse DNS format to IP address
|
||||
ip := reverseDNSToIP(domain)
|
||||
if ip == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Look up the PTR record
|
||||
if ptrDomain, ok := s.ptrRecords[ip.String()]; ok {
|
||||
return ptrDomain, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// HasRecord checks if a domain has any records of the specified type
|
||||
// Checks both exact matches and wildcard patterns
|
||||
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = dns.Fqdn(domain)
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
_, ok := s.aRecords[domain]
|
||||
return ok
|
||||
// Check exact match
|
||||
if _, ok := s.aRecords[domain]; ok {
|
||||
return true
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern := range s.aWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case RecordTypeAAAA:
|
||||
_, ok := s.aaaaRecords[domain]
|
||||
return ok
|
||||
// Check exact match
|
||||
if _, ok := s.aaaaRecords[domain]; ok {
|
||||
return true
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern := range s.aaaaWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HasPTRRecord checks if a PTR record exists for the given reverse DNS domain
|
||||
func (s *DNSRecordStore) HasPTRRecord(domain string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Convert reverse DNS format to IP address
|
||||
ip := reverseDNSToIP(domain)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
_, ok := s.ptrRecords[ip.String()]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Clear removes all records from the store
|
||||
func (s *DNSRecordStore) Clear() {
|
||||
s.mu.Lock()
|
||||
@@ -152,6 +341,9 @@ func (s *DNSRecordStore) Clear() {
|
||||
|
||||
s.aRecords = make(map[string][]net.IP)
|
||||
s.aaaaRecords = make(map[string][]net.IP)
|
||||
s.aWildcards = make(map[string][]net.IP)
|
||||
s.aaaaWildcards = make(map[string][]net.IP)
|
||||
s.ptrRecords = make(map[string]string)
|
||||
}
|
||||
|
||||
// removeIP is a helper function to remove a specific IP from a slice
|
||||
@@ -164,3 +356,142 @@ func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// matchWildcard checks if a domain matches a wildcard pattern
|
||||
// Pattern supports * (0+ chars) and ? (exactly 1 char)
|
||||
// Special case: *.domain.com does not match domain.com itself
|
||||
func matchWildcard(pattern, domain string) bool {
|
||||
return matchWildcardInternal(pattern, domain, 0, 0)
|
||||
}
|
||||
|
||||
// matchWildcardInternal performs the actual wildcard matching recursively
|
||||
func matchWildcardInternal(pattern, domain string, pi, di int) bool {
|
||||
plen := len(pattern)
|
||||
dlen := len(domain)
|
||||
|
||||
// Base cases
|
||||
if pi == plen && di == dlen {
|
||||
return true
|
||||
}
|
||||
if pi == plen {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle wildcard characters
|
||||
if pattern[pi] == '*' {
|
||||
// Special case: if pattern starts with "*." and we're at the beginning,
|
||||
// ensure we don't match the domain without a prefix
|
||||
// e.g., *.autoco.internal should not match autoco.internal
|
||||
if pi == 0 && pi+1 < plen && pattern[pi+1] == '.' {
|
||||
// The * must match at least one character
|
||||
if di == dlen {
|
||||
return false
|
||||
}
|
||||
// Try matching 1 or more characters before the dot
|
||||
for i := di + 1; i <= dlen; i++ {
|
||||
if matchWildcardInternal(pattern, domain, pi+1, i) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Normal * matching (0 or more characters)
|
||||
// Try matching 0 characters (skip the *)
|
||||
if matchWildcardInternal(pattern, domain, pi+1, di) {
|
||||
return true
|
||||
}
|
||||
// Try matching 1+ characters
|
||||
if di < dlen {
|
||||
return matchWildcardInternal(pattern, domain, pi, di+1)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if pattern[pi] == '?' {
|
||||
// ? matches exactly one character
|
||||
if di >= dlen {
|
||||
return false
|
||||
}
|
||||
return matchWildcardInternal(pattern, domain, pi+1, di+1)
|
||||
}
|
||||
|
||||
// Regular character - must match exactly
|
||||
if di >= dlen || pattern[pi] != domain[di] {
|
||||
return false
|
||||
}
|
||||
|
||||
return matchWildcardInternal(pattern, domain, pi+1, di+1)
|
||||
}
|
||||
|
||||
// reverseDNSToIP converts a reverse DNS query name to an IP address
|
||||
// Supports both IPv4 (in-addr.arpa) and IPv6 (ip6.arpa) formats
|
||||
func reverseDNSToIP(domain string) net.IP {
|
||||
// Normalize to lowercase and ensure FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check for IPv4 reverse DNS (in-addr.arpa)
|
||||
if strings.HasSuffix(domain, ".in-addr.arpa.") {
|
||||
// Remove the suffix
|
||||
ipPart := strings.TrimSuffix(domain, ".in-addr.arpa.")
|
||||
// Split by dots and reverse
|
||||
parts := strings.Split(ipPart, ".")
|
||||
if len(parts) != 4 {
|
||||
return nil
|
||||
}
|
||||
// Reverse the octets
|
||||
reversed := make([]string, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
reversed[i] = parts[3-i]
|
||||
}
|
||||
// Parse as IP
|
||||
return net.ParseIP(strings.Join(reversed, "."))
|
||||
}
|
||||
|
||||
// Check for IPv6 reverse DNS (ip6.arpa)
|
||||
if strings.HasSuffix(domain, ".ip6.arpa.") {
|
||||
// Remove the suffix
|
||||
ipPart := strings.TrimSuffix(domain, ".ip6.arpa.")
|
||||
// Split by dots and reverse
|
||||
parts := strings.Split(ipPart, ".")
|
||||
if len(parts) != 32 {
|
||||
return nil
|
||||
}
|
||||
// Reverse the nibbles and group into 16-bit hex values
|
||||
reversed := make([]string, 32)
|
||||
for i := 0; i < 32; i++ {
|
||||
reversed[i] = parts[31-i]
|
||||
}
|
||||
// Join into IPv6 format (groups of 4 nibbles separated by colons)
|
||||
var ipv6Parts []string
|
||||
for i := 0; i < 32; i += 4 {
|
||||
ipv6Parts = append(ipv6Parts, reversed[i]+reversed[i+1]+reversed[i+2]+reversed[i+3])
|
||||
}
|
||||
// Parse as IP
|
||||
return net.ParseIP(strings.Join(ipv6Parts, ":"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPToReverseDNS converts an IP address to reverse DNS format
|
||||
// Returns the domain name for PTR queries (e.g., "1.0.0.127.in-addr.arpa.")
|
||||
func IPToReverseDNS(ip net.IP) string {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// IPv4: reverse octets and append .in-addr.arpa.
|
||||
return dns.Fqdn(fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa",
|
||||
ip4[3], ip4[2], ip4[1], ip4[0]))
|
||||
}
|
||||
|
||||
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
|
||||
// IPv6: expand to 32 nibbles, reverse, and append .ip6.arpa.
|
||||
var nibbles []string
|
||||
for i := 15; i >= 0; i-- {
|
||||
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]&0x0f))
|
||||
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]>>4))
|
||||
}
|
||||
return dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
864
dns/dns_records_test.go
Normal file
864
dns/dns_records_test.go
Normal file
@@ -0,0 +1,864 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWildcardMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
domain string
|
||||
expected bool
|
||||
}{
|
||||
// Basic wildcard tests
|
||||
{
|
||||
name: "*.autoco.internal matches host.autoco.internal",
|
||||
pattern: "*.autoco.internal.",
|
||||
domain: "host.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.autoco.internal matches longerhost.autoco.internal",
|
||||
pattern: "*.autoco.internal.",
|
||||
domain: "longerhost.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.autoco.internal matches sub.host.autoco.internal",
|
||||
pattern: "*.autoco.internal.",
|
||||
domain: "sub.host.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.autoco.internal does NOT match autoco.internal",
|
||||
pattern: "*.autoco.internal.",
|
||||
domain: "autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Question mark wildcard tests
|
||||
{
|
||||
name: "host-0?.autoco.internal matches host-01.autoco.internal",
|
||||
pattern: "host-0?.autoco.internal.",
|
||||
domain: "host-01.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "host-0?.autoco.internal matches host-0a.autoco.internal",
|
||||
pattern: "host-0?.autoco.internal.",
|
||||
domain: "host-0a.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "host-0?.autoco.internal does NOT match host-0.autoco.internal",
|
||||
pattern: "host-0?.autoco.internal.",
|
||||
domain: "host-0.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "host-0?.autoco.internal does NOT match host-012.autoco.internal",
|
||||
pattern: "host-0?.autoco.internal.",
|
||||
domain: "host-012.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Combined wildcard tests
|
||||
{
|
||||
name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal",
|
||||
pattern: "*.host-0?.autoco.internal.",
|
||||
domain: "sub.host-01.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.host-0?.autoco.internal matches prefix.host-0a.autoco.internal",
|
||||
pattern: "*.host-0?.autoco.internal.",
|
||||
domain: "prefix.host-0a.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.host-0?.autoco.internal does NOT match host-01.autoco.internal",
|
||||
pattern: "*.host-0?.autoco.internal.",
|
||||
domain: "host-01.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Multiple asterisks
|
||||
{
|
||||
name: "*.*. autoco.internal matches any.thing.autoco.internal",
|
||||
pattern: "*.*.autoco.internal.",
|
||||
domain: "any.thing.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.*.autoco.internal does NOT match single.autoco.internal",
|
||||
pattern: "*.*.autoco.internal.",
|
||||
domain: "single.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Asterisk in middle
|
||||
{
|
||||
name: "host-*.autoco.internal matches host-anything.autoco.internal",
|
||||
pattern: "host-*.autoco.internal.",
|
||||
domain: "host-anything.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "host-*.autoco.internal matches host-.autoco.internal (empty match)",
|
||||
pattern: "host-*.autoco.internal.",
|
||||
domain: "host-.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Multiple question marks
|
||||
{
|
||||
name: "host-??.autoco.internal matches host-01.autoco.internal",
|
||||
pattern: "host-??.autoco.internal.",
|
||||
domain: "host-01.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "host-??.autoco.internal does NOT match host-1.autoco.internal",
|
||||
pattern: "host-??.autoco.internal.",
|
||||
domain: "host-1.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Exact match (no wildcards)
|
||||
{
|
||||
name: "exact.autoco.internal matches exact.autoco.internal",
|
||||
pattern: "exact.autoco.internal.",
|
||||
domain: "exact.autoco.internal.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "exact.autoco.internal does NOT match other.autoco.internal",
|
||||
pattern: "exact.autoco.internal.",
|
||||
domain: "other.autoco.internal.",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "* matches anything",
|
||||
pattern: "*",
|
||||
domain: "anything.at.all.",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "*.* matches multi.level.",
|
||||
pattern: "*.*",
|
||||
domain: "multi.level.",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchWildcard(tt.pattern, tt.domain)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchWildcard(%q, %q) = %v, want %v", tt.pattern, tt.domain, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreWildcard(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add wildcard records
|
||||
wildcardIP := net.ParseIP("10.0.0.1")
|
||||
err := store.AddRecord("*.autoco.internal", wildcardIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Add exact record
|
||||
exactIP := net.ParseIP("10.0.0.2")
|
||||
err = store.AddRecord("exact.autoco.internal", exactIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add exact record: %v", err)
|
||||
}
|
||||
|
||||
// Test exact match takes precedence
|
||||
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
|
||||
}
|
||||
if !ips[0].Equal(exactIP) {
|
||||
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
|
||||
}
|
||||
|
||||
// Test wildcard match
|
||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
|
||||
}
|
||||
if !ips[0].Equal(wildcardIP) {
|
||||
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
|
||||
}
|
||||
|
||||
// Test non-match (base domain)
|
||||
ips = store.GetRecords("autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add complex wildcard pattern
|
||||
ip1 := net.ParseIP("10.0.0.1")
|
||||
err := store.AddRecord("*.host-0?.autoco.internal", ip1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Test matching domain
|
||||
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
|
||||
}
|
||||
if len(ips) > 0 && !ips[0].Equal(ip1) {
|
||||
t.Errorf("Expected IP %v, got %v", ip1, ips[0])
|
||||
}
|
||||
|
||||
// Test non-matching domain (missing prefix)
|
||||
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test non-matching domain (wrong ? position)
|
||||
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add wildcard record
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
err := store.AddRecord("*.autoco.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Verify it exists
|
||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
|
||||
}
|
||||
|
||||
// Remove wildcard record
|
||||
store.RemoveRecord("*.autoco.internal", nil)
|
||||
|
||||
// Verify it's gone
|
||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add multiple wildcard patterns that don't overlap
|
||||
ip1 := net.ParseIP("10.0.0.1")
|
||||
ip2 := net.ParseIP("10.0.0.2")
|
||||
ip3 := net.ParseIP("10.0.0.3")
|
||||
|
||||
err := store.AddRecord("*.prod.autoco.internal", ip1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add first wildcard: %v", err)
|
||||
}
|
||||
|
||||
err = store.AddRecord("*.dev.autoco.internal", ip2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add second wildcard: %v", err)
|
||||
}
|
||||
|
||||
// Add a broader wildcard that matches both
|
||||
err = store.AddRecord("*.autoco.internal", ip3)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add third wildcard: %v", err)
|
||||
}
|
||||
|
||||
// Test domain matching only the prod pattern and the broad pattern
|
||||
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test domain matching only the dev pattern and the broad pattern
|
||||
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test domain matching only the broad pattern
|
||||
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add IPv6 wildcard record
|
||||
ip := net.ParseIP("2001:db8::1")
|
||||
err := store.AddRecord("*.autoco.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Test wildcard match for IPv6
|
||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
|
||||
}
|
||||
if len(ips) > 0 && !ips[0].Equal(ip) {
|
||||
t.Errorf("Expected IPv6 %v, got %v", ip, ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasRecordWildcard(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add wildcard record
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
err := store.AddRecord("*.autoco.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Test HasRecord with wildcard match
|
||||
if !store.HasRecord("host.autoco.internal.", RecordTypeA) {
|
||||
t.Error("Expected HasRecord to return true for wildcard match")
|
||||
}
|
||||
|
||||
// Test HasRecord with non-match
|
||||
if store.HasRecord("autoco.internal.", RecordTypeA) {
|
||||
t.Error("Expected HasRecord to return false for base domain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add record with mixed case
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
err := store.AddRecord("MyHost.AutoCo.Internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add mixed case record: %v", err)
|
||||
}
|
||||
|
||||
// Test lookup with different cases
|
||||
testCases := []string{
|
||||
"myhost.autoco.internal.",
|
||||
"MYHOST.AUTOCO.INTERNAL.",
|
||||
"MyHost.AutoCo.Internal.",
|
||||
"mYhOsT.aUtOcO.iNtErNaL.",
|
||||
}
|
||||
|
||||
for _, domain := range testCases {
|
||||
ips := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
|
||||
}
|
||||
if len(ips) > 0 && !ips[0].Equal(ip) {
|
||||
t.Errorf("Expected IP %v for domain %q, got %v", ip, domain, ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Test wildcard with mixed case
|
||||
wildcardIP := net.ParseIP("10.0.0.2")
|
||||
err = store.AddRecord("*.Example.Com", wildcardIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add mixed case wildcard: %v", err)
|
||||
}
|
||||
|
||||
wildcardTestCases := []string{
|
||||
"host.example.com.",
|
||||
"HOST.EXAMPLE.COM.",
|
||||
"Host.Example.Com.",
|
||||
"HoSt.ExAmPlE.CoM.",
|
||||
}
|
||||
|
||||
for _, domain := range wildcardTestCases {
|
||||
ips := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
|
||||
}
|
||||
if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
|
||||
t.Errorf("Expected IP %v for wildcard domain %q, got %v", wildcardIP, domain, ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Test removal with different case
|
||||
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
|
||||
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test HasRecord with different case
|
||||
if !store.HasRecord("HOST.EXAMPLE.COM.", RecordTypeA) {
|
||||
t.Error("Expected HasRecord to return true for mixed case wildcard match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordIPv4(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record for IPv4
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
domain := "host.example.com."
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Test reverse DNS lookup
|
||||
reverseDomain := "1.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be found")
|
||||
}
|
||||
if result != domain {
|
||||
t.Errorf("Expected domain %q, got %q", domain, result)
|
||||
}
|
||||
|
||||
// Test HasPTRRecord
|
||||
if !store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected HasPTRRecord to return true")
|
||||
}
|
||||
|
||||
// Test non-existent PTR record
|
||||
_, ok = store.GetPTRRecord("2.1.168.192.in-addr.arpa.")
|
||||
if ok {
|
||||
t.Error("Expected PTR record not to be found for different IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordIPv6(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record for IPv6
|
||||
ip := net.ParseIP("2001:db8::1")
|
||||
domain := "ipv6host.example.com."
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Test reverse DNS lookup
|
||||
// 2001:db8::1 = 2001:0db8:0000:0000:0000:0000:0000:0001
|
||||
// Reverse: 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
|
||||
reverseDomain := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected IPv6 PTR record to be found")
|
||||
}
|
||||
if result != domain {
|
||||
t.Errorf("Expected domain %q, got %q", domain, result)
|
||||
}
|
||||
|
||||
// Test HasPTRRecord
|
||||
if !store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected HasPTRRecord to return true for IPv6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemovePTRRecord(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
domain := "test.example.com."
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Verify it exists
|
||||
reverseDomain := "1.0.0.10.in-addr.arpa."
|
||||
_, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to exist before removal")
|
||||
}
|
||||
|
||||
// Remove PTR record
|
||||
store.RemovePTRRecord(ip)
|
||||
|
||||
// Verify it's gone
|
||||
_, ok = store.GetPTRRecord(reverseDomain)
|
||||
if ok {
|
||||
t.Error("Expected PTR record to be removed")
|
||||
}
|
||||
|
||||
// Test HasPTRRecord after removal
|
||||
if store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected HasPTRRecord to return false after removal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPToReverseDNS(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 simple",
|
||||
ip: "192.168.1.1",
|
||||
expected: "1.1.168.192.in-addr.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv4 localhost",
|
||||
ip: "127.0.0.1",
|
||||
expected: "1.0.0.127.in-addr.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv4 with zeros",
|
||||
ip: "10.0.0.1",
|
||||
expected: "1.0.0.10.in-addr.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv6 simple",
|
||||
ip: "2001:db8::1",
|
||||
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost",
|
||||
ip: "::1",
|
||||
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("Failed to parse IP: %s", tt.ip)
|
||||
}
|
||||
result := IPToReverseDNS(ip)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IPToReverseDNS(%s) = %q, want %q", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseDNSToIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reverseDNS string
|
||||
expectedIP string
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 simple",
|
||||
reverseDNS: "1.1.168.192.in-addr.arpa.",
|
||||
expectedIP: "192.168.1.1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 localhost",
|
||||
reverseDNS: "1.0.0.127.in-addr.arpa.",
|
||||
expectedIP: "127.0.0.1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 simple",
|
||||
reverseDNS: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
expectedIP: "2001:db8::1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid IPv4 format",
|
||||
reverseDNS: "1.1.168.in-addr.arpa.",
|
||||
expectedIP: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid IPv6 format",
|
||||
reverseDNS: "1.0.0.0.ip6.arpa.",
|
||||
expectedIP: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "Not a reverse DNS domain",
|
||||
reverseDNS: "example.com.",
|
||||
expectedIP: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reverseDNSToIP(tt.reverseDNS)
|
||||
if tt.shouldMatch {
|
||||
if result == nil {
|
||||
t.Errorf("reverseDNSToIP(%q) returned nil, expected IP", tt.reverseDNS)
|
||||
return
|
||||
}
|
||||
expectedIP := net.ParseIP(tt.expectedIP)
|
||||
if !result.Equal(expectedIP) {
|
||||
t.Errorf("reverseDNSToIP(%q) = %v, want %v", tt.reverseDNS, result, expectedIP)
|
||||
}
|
||||
} else {
|
||||
if result != nil {
|
||||
t.Errorf("reverseDNSToIP(%q) = %v, expected nil", tt.reverseDNS, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordCaseInsensitive(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record with mixed case domain
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
domain := "MyHost.Example.Com"
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Test lookup with different cases in reverse DNS format
|
||||
reverseDomain := "1.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be found")
|
||||
}
|
||||
// Domain should be normalized to lowercase
|
||||
if result != "myhost.example.com." {
|
||||
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
|
||||
}
|
||||
|
||||
// Test with uppercase reverse DNS
|
||||
reverseDomainUpper := "1.1.168.192.IN-ADDR.ARPA."
|
||||
result, ok = store.GetPTRRecord(reverseDomainUpper)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be found with uppercase reverse DNS")
|
||||
}
|
||||
if result != "myhost.example.com." {
|
||||
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearPTRRecords(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add some PTR records
|
||||
ip1 := net.ParseIP("192.168.1.1")
|
||||
ip2 := net.ParseIP("192.168.1.2")
|
||||
store.AddPTRRecord(ip1, "host1.example.com.")
|
||||
store.AddPTRRecord(ip2, "host2.example.com.")
|
||||
|
||||
// Add some A records too
|
||||
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"))
|
||||
|
||||
// Verify PTR records exist
|
||||
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
|
||||
t.Error("Expected PTR record to exist before clear")
|
||||
}
|
||||
|
||||
// Clear all records
|
||||
store.Clear()
|
||||
|
||||
// Verify PTR records are gone
|
||||
if store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
|
||||
t.Error("Expected PTR record to be cleared")
|
||||
}
|
||||
if store.HasPTRRecord("2.1.168.192.in-addr.arpa.") {
|
||||
t.Error("Expected PTR record to be cleared")
|
||||
}
|
||||
|
||||
// Verify A records are also gone
|
||||
if store.HasRecord("test.example.com.", RecordTypeA) {
|
||||
t.Error("Expected A record to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutomaticPTRRecordOnAdd(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add an A record - should automatically add PTR record
|
||||
domain := "host.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
err := store.AddRecord(domain, ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add A record: %v", err)
|
||||
}
|
||||
|
||||
// Verify PTR record was automatically created
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be automatically created")
|
||||
}
|
||||
if result != domain {
|
||||
t.Errorf("Expected PTR to point to %q, got %q", domain, result)
|
||||
}
|
||||
|
||||
// Add AAAA record - should also automatically add PTR record
|
||||
domain6 := "ipv6host.example.com."
|
||||
ip6 := net.ParseIP("2001:db8::1")
|
||||
err = store.AddRecord(domain6, ip6)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add AAAA record: %v", err)
|
||||
}
|
||||
|
||||
// Verify IPv6 PTR record was automatically created
|
||||
reverseDomain6 := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
|
||||
result6, ok := store.GetPTRRecord(reverseDomain6)
|
||||
if !ok {
|
||||
t.Error("Expected IPv6 PTR record to be automatically created")
|
||||
}
|
||||
if result6 != domain6 {
|
||||
t.Errorf("Expected PTR to point to %q, got %q", domain6, result6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutomaticPTRRecordOnRemove(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add an A record (with automatic PTR)
|
||||
domain := "host.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
store.AddRecord(domain, ip)
|
||||
|
||||
// Verify PTR exists
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
if !store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected PTR record to exist after adding A record")
|
||||
}
|
||||
|
||||
// Remove the A record
|
||||
store.RemoveRecord(domain, ip)
|
||||
|
||||
// Verify PTR was automatically removed
|
||||
if store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected PTR record to be automatically removed")
|
||||
}
|
||||
|
||||
// Verify A record is also gone
|
||||
ips := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected A record to be removed, got %d records", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add multiple IPs for the same domain
|
||||
domain := "host.example.com."
|
||||
ip1 := net.ParseIP("192.168.1.100")
|
||||
ip2 := net.ParseIP("192.168.1.101")
|
||||
store.AddRecord(domain, ip1)
|
||||
store.AddRecord(domain, ip2)
|
||||
|
||||
// Verify both PTR records exist
|
||||
reverseDomain1 := "100.1.168.192.in-addr.arpa."
|
||||
reverseDomain2 := "101.1.168.192.in-addr.arpa."
|
||||
if !store.HasPTRRecord(reverseDomain1) {
|
||||
t.Error("Expected first PTR record to exist")
|
||||
}
|
||||
if !store.HasPTRRecord(reverseDomain2) {
|
||||
t.Error("Expected second PTR record to exist")
|
||||
}
|
||||
|
||||
// Remove all records for the domain
|
||||
store.RemoveRecord(domain, nil)
|
||||
|
||||
// Verify both PTR records were removed
|
||||
if store.HasPTRRecord(reverseDomain1) {
|
||||
t.Error("Expected first PTR record to be removed")
|
||||
}
|
||||
if store.HasPTRRecord(reverseDomain2) {
|
||||
t.Error("Expected second PTR record to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoPTRForWildcardRecords(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add wildcard record - should NOT create PTR record
|
||||
domain := "*.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
err := store.AddRecord(domain, ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Verify no PTR record was created
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
_, ok := store.GetPTRRecord(reverseDomain)
|
||||
if ok {
|
||||
t.Error("Expected no PTR record for wildcard domain")
|
||||
}
|
||||
|
||||
// Verify wildcard A record exists
|
||||
if !store.HasRecord("host.example.com.", RecordTypeA) {
|
||||
t.Error("Expected wildcard A record to exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordOverwrite(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add first domain with IP
|
||||
domain1 := "host1.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
store.AddRecord(domain1, ip)
|
||||
|
||||
// Verify PTR points to first domain
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Fatal("Expected PTR record to exist")
|
||||
}
|
||||
if result != domain1 {
|
||||
t.Errorf("Expected PTR to point to %q, got %q", domain1, result)
|
||||
}
|
||||
|
||||
// Add second domain with same IP - should overwrite PTR
|
||||
domain2 := "host2.example.com."
|
||||
store.AddRecord(domain2, ip)
|
||||
|
||||
// Verify PTR now points to second domain (last one added)
|
||||
result, ok = store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Fatal("Expected PTR record to still exist")
|
||||
}
|
||||
if result != domain2 {
|
||||
t.Errorf("Expected PTR to point to %q (overwritten), got %q", domain2, result)
|
||||
}
|
||||
|
||||
// Remove first domain - PTR should remain pointing to second domain
|
||||
store.RemoveRecord(domain1, ip)
|
||||
result, ok = store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to still exist after removing first domain")
|
||||
}
|
||||
if result != domain2 {
|
||||
t.Errorf("Expected PTR to still point to %q, got %q", domain2, result)
|
||||
}
|
||||
|
||||
// Remove second domain - PTR should now be gone
|
||||
store.RemoveRecord(domain2, ip)
|
||||
_, ok = store.GetPTRRecord(reverseDomain)
|
||||
if ok {
|
||||
t.Error("Expected PTR record to be removed after removing second domain")
|
||||
}
|
||||
}
|
||||
16
dns/override/dns_override_android.go
Normal file
16
dns/override/dns_override_android.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build android
|
||||
|
||||
package olm
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// SetupDNSOverride is a no-op on Android
|
||||
// Android handles DNS through the VpnService API at the Java/Kotlin layer
|
||||
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride is a no-op on Android
|
||||
func RestoreDNSOverride() error {
|
||||
return nil
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/dns"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
|
||||
// Uses scutil for DNS configuration
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
if dnsProxy == nil {
|
||||
return fmt.Errorf("DNS proxy is nil")
|
||||
}
|
||||
|
||||
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||
var err error
|
||||
configurator, err = platform.NewDarwinDNSConfigurator()
|
||||
if err != nil {
|
||||
@@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
dnsProxy.GetProxyIP(),
|
||||
proxyIp,
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
|
||||
15
dns/override/dns_override_ios.go
Normal file
15
dns/override/dns_override_ios.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build ios
|
||||
|
||||
package olm
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system
|
||||
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride is a no-op on iOS as DNS configuration is handled by the system
|
||||
func RestoreDNSOverride() error {
|
||||
return nil
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/dns"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
|
||||
// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
if dnsProxy == nil {
|
||||
return fmt.Errorf("DNS proxy is nil")
|
||||
}
|
||||
|
||||
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||
var err error
|
||||
|
||||
// Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability
|
||||
@@ -32,7 +27,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using systemd-resolved DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
return setDNS(proxyIp, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err)
|
||||
|
||||
@@ -40,7 +35,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using NetworkManager DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
return setDNS(proxyIp, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err)
|
||||
|
||||
@@ -48,7 +43,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using resolvconf DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
return setDNS(proxyIp, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create resolvconf configurator: %v, falling back", err)
|
||||
}
|
||||
@@ -60,11 +55,11 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
}
|
||||
|
||||
logger.Info("Using file-based DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
return setDNS(proxyIp, configurator)
|
||||
}
|
||||
|
||||
// setDNS is a helper function to set DNS and log the results
|
||||
func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
|
||||
func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error {
|
||||
// Get current DNS servers before changing
|
||||
currentDNS, err := conf.GetCurrentDNS()
|
||||
if err != nil {
|
||||
@@ -75,7 +70,7 @@ func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
dnsProxy.GetProxyIP(),
|
||||
proxyIp,
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/dns"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
|
||||
// Uses registry-based configuration (automatically extracts interface GUID)
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
if dnsProxy == nil {
|
||||
return fmt.Errorf("DNS proxy is nil")
|
||||
}
|
||||
|
||||
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||
var err error
|
||||
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
|
||||
if err != nil {
|
||||
@@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
dnsProxy.GetProxyIP(),
|
||||
proxyIp,
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
|
||||
@@ -5,9 +5,13 @@ package dns
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -28,19 +32,38 @@ const (
|
||||
keyServerPort = "ServerPort"
|
||||
arraySymbol = "* "
|
||||
digitSymbol = "# "
|
||||
|
||||
// State file name for crash recovery
|
||||
dnsStateFileName = "dns_state.json"
|
||||
)
|
||||
|
||||
// DNSPersistentState represents the state saved to disk for crash recovery
|
||||
type DNSPersistentState struct {
|
||||
CreatedKeys []string `json:"created_keys"`
|
||||
}
|
||||
|
||||
// DarwinDNSConfigurator manages DNS settings on macOS using scutil
|
||||
type DarwinDNSConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
originalState *DNSState
|
||||
stateFilePath string
|
||||
}
|
||||
|
||||
// NewDarwinDNSConfigurator creates a new macOS DNS configurator
|
||||
func NewDarwinDNSConfigurator() (*DarwinDNSConfigurator, error) {
|
||||
return &DarwinDNSConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}, nil
|
||||
stateFilePath := getDNSStateFilePath()
|
||||
|
||||
configurator := &DarwinDNSConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
stateFilePath: stateFilePath,
|
||||
}
|
||||
|
||||
// Clean up any leftover state from a previous crash
|
||||
if err := configurator.CleanupUncleanShutdown(); err != nil {
|
||||
logger.Warn("Failed to cleanup previous DNS state: %v", err)
|
||||
}
|
||||
|
||||
return configurator, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
@@ -67,6 +90,11 @@ func (d *DarwinDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, erro
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Persist state to disk for crash recovery
|
||||
if err := d.saveState(); err != nil {
|
||||
logger.Warn("Failed to save DNS state for crash recovery: %v", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := d.flushDNSCache(); err != nil {
|
||||
// Non-fatal, just log
|
||||
@@ -85,6 +113,11 @@ func (d *DarwinDNSConfigurator) RestoreDNS() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Clear state file after successful restoration
|
||||
if err := d.clearState(); err != nil {
|
||||
logger.Warn("Failed to clear DNS state file: %v", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := d.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
@@ -112,6 +145,47 @@ func (d *DarwinDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS keys left over from a previous crash
|
||||
func (d *DarwinDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
state, err := d.loadState()
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// No state file, nothing to clean up
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("load state: %w", err)
|
||||
}
|
||||
|
||||
if len(state.CreatedKeys) == 0 {
|
||||
// No keys to clean up
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Found DNS state from previous session, cleaning up %d keys", len(state.CreatedKeys))
|
||||
|
||||
// Remove all keys from previous session
|
||||
var lastErr error
|
||||
for _, key := range state.CreatedKeys {
|
||||
logger.Debug("Removing leftover DNS key: %s", key)
|
||||
if err := d.removeKeyDirect(key); err != nil {
|
||||
logger.Warn("Failed to remove DNS key %s: %v", key, err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
// Clear state file
|
||||
if err := d.clearState(); err != nil {
|
||||
logger.Warn("Failed to clear DNS state file: %v", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache after cleanup
|
||||
if err := d.flushDNSCache(); err != nil {
|
||||
logger.Warn("Failed to flush DNS cache after cleanup: %v", err)
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// applyDNSServers applies the DNS server configuration
|
||||
func (d *DarwinDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
@@ -156,15 +230,25 @@ func (d *DarwinDNSConfigurator) addDNSState(state, domains string, dnsServer net
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeKey removes a DNS configuration key
|
||||
// removeKey removes a DNS configuration key and updates internal state
|
||||
func (d *DarwinDNSConfigurator) removeKey(key string) error {
|
||||
if err := d.removeKeyDirect(key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(d.createdKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeKeyDirect removes a DNS configuration key without updating internal state
|
||||
// Used for cleanup operations
|
||||
func (d *DarwinDNSConfigurator) removeKeyDirect(key string) error {
|
||||
cmd := fmt.Sprintf("remove %s\n", key)
|
||||
|
||||
if _, err := d.runScutil(cmd); err != nil {
|
||||
return fmt.Errorf("remove key: %w", err)
|
||||
}
|
||||
|
||||
delete(d.createdKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -266,3 +350,70 @@ func (d *DarwinDNSConfigurator) runScutil(commands string) ([]byte, error) {
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// getDNSStateFilePath returns the path to the DNS state file
|
||||
func getDNSStateFilePath() string {
|
||||
var stateDir string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
stateDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client")
|
||||
default:
|
||||
stateDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(stateDir, 0755); err != nil {
|
||||
logger.Warn("Failed to create state directory: %v", err)
|
||||
}
|
||||
|
||||
return filepath.Join(stateDir, dnsStateFileName)
|
||||
}
|
||||
|
||||
// saveState persists the current DNS state to disk
|
||||
func (d *DarwinDNSConfigurator) saveState() error {
|
||||
keys := make([]string, 0, len(d.createdKeys))
|
||||
for key := range d.createdKeys {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
state := DNSPersistentState{
|
||||
CreatedKeys: keys,
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(state, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal state: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(d.stateFilePath, data, 0644); err != nil {
|
||||
return fmt.Errorf("write state file: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Saved DNS state to %s", d.stateFilePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadState loads the DNS state from disk
|
||||
func (d *DarwinDNSConfigurator) loadState() (*DNSPersistentState, error) {
|
||||
data, err := os.ReadFile(d.stateFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var state DNSPersistentState
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal state: %w", err)
|
||||
}
|
||||
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
// clearState removes the DNS state file
|
||||
func (d *DarwinDNSConfigurator) clearState() error {
|
||||
err := os.Remove(d.stateFilePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove state file: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Cleared DNS state file")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -22,7 +22,11 @@ type FileDNSConfigurator struct {
|
||||
|
||||
// NewFileDNSConfigurator creates a new file-based DNS configurator
|
||||
func NewFileDNSConfigurator() (*FileDNSConfigurator, error) {
|
||||
return &FileDNSConfigurator{}, nil
|
||||
f := &FileDNSConfigurator{}
|
||||
if err := f.CleanupUncleanShutdown(); err != nil {
|
||||
return nil, fmt.Errorf("cleanup unclean shutdown: %w", err)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
@@ -78,6 +82,30 @@ func (f *FileDNSConfigurator) RestoreDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
|
||||
// For the file-based configurator, we check if a backup file exists (indicating a crash
|
||||
// happened while DNS was configured) and restore from it if so.
|
||||
func (f *FileDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
// Check if backup file exists from a previous session
|
||||
if !f.isBackupExists() {
|
||||
// No backup file, nothing to clean up
|
||||
return nil
|
||||
}
|
||||
|
||||
// A backup exists, which means we crashed while DNS was configured
|
||||
// Restore the original resolv.conf
|
||||
if err := copyFile(resolvConfBackupPath, resolvConfPath); err != nil {
|
||||
return fmt.Errorf("restore from backup during cleanup: %w", err)
|
||||
}
|
||||
|
||||
// Remove backup file
|
||||
if err := os.Remove(resolvConfBackupPath); err != nil {
|
||||
return fmt.Errorf("remove backup file during cleanup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (f *FileDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
content, err := os.ReadFile(resolvConfPath)
|
||||
|
||||
@@ -50,11 +50,18 @@ func NewNetworkManagerDNSConfigurator(ifaceName string) (*NetworkManagerDNSConfi
|
||||
return nil, fmt.Errorf("NetworkManager conf.d directory not found: %s", networkManagerConfDir)
|
||||
}
|
||||
|
||||
return &NetworkManagerDNSConfigurator{
|
||||
configurator := &NetworkManagerDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
confPath: networkManagerConfDir + "/" + networkManagerDNSConfFile,
|
||||
dispatchPath: networkManagerDispatcherDir + "/" + networkManagerDispatcherFile,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Clean up any stale configuration from a previous unclean shutdown
|
||||
if err := configurator.CleanupUncleanShutdown(); err != nil {
|
||||
return nil, fmt.Errorf("cleanup unclean shutdown: %w", err)
|
||||
}
|
||||
|
||||
return configurator, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
@@ -100,6 +107,30 @@ func (n *NetworkManagerDNSConfigurator) RestoreDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
|
||||
// For NetworkManager, we check if our config file exists and remove it if so.
|
||||
// This ensures that if the process crashed while DNS was configured, the stale
|
||||
// configuration is removed on the next startup.
|
||||
func (n *NetworkManagerDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
// Check if our config file exists from a previous session
|
||||
if _, err := os.Stat(n.confPath); os.IsNotExist(err) {
|
||||
// No config file, nothing to clean up
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove the stale configuration file
|
||||
if err := os.Remove(n.confPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove stale DNS config file: %w", err)
|
||||
}
|
||||
|
||||
// Reload NetworkManager to apply the change
|
||||
if err := n.reloadNetworkManager(); err != nil {
|
||||
return fmt.Errorf("reload NetworkManager after cleanup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers by reading /etc/resolv.conf
|
||||
func (n *NetworkManagerDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
content, err := os.ReadFile("/etc/resolv.conf")
|
||||
|
||||
@@ -31,10 +31,17 @@ func NewResolvconfDNSConfigurator(ifaceName string) (*ResolvconfDNSConfigurator,
|
||||
return nil, fmt.Errorf("detect resolvconf type: %w", err)
|
||||
}
|
||||
|
||||
return &ResolvconfDNSConfigurator{
|
||||
configurator := &ResolvconfDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
implType: implType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Call cleanup function to remove any stale DNS config for this interface
|
||||
if err := configurator.CleanupUncleanShutdown(); err != nil {
|
||||
return nil, fmt.Errorf("cleanup unclean shutdown: %w", err)
|
||||
}
|
||||
|
||||
return configurator, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
@@ -84,6 +91,28 @@ func (r *ResolvconfDNSConfigurator) RestoreDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
|
||||
// For resolvconf, we attempt to delete any entry for the interface name.
|
||||
// This ensures that if the process crashed while DNS was configured, the stale
|
||||
// entry is removed on the next startup.
|
||||
func (r *ResolvconfDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
// Try to delete any existing entry for this interface
|
||||
// This is idempotent - if no entry exists, resolvconf will just return success
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch r.implType {
|
||||
case "openresolv":
|
||||
cmd = exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||
default:
|
||||
cmd = exec.Command(resolvconfCommand, "-d", r.ifaceName)
|
||||
}
|
||||
|
||||
// Ignore errors - the entry may not exist, which is fine
|
||||
_ = cmd.Run()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (r *ResolvconfDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
// resolvconf doesn't provide a direct way to query per-interface DNS
|
||||
|
||||
@@ -73,10 +73,17 @@ func NewSystemdResolvedDNSConfigurator(ifaceName string) (*SystemdResolvedDNSCon
|
||||
return nil, fmt.Errorf("get link: %w", err)
|
||||
}
|
||||
|
||||
return &SystemdResolvedDNSConfigurator{
|
||||
config := &SystemdResolvedDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
dbusLinkObject: dbus.ObjectPath(linkPath),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Call cleanup function here
|
||||
if err := config.CleanupUncleanShutdown(); err != nil {
|
||||
fmt.Printf("warning: cleanup unclean shutdown failed: %v\n", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
@@ -133,6 +140,17 @@ func (s *SystemdResolvedDNSConfigurator) RestoreDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
|
||||
// For systemd-resolved, the DNS configuration is tied to the network interface.
|
||||
// When the interface is destroyed and recreated, systemd-resolved automatically
|
||||
// clears the per-link DNS settings, so there's nothing to clean up.
|
||||
func (s *SystemdResolvedDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
// systemd-resolved DNS configuration is per-link and automatically cleared
|
||||
// when the link (interface) is destroyed. Since the WireGuard interface is
|
||||
// recreated on restart, there's no leftover state to clean up.
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
// Note: systemd-resolved doesn't easily expose current per-link DNS servers via D-Bus
|
||||
// This is a placeholder that returns an empty list
|
||||
|
||||
@@ -17,6 +17,10 @@ type DNSConfigurator interface {
|
||||
|
||||
// Name returns the name of this configurator implementation
|
||||
Name() string
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from
|
||||
// a previous crash or unclean shutdown. This should be called on startup.
|
||||
CleanupUncleanShutdown() error
|
||||
}
|
||||
|
||||
// DNSConfig contains the configuration for DNS override
|
||||
|
||||
@@ -113,6 +113,18 @@ func (w *WindowsDNSConfigurator) RestoreDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
|
||||
// On Windows, we rely on the registry-based approach which doesn't leave orphaned state
|
||||
// in the same way as macOS scutil. The DNS settings are tied to the interface which
|
||||
// gets recreated on restart.
|
||||
func (w *WindowsDNSConfigurator) CleanupUncleanShutdown() error {
|
||||
// Windows DNS configuration via registry is interface-specific.
|
||||
// When the WireGuard interface is recreated, it gets a new GUID,
|
||||
// so there's no leftover state to clean up from previous sessions.
|
||||
// The old interface's registry keys are effectively orphaned but harmless.
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (w *WindowsDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
regKey, err := w.getInterfaceRegistryKey(registry.QUERY_VALUE)
|
||||
|
||||
23
go.mod
23
go.mod
@@ -4,15 +4,15 @@ go 1.25
|
||||
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2
|
||||
github.com/fosrl/newt v0.0.0-20251208171729-6d7985689552
|
||||
github.com/godbus/dbus/v5 v5.2.0
|
||||
github.com/fosrl/newt v1.9.0
|
||||
github.com/godbus/dbus/v5 v5.2.2
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/miekg/dns v1.1.68
|
||||
golang.org/x/sys v0.38.0
|
||||
github.com/miekg/dns v1.1.70
|
||||
golang.org/x/sys v0.40.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0
|
||||
software.sslmate.com/src/go-pkcs12 v0.7.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -20,13 +20,16 @@ require (
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/vishvananda/netlink v1.3.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/crypto v0.46.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
||||
golang.org/x/mod v0.30.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
golang.org/x/tools v0.39.0 // indirect
|
||||
golang.org/x/tools v0.40.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||
)
|
||||
|
||||
// To be used ONLY for local development
|
||||
// replace github.com/fosrl/newt => ../newt
|
||||
|
||||
40
go.sum
40
go.sum
@@ -1,39 +1,39 @@
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/fosrl/newt v0.0.0-20251208171729-6d7985689552 h1:51pHUtoqQhYPS9OiBDHLgYV44X/CBzR5J7GuWO3izhU=
|
||||
github.com/fosrl/newt v0.0.0-20251208171729-6d7985689552/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
|
||||
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
|
||||
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||
github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE=
|
||||
github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck=
|
||||
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
|
||||
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
|
||||
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
|
||||
github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA=
|
||||
github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||
@@ -44,5 +44,5 @@ golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||
software.sslmate.com/src/go-pkcs12 v0.7.0 h1:Db8W44cB54TWD7stUFFSWxdfpdn6fZVcDl0w3R4RVM0=
|
||||
software.sslmate.com/src/go-pkcs12 v0.7.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||
|
||||
13
main.go
13
main.go
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/updates"
|
||||
"github.com/fosrl/olm/olm"
|
||||
olmpkg "github.com/fosrl/olm/olm"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
|
||||
}
|
||||
|
||||
// Create a new olm.Config struct and copy values from the main config
|
||||
olmConfig := olm.GlobalConfig{
|
||||
olmConfig := olmpkg.OlmConfig{
|
||||
LogLevel: config.LogLevel,
|
||||
EnableAPI: config.EnableAPI,
|
||||
HTTPAddr: config.HTTPAddr,
|
||||
@@ -219,15 +219,20 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
|
||||
Agent: "Olm CLI",
|
||||
OnExit: cancel, // Pass cancel function directly to trigger shutdown
|
||||
OnTerminated: cancel,
|
||||
PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE
|
||||
}
|
||||
|
||||
olm, err := olmpkg.Init(ctx, olmConfig)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to initialize olm: %v", err)
|
||||
}
|
||||
|
||||
olm.Init(ctx, olmConfig)
|
||||
if err := olm.StartApi(); err != nil {
|
||||
logger.Fatal("Failed to start API server: %v", err)
|
||||
}
|
||||
|
||||
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
|
||||
tunnelConfig := olm.TunnelConfig{
|
||||
tunnelConfig := olmpkg.TunnelConfig{
|
||||
Endpoint: config.Endpoint,
|
||||
ID: config.ID,
|
||||
Secret: config.Secret,
|
||||
|
||||
16
olm.iss
16
olm.iss
@@ -44,8 +44,8 @@ Name: "english"; MessagesFile: "compiler:Default.isl"
|
||||
|
||||
[Files]
|
||||
; The 'DestName' flag ensures that 'olm_windows_amd64.exe' is installed as 'olm.exe'
|
||||
Source: "C:\Users\Administrator\Downloads\olm_windows_amd64.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}"; Flags: ignoreversion
|
||||
Source: "C:\Users\Administrator\Downloads\wintun.dll"; DestDir: "{app}"; Flags: ignoreversion
|
||||
Source: "Z:\olm_windows_amd64.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}"; Flags: ignoreversion
|
||||
Source: "Z:\wintun.dll"; DestDir: "{app}"; Flags: ignoreversion
|
||||
; NOTE: Don't use "Flags: ignoreversion" on any shared system files
|
||||
|
||||
[Icons]
|
||||
@@ -78,7 +78,7 @@ begin
|
||||
Result := True;
|
||||
exit;
|
||||
end;
|
||||
|
||||
|
||||
// Perform a case-insensitive check to see if the path is already present.
|
||||
// We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2).
|
||||
if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then
|
||||
@@ -109,7 +109,7 @@ begin
|
||||
PathList.Delimiter := ';';
|
||||
PathList.StrictDelimiter := True;
|
||||
PathList.DelimitedText := OrigPath;
|
||||
|
||||
|
||||
// Find and remove the matching entry (case-insensitive)
|
||||
for I := PathList.Count - 1 downto 0 do
|
||||
begin
|
||||
@@ -119,10 +119,10 @@ begin
|
||||
PathList.Delete(I);
|
||||
end;
|
||||
end;
|
||||
|
||||
|
||||
// Reconstruct the PATH
|
||||
NewPath := PathList.DelimitedText;
|
||||
|
||||
|
||||
// Write the new PATH back to the registry
|
||||
if RegWriteExpandStringValue(HKEY_LOCAL_MACHINE,
|
||||
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||
@@ -145,8 +145,8 @@ begin
|
||||
// Get the application installation path
|
||||
AppPath := ExpandConstant('{app}');
|
||||
Log('Removing PATH entry for: ' + AppPath);
|
||||
|
||||
|
||||
// Remove only our path entry from the system PATH
|
||||
RemovePathEntry(AppPath);
|
||||
end;
|
||||
end;
|
||||
end;
|
||||
|
||||
299
olm/connect.go
Normal file
299
olm/connect.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
olmDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/dns"
|
||||
dnsOverride "github.com/fosrl/olm/dns/override"
|
||||
"github.com/fosrl/olm/peers"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
// OlmErrorData represents the error data sent from the server
|
||||
type OlmErrorData struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring connect message")
|
||||
return
|
||||
}
|
||||
|
||||
var wgData WgData
|
||||
|
||||
if o.registered {
|
||||
logger.Info("Already connected. Ignoring new connection request.")
|
||||
return
|
||||
}
|
||||
|
||||
if o.stopRegister != nil {
|
||||
o.stopRegister()
|
||||
o.stopRegister = nil
|
||||
}
|
||||
|
||||
if o.updateRegister != nil {
|
||||
o.updateRegister = nil
|
||||
}
|
||||
|
||||
// if there is an existing tunnel then close it
|
||||
if o.dev != nil {
|
||||
logger.Info("Got new message. Closing existing tunnel!")
|
||||
o.dev.Close()
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
o.tdev, err = func() (tun.Device, error) {
|
||||
if o.tunnelConfig.FileDescriptorTun != 0 {
|
||||
return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU)
|
||||
}
|
||||
ifName := o.tunnelConfig.InterfaceName
|
||||
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
||||
ifName, err = network.FindUnusedUTUN()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return tun.CreateTUN(ifName, o.tunnelConfig.MTU)
|
||||
}()
|
||||
if err != nil {
|
||||
logger.Error("Failed to create TUN device: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// if config.FileDescriptorTun == 0 {
|
||||
if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
|
||||
o.tunnelConfig.InterfaceName = realInterfaceName
|
||||
}
|
||||
// }
|
||||
|
||||
// Wrap TUN device with packet filter for DNS proxy
|
||||
o.middleDev = olmDevice.NewMiddleDevice(o.tdev)
|
||||
|
||||
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
||||
// Use filtered device instead of raw TUN device
|
||||
o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger))
|
||||
|
||||
if o.tunnelConfig.EnableUAPI {
|
||||
fileUAPI, err := func() (*os.File, error) {
|
||||
if o.tunnelConfig.FileDescriptorUAPI != 0 {
|
||||
fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err)
|
||||
}
|
||||
return os.NewFile(uintptr(fd), ""), nil
|
||||
}
|
||||
return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName)
|
||||
}()
|
||||
if err != nil {
|
||||
logger.Error("UAPI listen error: %v", err)
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI)
|
||||
if err != nil {
|
||||
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := o.uapiListener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go o.dev.IpcHandle(conn)
|
||||
}
|
||||
}()
|
||||
logger.Info("UAPI listener started")
|
||||
}
|
||||
|
||||
if err = o.dev.Up(); err != nil {
|
||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// Extract interface IP (strip CIDR notation if present)
|
||||
interfaceIP := wgData.TunnelIP
|
||||
if strings.Contains(interfaceIP, "/") {
|
||||
interfaceIP = strings.Split(interfaceIP, "/")[0]
|
||||
}
|
||||
|
||||
// Create and start DNS proxy
|
||||
o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create DNS proxy: %v", err)
|
||||
}
|
||||
|
||||
if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil {
|
||||
logger.Error("Failed to o.tunnelConfigure interface: %v", err)
|
||||
}
|
||||
|
||||
if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet
|
||||
logger.Error("Failed to add route for utility subnet: %v", err)
|
||||
}
|
||||
|
||||
// Create peer manager with integrated peer monitoring
|
||||
o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
|
||||
Device: o.dev,
|
||||
DNSProxy: o.dnsProxy,
|
||||
InterfaceName: o.tunnelConfig.InterfaceName,
|
||||
PrivateKey: o.privateKey,
|
||||
MiddleDev: o.middleDev,
|
||||
LocalIP: interfaceIP,
|
||||
SharedBind: o.sharedBind,
|
||||
WSClient: o.websocket,
|
||||
APIServer: o.apiServer,
|
||||
})
|
||||
|
||||
for i := range wgData.Sites {
|
||||
site := wgData.Sites[i]
|
||||
var siteEndpoint string
|
||||
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
|
||||
if site.RelayEndpoint != "" {
|
||||
siteEndpoint = site.RelayEndpoint
|
||||
} else {
|
||||
siteEndpoint = site.Endpoint
|
||||
}
|
||||
|
||||
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
|
||||
|
||||
if err := o.peerManager.AddPeer(site); err != nil {
|
||||
logger.Error("Failed to add peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Configured peer %s", site.PublicKey)
|
||||
}
|
||||
|
||||
o.peerManager.Start()
|
||||
|
||||
if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime
|
||||
logger.Error("Failed to start DNS proxy: %v", err)
|
||||
}
|
||||
|
||||
if o.tunnelConfig.OverrideDNS {
|
||||
// Set up DNS override to use our DNS proxy
|
||||
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
|
||||
logger.Error("Failed to setup DNS override: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()})
|
||||
}
|
||||
|
||||
o.apiServer.SetRegistered(true)
|
||||
|
||||
o.registered = true
|
||||
|
||||
// Start ping monitor now that we are registered and connected
|
||||
o.websocket.StartPingMonitor()
|
||||
|
||||
// Invoke onConnected callback if configured
|
||||
if o.olmConfig.OnConnected != nil {
|
||||
go o.olmConfig.OnConnected()
|
||||
}
|
||||
|
||||
logger.Info("WireGuard device created.")
|
||||
}
|
||||
|
||||
func (o *Olm) handleOlmError(msg websocket.WSMessage) {
|
||||
logger.Debug("Received olm error message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring olm error message")
|
||||
return
|
||||
}
|
||||
|
||||
var errorData OlmErrorData
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling olm error data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &errorData); err != nil {
|
||||
logger.Error("Error unmarshaling olm error data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error("Olm error (code: %s): %s", errorData.Code, errorData.Message)
|
||||
|
||||
// Set the olm error in the API server so it can be exposed via status
|
||||
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
|
||||
|
||||
// Invoke onOlmError callback if configured
|
||||
if o.olmConfig.OnOlmError != nil {
|
||||
go o.olmConfig.OnOlmError(errorData.Code, errorData.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Olm) handleTerminate(msg websocket.WSMessage) {
|
||||
logger.Info("Received terminate message")
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring terminate message")
|
||||
return
|
||||
}
|
||||
|
||||
var errorData OlmErrorData
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling terminate error data: %v", err)
|
||||
} else {
|
||||
if err := json.Unmarshal(jsonData, &errorData); err != nil {
|
||||
logger.Error("Error unmarshaling terminate error data: %v", err)
|
||||
} else {
|
||||
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
|
||||
|
||||
if errorData.Code == "TERMINATED_INACTIVITY" {
|
||||
logger.Info("Ignoring...")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the olm error in the API server so it can be exposed via status
|
||||
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
|
||||
}
|
||||
}
|
||||
|
||||
o.apiServer.SetTerminated(true)
|
||||
o.apiServer.SetConnectionStatus(false)
|
||||
o.apiServer.SetRegistered(false)
|
||||
o.apiServer.ClearPeerStatuses()
|
||||
|
||||
network.ClearNetworkSettings()
|
||||
|
||||
o.Close()
|
||||
|
||||
if o.olmConfig.OnTerminated != nil {
|
||||
go o.olmConfig.OnTerminated()
|
||||
}
|
||||
}
|
||||
365
olm/data.go
Normal file
365
olm/data.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/peers"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
)
|
||||
|
||||
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
|
||||
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring add-remote-subnets-aliases message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var addSubnetsData peers.PeerAdd
|
||||
if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil {
|
||||
logger.Error("Error unmarshaling add-remote-subnets data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists {
|
||||
logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Add new subnets
|
||||
for _, subnet := range addSubnetsData.RemoteSubnets {
|
||||
if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil {
|
||||
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add new aliases
|
||||
for _, alias := range addSubnetsData.Aliases {
|
||||
if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil {
|
||||
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
|
||||
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring remove-remote-subnets-aliases message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var removeSubnetsData peers.RemovePeerData
|
||||
if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil {
|
||||
logger.Error("Error unmarshaling remove-remote-subnets data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists {
|
||||
logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove subnets
|
||||
for _, subnet := range removeSubnetsData.RemoteSubnets {
|
||||
if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil {
|
||||
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove aliases
|
||||
for _, alias := range removeSubnetsData.Aliases {
|
||||
if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil {
|
||||
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
|
||||
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring update-remote-subnets-aliases message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var updateSubnetsData peers.UpdatePeerData
|
||||
if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil {
|
||||
logger.Error("Error unmarshaling update-remote-subnets data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists {
|
||||
logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Add new subnets BEFORE removing old ones to preserve shared subnets
|
||||
// This ensures that if an old and new subnet are the same on different peers,
|
||||
// the route won't be temporarily removed
|
||||
for _, subnet := range updateSubnetsData.NewRemoteSubnets {
|
||||
if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
|
||||
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old subnets after new ones are added
|
||||
for _, subnet := range updateSubnetsData.OldRemoteSubnets {
|
||||
if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
|
||||
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add new aliases BEFORE removing old ones to preserve shared IP addresses
|
||||
// This ensures that if an old and new alias share the same IP, the IP won't be
|
||||
// temporarily removed from the allowed IPs list
|
||||
for _, alias := range updateSubnetsData.NewAliases {
|
||||
if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil {
|
||||
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old aliases after new ones are added
|
||||
for _, alias := range updateSubnetsData.OldAliases {
|
||||
if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil {
|
||||
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId)
|
||||
}
|
||||
|
||||
// Handler for syncing peer configuration - reconciles expected state with actual state
|
||||
func (o *Olm) handleSync(msg websocket.WSMessage) {
|
||||
logger.Debug("Received sync message: %v", msg.Data)
|
||||
|
||||
if !o.registered {
|
||||
logger.Warn("Not connected, ignoring sync request")
|
||||
return
|
||||
}
|
||||
|
||||
if o.peerManager == nil {
|
||||
logger.Warn("Peer manager not initialized, ignoring sync request")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling sync data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var syncData SyncData
|
||||
if err := json.Unmarshal(jsonData, &syncData); err != nil {
|
||||
logger.Error("Error unmarshaling sync data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Sync exit nodes for hole punching
|
||||
o.syncExitNodes(syncData.ExitNodes)
|
||||
|
||||
// Build a map of expected peers from the incoming data
|
||||
expectedPeers := make(map[int]peers.SiteConfig)
|
||||
for _, site := range syncData.Sites {
|
||||
expectedPeers[site.SiteId] = site
|
||||
}
|
||||
|
||||
// Get all current peers
|
||||
currentPeers := o.peerManager.GetAllPeers()
|
||||
currentPeerMap := make(map[int]peers.SiteConfig)
|
||||
for _, peer := range currentPeers {
|
||||
currentPeerMap[peer.SiteId] = peer
|
||||
}
|
||||
|
||||
// Find peers to remove (in current but not in expected)
|
||||
for siteId := range currentPeerMap {
|
||||
if _, exists := expectedPeers[siteId]; !exists {
|
||||
logger.Info("Sync: Removing peer for site %d (no longer in expected config)", siteId)
|
||||
if err := o.peerManager.RemovePeer(siteId); err != nil {
|
||||
logger.Error("Sync: Failed to remove peer %d: %v", siteId, err)
|
||||
} else {
|
||||
// Remove any exit nodes associated with this peer from hole punching
|
||||
if o.holePunchManager != nil {
|
||||
removed := o.holePunchManager.RemoveExitNodesByPeer(siteId)
|
||||
if removed > 0 {
|
||||
logger.Info("Sync: Removed %d exit nodes associated with peer %d from hole punch rotation", removed, siteId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find peers to add (in expected but not in current) and peers to update
|
||||
for siteId, expectedSite := range expectedPeers {
|
||||
if _, exists := currentPeerMap[siteId]; !exists {
|
||||
// New peer - add it using the add flow (with holepunch)
|
||||
logger.Info("Sync: Adding new peer for site %d", siteId)
|
||||
|
||||
o.holePunchManager.TriggerHolePunch()
|
||||
|
||||
// // TODO: do we need to send the message to the cloud to add the peer that way?
|
||||
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
|
||||
// logger.Error("Sync: Failed to add peer %d: %v", siteId, err)
|
||||
// } else {
|
||||
// logger.Info("Sync: Successfully added peer for site %d", siteId)
|
||||
// }
|
||||
|
||||
// add the peer via the server
|
||||
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
|
||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": expectedSite.SiteId,
|
||||
}, 1*time.Second, 10)
|
||||
|
||||
} else {
|
||||
// Existing peer - check if update is needed
|
||||
currentSite := currentPeerMap[siteId]
|
||||
needsUpdate := false
|
||||
|
||||
// Check if any fields have changed
|
||||
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
|
||||
needsUpdate = true
|
||||
}
|
||||
if expectedSite.RelayEndpoint != "" && expectedSite.RelayEndpoint != currentSite.RelayEndpoint {
|
||||
needsUpdate = true
|
||||
}
|
||||
if expectedSite.PublicKey != "" && expectedSite.PublicKey != currentSite.PublicKey {
|
||||
needsUpdate = true
|
||||
}
|
||||
if expectedSite.ServerIP != "" && expectedSite.ServerIP != currentSite.ServerIP {
|
||||
needsUpdate = true
|
||||
}
|
||||
if expectedSite.ServerPort != 0 && expectedSite.ServerPort != currentSite.ServerPort {
|
||||
needsUpdate = true
|
||||
}
|
||||
// Check remote subnets
|
||||
if expectedSite.RemoteSubnets != nil && !slicesEqual(expectedSite.RemoteSubnets, currentSite.RemoteSubnets) {
|
||||
needsUpdate = true
|
||||
}
|
||||
// Check aliases
|
||||
if expectedSite.Aliases != nil && !aliasesEqual(expectedSite.Aliases, currentSite.Aliases) {
|
||||
needsUpdate = true
|
||||
}
|
||||
|
||||
if needsUpdate {
|
||||
logger.Info("Sync: Updating peer for site %d", siteId)
|
||||
|
||||
// Merge expected data with current data
|
||||
siteConfig := currentSite
|
||||
if expectedSite.Endpoint != "" {
|
||||
siteConfig.Endpoint = expectedSite.Endpoint
|
||||
}
|
||||
if expectedSite.RelayEndpoint != "" {
|
||||
siteConfig.RelayEndpoint = expectedSite.RelayEndpoint
|
||||
}
|
||||
if expectedSite.PublicKey != "" {
|
||||
siteConfig.PublicKey = expectedSite.PublicKey
|
||||
}
|
||||
if expectedSite.ServerIP != "" {
|
||||
siteConfig.ServerIP = expectedSite.ServerIP
|
||||
}
|
||||
if expectedSite.ServerPort != 0 {
|
||||
siteConfig.ServerPort = expectedSite.ServerPort
|
||||
}
|
||||
if expectedSite.RemoteSubnets != nil {
|
||||
siteConfig.RemoteSubnets = expectedSite.RemoteSubnets
|
||||
}
|
||||
if expectedSite.Aliases != nil {
|
||||
siteConfig.Aliases = expectedSite.Aliases
|
||||
}
|
||||
|
||||
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
|
||||
logger.Error("Sync: Failed to update peer %d: %v", siteId, err)
|
||||
} else {
|
||||
// If the endpoint changed, trigger holepunch to refresh NAT mappings
|
||||
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
|
||||
logger.Info("Sync: Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", siteId)
|
||||
o.holePunchManager.TriggerHolePunch()
|
||||
o.holePunchManager.ResetServerHolepunchInterval()
|
||||
}
|
||||
logger.Info("Sync: Successfully updated peer for site %d", siteId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers))
|
||||
}
|
||||
|
||||
// syncExitNodes reconciles the expected exit nodes with the current ones in the hole punch manager
|
||||
func (o *Olm) syncExitNodes(expectedExitNodes []SyncExitNode) {
|
||||
if o.holePunchManager == nil {
|
||||
logger.Warn("Hole punch manager not initialized, skipping exit node sync")
|
||||
return
|
||||
}
|
||||
|
||||
// Build a map of expected exit nodes by endpoint
|
||||
expectedExitNodeMap := make(map[string]SyncExitNode)
|
||||
for _, exitNode := range expectedExitNodes {
|
||||
expectedExitNodeMap[exitNode.Endpoint] = exitNode
|
||||
}
|
||||
|
||||
// Get current exit nodes from hole punch manager
|
||||
currentExitNodes := o.holePunchManager.GetExitNodes()
|
||||
currentExitNodeMap := make(map[string]holepunch.ExitNode)
|
||||
for _, exitNode := range currentExitNodes {
|
||||
currentExitNodeMap[exitNode.Endpoint] = exitNode
|
||||
}
|
||||
|
||||
// Find exit nodes to remove (in current but not in expected)
|
||||
for endpoint := range currentExitNodeMap {
|
||||
if _, exists := expectedExitNodeMap[endpoint]; !exists {
|
||||
logger.Info("Sync: Removing exit node %s (no longer in expected config)", endpoint)
|
||||
o.holePunchManager.RemoveExitNode(endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
// Find exit nodes to add (in expected but not in current)
|
||||
for endpoint, expectedExitNode := range expectedExitNodeMap {
|
||||
if _, exists := currentExitNodeMap[endpoint]; !exists {
|
||||
logger.Info("Sync: Adding new exit node %s", endpoint)
|
||||
|
||||
relayPort := expectedExitNode.RelayPort
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820 // default relay port
|
||||
}
|
||||
|
||||
hpExitNode := holepunch.ExitNode{
|
||||
Endpoint: expectedExitNode.Endpoint,
|
||||
RelayPort: relayPort,
|
||||
PublicKey: expectedExitNode.PublicKey,
|
||||
SiteIds: expectedExitNode.SiteIds,
|
||||
}
|
||||
|
||||
if o.holePunchManager.AddExitNode(hpExitNode) {
|
||||
logger.Info("Sync: Successfully added exit node %s", endpoint)
|
||||
}
|
||||
o.holePunchManager.TriggerHolePunch()
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Sync exit nodes completed: processed %d expected exit nodes, had %d current exit nodes", len(expectedExitNodeMap), len(currentExitNodeMap))
|
||||
}
|
||||
1370
olm/olm.go
1370
olm/olm.go
File diff suppressed because it is too large
Load Diff
282
olm/peer.go
Normal file
282
olm/peer.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/peers"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
)
|
||||
|
||||
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
||||
logger.Debug("Received add-peer message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring add-peer message")
|
||||
return
|
||||
}
|
||||
|
||||
if o.stopPeerSend != nil {
|
||||
o.stopPeerSend()
|
||||
o.stopPeerSend = nil
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var siteConfig peers.SiteConfig
|
||||
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
|
||||
logger.Error("Error unmarshaling add data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||
|
||||
if err := o.peerManager.AddPeer(siteConfig); err != nil {
|
||||
logger.Error("Failed to add peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
||||
logger.Debug("Received remove-peer message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring remove-peer message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var removeData peers.PeerRemove
|
||||
if err := json.Unmarshal(jsonData, &removeData); err != nil {
|
||||
logger.Error("Error unmarshaling remove data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil {
|
||||
logger.Error("Failed to remove peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove any exit nodes associated with this peer from hole punching
|
||||
if o.holePunchManager != nil {
|
||||
removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId)
|
||||
if removed > 0 {
|
||||
logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
|
||||
logger.Debug("Received update-peer message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring update-peer message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var updateData peers.SiteConfig
|
||||
if err := json.Unmarshal(jsonData, &updateData); err != nil {
|
||||
logger.Error("Error unmarshaling update data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get existing peer from PeerManager
|
||||
existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId)
|
||||
if !exists {
|
||||
logger.Warn("Peer with site ID %d not found", updateData.SiteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Create updated site config by merging with existing data
|
||||
siteConfig := existingPeer
|
||||
|
||||
if updateData.Endpoint != "" {
|
||||
siteConfig.Endpoint = updateData.Endpoint
|
||||
}
|
||||
if updateData.RelayEndpoint != "" {
|
||||
siteConfig.RelayEndpoint = updateData.RelayEndpoint
|
||||
}
|
||||
if updateData.PublicKey != "" {
|
||||
siteConfig.PublicKey = updateData.PublicKey
|
||||
}
|
||||
if updateData.ServerIP != "" {
|
||||
siteConfig.ServerIP = updateData.ServerIP
|
||||
}
|
||||
if updateData.ServerPort != 0 {
|
||||
siteConfig.ServerPort = updateData.ServerPort
|
||||
}
|
||||
if updateData.RemoteSubnets != nil {
|
||||
siteConfig.RemoteSubnets = updateData.RemoteSubnets
|
||||
}
|
||||
|
||||
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
|
||||
logger.Error("Failed to update peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the endpoint changed, trigger holepunch to refresh NAT mappings
|
||||
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
|
||||
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
|
||||
_ = o.holePunchManager.TriggerHolePunch()
|
||||
o.holePunchManager.ResetServerHolepunchInterval()
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
|
||||
logger.Debug("Received relay-peer message: %v", msg.Data)
|
||||
|
||||
// Check if peerManager is still valid (may be nil during shutdown)
|
||||
if o.peerManager == nil {
|
||||
logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var relayData peers.RelayPeerData
|
||||
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||
logger.Error("Error unmarshaling relay data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve primary relay endpoint: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update HTTP server to mark this peer as using relay
|
||||
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true)
|
||||
|
||||
o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
|
||||
logger.Debug("Received unrelay-peer message: %v", msg.Data)
|
||||
|
||||
// Check if peerManager is still valid (may be nil during shutdown)
|
||||
if o.peerManager == nil {
|
||||
logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var relayData peers.UnRelayPeerData
|
||||
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||
logger.Error("Error unmarshaling relay data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||
}
|
||||
|
||||
// Update HTTP server to mark this peer as using relay
|
||||
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false)
|
||||
|
||||
o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
logger.Debug("Received peer-handshake message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring peer-handshake message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling handshake data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var handshakeData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
ExitNode struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
RelayPort uint16 `json:"relayPort"`
|
||||
} `json:"exitNode"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
|
||||
logger.Error("Error unmarshaling handshake data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get existing peer from PeerManager
|
||||
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
||||
if exists {
|
||||
logger.Warn("Peer with site ID %d already added", handshakeData.SiteId)
|
||||
return
|
||||
}
|
||||
|
||||
relayPort := handshakeData.ExitNode.RelayPort
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820 // default relay port
|
||||
}
|
||||
|
||||
siteId := handshakeData.SiteId
|
||||
exitNode := holepunch.ExitNode{
|
||||
Endpoint: handshakeData.ExitNode.Endpoint,
|
||||
RelayPort: relayPort,
|
||||
PublicKey: handshakeData.ExitNode.PublicKey,
|
||||
SiteIds: []int{siteId},
|
||||
}
|
||||
|
||||
added := o.holePunchManager.AddExitNode(exitNode)
|
||||
if added {
|
||||
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
|
||||
} else {
|
||||
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
|
||||
}
|
||||
|
||||
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
||||
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||
|
||||
// Send handshake acknowledgment back to server with retry
|
||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": handshakeData.SiteId,
|
||||
}, 1*time.Second, 10)
|
||||
|
||||
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||
}
|
||||
27
olm/types.go
27
olm/types.go
@@ -12,9 +12,22 @@ type WgData struct {
|
||||
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
|
||||
}
|
||||
|
||||
type GlobalConfig struct {
|
||||
type SyncData struct {
|
||||
Sites []peers.SiteConfig `json:"sites"`
|
||||
ExitNodes []SyncExitNode `json:"exitNodes"`
|
||||
}
|
||||
|
||||
type SyncExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
RelayPort uint16 `json:"relayPort"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
SiteIds []int `json:"siteIds"`
|
||||
}
|
||||
|
||||
type OlmConfig struct {
|
||||
// Logging
|
||||
LogLevel string
|
||||
LogLevel string
|
||||
LogFilePath string
|
||||
|
||||
// HTTP server
|
||||
EnableAPI bool
|
||||
@@ -23,11 +36,17 @@ type GlobalConfig struct {
|
||||
Version string
|
||||
Agent string
|
||||
|
||||
WakeUpDebounce time.Duration
|
||||
|
||||
// Debugging
|
||||
PprofAddr string // Address to serve pprof on (e.g., "localhost:6060")
|
||||
|
||||
// Callbacks
|
||||
OnRegistered func()
|
||||
OnConnected func()
|
||||
OnTerminated func()
|
||||
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
|
||||
OnOlmError func(code string, message string) // Called when registration fails
|
||||
OnExit func() // Called when exit is requested via API
|
||||
}
|
||||
|
||||
@@ -61,6 +80,10 @@ type TunnelConfig struct {
|
||||
EnableUAPI bool
|
||||
|
||||
OverrideDNS bool
|
||||
TunnelDNS bool
|
||||
|
||||
InitialFingerprint map[string]any
|
||||
InitialPostures map[string]any
|
||||
|
||||
DisableRelay bool
|
||||
}
|
||||
|
||||
119
olm/util.go
119
olm/util.go
@@ -1,96 +1,45 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"github.com/fosrl/olm/peers"
|
||||
)
|
||||
|
||||
// Helper function to format endpoints correctly
|
||||
func formatEndpoint(endpoint string) string {
|
||||
if endpoint == "" {
|
||||
return ""
|
||||
}
|
||||
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
|
||||
_, _, err := net.SplitHostPort(endpoint)
|
||||
if err == nil {
|
||||
return endpoint // Already valid, no change needed
|
||||
}
|
||||
|
||||
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
|
||||
lastColon := strings.LastIndex(endpoint, ":")
|
||||
if lastColon > 0 { // Ensure there is a colon and it's not the first character
|
||||
hostPart := endpoint[:lastColon]
|
||||
// Check if the host part is a literal IPv6 address
|
||||
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
|
||||
// It is! Reformat it with brackets.
|
||||
portPart := endpoint[lastColon+1:]
|
||||
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
|
||||
}
|
||||
}
|
||||
|
||||
// If it's not the specific malformed case, return it as is.
|
||||
return endpoint
|
||||
}
|
||||
|
||||
func sendPing(olm *websocket.Client) error {
|
||||
err := olm.SendMessage("olm/ping", map[string]interface{}{
|
||||
"timestamp": time.Now().Unix(),
|
||||
"userToken": olm.GetConfig().UserToken,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send ping message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Debug("Sent ping message")
|
||||
return nil
|
||||
}
|
||||
|
||||
func keepSendingPing(olm *websocket.Client) {
|
||||
// Send ping immediately on startup
|
||||
if err := sendPing(olm); err != nil {
|
||||
logger.Error("Failed to send initial ping: %v", err)
|
||||
} else {
|
||||
logger.Info("Sent initial ping message")
|
||||
}
|
||||
|
||||
// Set up ticker for one minute intervals
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopPing:
|
||||
logger.Info("Stopping ping messages")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := sendPing(olm); err != nil {
|
||||
logger.Error("Failed to send periodic ping: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetNetworkSettingsJSON() (string, error) {
|
||||
return network.GetJSON()
|
||||
}
|
||||
|
||||
func GetNetworkSettingsIncrementor() int {
|
||||
return network.GetIncrementor()
|
||||
}
|
||||
|
||||
// stringSlicesEqual compares two string slices for equality
|
||||
func stringSlicesEqual(a, b []string) bool {
|
||||
// slicesEqual compares two string slices for equality (order-independent)
|
||||
func slicesEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
// Create a map to count occurrences in slice a
|
||||
counts := make(map[string]int)
|
||||
for _, v := range a {
|
||||
counts[v]++
|
||||
}
|
||||
// Check if slice b has the same elements
|
||||
for _, v := range b {
|
||||
counts[v]--
|
||||
if counts[v] < 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// aliasesEqual compares two Alias slices for equality (order-independent)
|
||||
func aliasesEqual(a, b []peers.Alias) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
// Create a map to count occurrences in slice a (using alias+address as key)
|
||||
counts := make(map[string]int)
|
||||
for _, v := range a {
|
||||
key := v.Alias + "|" + v.AliasAddress
|
||||
counts[key]++
|
||||
}
|
||||
// Check if slice b has the same elements
|
||||
for _, v := range b {
|
||||
key := v.Alias + "|" + v.AliasAddress
|
||||
counts[key]--
|
||||
if counts[key] < 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,6 +50,8 @@ type PeerManager struct {
|
||||
// key is the CIDR string, value is a set of siteIds that want this IP
|
||||
allowedIPClaims map[string]map[int]bool
|
||||
APIServer *api.API
|
||||
|
||||
PersistentKeepalive int
|
||||
}
|
||||
|
||||
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
|
||||
@@ -84,6 +86,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
|
||||
return peer, ok
|
||||
}
|
||||
|
||||
// GetPeerMonitor returns the internal peer monitor instance
|
||||
func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
return pm.peerMonitor
|
||||
}
|
||||
|
||||
func (pm *PeerManager) GetAllPeers() []SiteConfig {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
@@ -120,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
||||
wgConfig := siteConfig
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -159,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once
|
||||
// without recreating them. Returns a map of siteId to error for any peers that failed to update.
|
||||
func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
pm.PersistentKeepalive = interval
|
||||
|
||||
errors := make(map[int]error)
|
||||
|
||||
for siteId, peer := range pm.peers {
|
||||
err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval)
|
||||
if err != nil {
|
||||
errors[siteId] = err
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) == 0 {
|
||||
return nil
|
||||
}
|
||||
return errors
|
||||
}
|
||||
|
||||
func (pm *PeerManager) RemovePeer(siteId int) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
@@ -238,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
|
||||
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||
wgConfig := promotedPeer
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
|
||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||
}
|
||||
}
|
||||
@@ -314,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
||||
wgConfig := siteConfig
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -324,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
||||
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||
promotedWgConfig := promotedPeer
|
||||
promotedWgConfig.AllowedIps = promotedOwnedIPs
|
||||
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
|
||||
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
|
||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||
}
|
||||
}
|
||||
@@ -743,7 +775,7 @@ func (pm *PeerManager) RemoveAlias(siteId int, aliasName string) error {
|
||||
}
|
||||
|
||||
// RelayPeer handles failover to the relay server when a peer is disconnected
|
||||
func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string) {
|
||||
func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string, relayPort uint16) {
|
||||
pm.mu.Lock()
|
||||
peer, exists := pm.peers[siteId]
|
||||
if exists {
|
||||
@@ -764,10 +796,14 @@ func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string) {
|
||||
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
|
||||
}
|
||||
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820 // fall back to 21820 for backward compatibility
|
||||
}
|
||||
|
||||
// Update only the endpoint for this peer (update_only preserves other settings)
|
||||
wgConfig := fmt.Sprintf(`public_key=%s
|
||||
update_only=true
|
||||
endpoint=%s:21820`, util.FixKey(peer.PublicKey), formattedEndpoint)
|
||||
endpoint=%s:%d`, util.FixKey(peer.PublicKey), formattedEndpoint, relayPort)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
|
||||
@@ -31,8 +31,7 @@ type PeerMonitor struct {
|
||||
monitors map[int]*Client
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
wsClient *websocket.Client
|
||||
|
||||
@@ -42,7 +41,7 @@ type PeerMonitor struct {
|
||||
stack *stack.Stack
|
||||
ep *channel.Endpoint
|
||||
activePorts map[uint16]bool
|
||||
portsLock sync.Mutex
|
||||
portsLock sync.RWMutex
|
||||
nsCtx context.Context
|
||||
nsCancel context.CancelFunc
|
||||
nsWg sync.WaitGroup
|
||||
@@ -50,17 +49,26 @@ type PeerMonitor struct {
|
||||
// Holepunch testing fields
|
||||
sharedBind *bind.SharedBind
|
||||
holepunchTester *holepunch.HolepunchTester
|
||||
holepunchInterval time.Duration
|
||||
holepunchTimeout time.Duration
|
||||
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
||||
holepunchStatus map[int]bool // siteID -> connected status
|
||||
holepunchStopChan chan struct{}
|
||||
holepunchStopChan chan struct{}
|
||||
holepunchUpdateChan chan struct{}
|
||||
|
||||
// Relay tracking fields
|
||||
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
|
||||
holepunchMaxAttempts int // max consecutive failures before triggering relay
|
||||
holepunchFailures map[int]int // siteID -> consecutive failure count
|
||||
|
||||
// Exponential backoff fields for holepunch monitor
|
||||
defaultHolepunchMinInterval time.Duration // Minimum interval (initial)
|
||||
defaultHolepunchMaxInterval time.Duration
|
||||
holepunchMinInterval time.Duration // Minimum interval (initial)
|
||||
holepunchMaxInterval time.Duration // Maximum interval (cap for backoff)
|
||||
holepunchBackoffMultiplier float64 // Multiplier for each stable check
|
||||
holepunchStableCount map[int]int // siteID -> consecutive stable status count
|
||||
holepunchCurrentInterval time.Duration // Current interval with backoff applied
|
||||
|
||||
// Rapid initial test fields
|
||||
rapidTestInterval time.Duration // interval between rapid test attempts
|
||||
rapidTestTimeout time.Duration // timeout for each rapid test attempt
|
||||
@@ -78,7 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pm := &PeerMonitor{
|
||||
monitors: make(map[int]*Client),
|
||||
interval: 2 * time.Second, // Default check interval (faster)
|
||||
timeout: 3 * time.Second,
|
||||
maxAttempts: 3,
|
||||
wsClient: wsClient,
|
||||
@@ -88,7 +95,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
nsCtx: ctx,
|
||||
nsCancel: cancel,
|
||||
sharedBind: sharedBind,
|
||||
holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds
|
||||
holepunchTimeout: 2 * time.Second, // Faster timeout
|
||||
holepunchEndpoints: make(map[int]string),
|
||||
holepunchStatus: make(map[int]bool),
|
||||
@@ -101,6 +107,15 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
|
||||
apiServer: apiServer,
|
||||
wgConnectionStatus: make(map[int]bool),
|
||||
// Exponential backoff settings for holepunch monitor
|
||||
defaultHolepunchMinInterval: 2 * time.Second,
|
||||
defaultHolepunchMaxInterval: 30 * time.Second,
|
||||
holepunchMinInterval: 2 * time.Second,
|
||||
holepunchMaxInterval: 30 * time.Second,
|
||||
holepunchBackoffMultiplier: 1.5,
|
||||
holepunchStableCount: make(map[int]int),
|
||||
holepunchCurrentInterval: 2 * time.Second,
|
||||
holepunchUpdateChan: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
if err := pm.initNetstack(); err != nil {
|
||||
@@ -116,41 +131,75 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
}
|
||||
|
||||
// SetInterval changes how frequently peers are checked
|
||||
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
|
||||
func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.interval = interval
|
||||
|
||||
// Update interval for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetPacketInterval(interval)
|
||||
client.SetPacketInterval(minInterval, maxInterval)
|
||||
}
|
||||
|
||||
logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval)
|
||||
}
|
||||
|
||||
// SetTimeout changes the timeout for waiting for responses
|
||||
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
|
||||
func (pm *PeerMonitor) ResetPeerInterval() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.timeout = timeout
|
||||
|
||||
// Update timeout for all existing monitors
|
||||
// Update interval for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetTimeout(timeout)
|
||||
client.ResetPacketInterval()
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
|
||||
// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
|
||||
func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
pm.holepunchMinInterval = minInterval
|
||||
pm.holepunchMaxInterval = maxInterval
|
||||
// Reset current interval to the new minimum
|
||||
pm.holepunchCurrentInterval = minInterval
|
||||
updateChan := pm.holepunchUpdateChan
|
||||
pm.mutex.Unlock()
|
||||
|
||||
logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval)
|
||||
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if updateChan != nil {
|
||||
select {
|
||||
case updateChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring
|
||||
func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.maxAttempts = attempts
|
||||
return pm.holepunchMinInterval, pm.holepunchMaxInterval
|
||||
}
|
||||
|
||||
// Update max attempts for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetMaxAttempts(attempts)
|
||||
func (pm *PeerMonitor) ResetPeerHolepunchInterval() {
|
||||
pm.mutex.Lock()
|
||||
pm.holepunchMinInterval = pm.defaultHolepunchMinInterval
|
||||
pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval
|
||||
pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval
|
||||
updateChan := pm.holepunchUpdateChan
|
||||
pm.mutex.Unlock()
|
||||
|
||||
logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval)
|
||||
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if updateChan != nil {
|
||||
select {
|
||||
case updateChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,10 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
|
||||
return err
|
||||
}
|
||||
|
||||
client.SetPacketInterval(pm.interval)
|
||||
client.SetTimeout(pm.timeout)
|
||||
client.SetMaxAttempts(pm.maxAttempts)
|
||||
|
||||
pm.monitors[siteID] = client
|
||||
|
||||
pm.holepunchEndpoints[siteID] = holepunchEndpoint
|
||||
@@ -191,12 +236,12 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
|
||||
|
||||
// update holepunch endpoint for a peer
|
||||
func (pm *PeerMonitor) UpdateHolepunchEndpoint(siteID int, endpoint string) {
|
||||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
pm.holepunchEndpoints[siteID] = endpoint
|
||||
}()
|
||||
// Short delay to allow WireGuard peer reconfiguration to complete
|
||||
// The NAT mapping refresh is handled separately by TriggerHolePunch in olm.go
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
pm.holepunchEndpoints[siteID] = endpoint
|
||||
logger.Debug("Updated holepunch endpoint for site %d to %s", siteID, endpoint)
|
||||
}
|
||||
|
||||
// RapidTestPeer performs a rapid connectivity test for a newly added peer.
|
||||
@@ -294,6 +339,12 @@ func (pm *PeerMonitor) RemovePeer(siteID int) {
|
||||
pm.removePeerUnlocked(siteID)
|
||||
}
|
||||
|
||||
func (pm *PeerMonitor) RemoveHolepunchEndpoint(siteID int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
delete(pm.holepunchEndpoints, siteID)
|
||||
}
|
||||
|
||||
// Start begins monitoring all peers
|
||||
func (pm *PeerMonitor) Start() {
|
||||
pm.mutex.Lock()
|
||||
@@ -464,31 +515,59 @@ func (pm *PeerMonitor) stopHolepunchMonitor() {
|
||||
logger.Info("Stopped holepunch connection monitor")
|
||||
}
|
||||
|
||||
// runHolepunchMonitor runs the holepunch monitoring loop
|
||||
// runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff
|
||||
func (pm *PeerMonitor) runHolepunchMonitor() {
|
||||
ticker := time.NewTicker(pm.holepunchInterval)
|
||||
defer ticker.Stop()
|
||||
pm.mutex.Lock()
|
||||
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
||||
pm.mutex.Unlock()
|
||||
|
||||
// Do initial check immediately
|
||||
pm.checkHolepunchEndpoints()
|
||||
timer := time.NewTimer(0) // Fire immediately for initial check
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pm.holepunchStopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
pm.checkHolepunchEndpoints()
|
||||
case <-pm.holepunchUpdateChan:
|
||||
// Interval settings changed, reset to minimum
|
||||
pm.mutex.Lock()
|
||||
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
||||
currentInterval := pm.holepunchCurrentInterval
|
||||
pm.mutex.Unlock()
|
||||
|
||||
timer.Reset(currentInterval)
|
||||
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
|
||||
case <-timer.C:
|
||||
anyStatusChanged := pm.checkHolepunchEndpoints()
|
||||
|
||||
pm.mutex.Lock()
|
||||
if anyStatusChanged {
|
||||
// Reset to minimum interval on any status change
|
||||
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
||||
} else {
|
||||
// Apply exponential backoff when stable
|
||||
newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier)
|
||||
if newInterval > pm.holepunchMaxInterval {
|
||||
newInterval = pm.holepunchMaxInterval
|
||||
}
|
||||
pm.holepunchCurrentInterval = newInterval
|
||||
}
|
||||
currentInterval := pm.holepunchCurrentInterval
|
||||
pm.mutex.Unlock()
|
||||
|
||||
timer.Reset(currentInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkHolepunchEndpoints tests all holepunch endpoints
|
||||
func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
// Returns true if any endpoint's status changed
|
||||
func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
|
||||
pm.mutex.Lock()
|
||||
// Check if we're still running before doing any work
|
||||
if !pm.running {
|
||||
pm.mutex.Unlock()
|
||||
return
|
||||
return false
|
||||
}
|
||||
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
|
||||
for siteID, endpoint := range pm.holepunchEndpoints {
|
||||
@@ -498,8 +577,10 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
maxAttempts := pm.holepunchMaxAttempts
|
||||
pm.mutex.Unlock()
|
||||
|
||||
anyStatusChanged := false
|
||||
|
||||
for siteID, endpoint := range endpoints {
|
||||
logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
|
||||
// logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint)
|
||||
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
|
||||
|
||||
pm.mutex.Lock()
|
||||
@@ -523,7 +604,9 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
pm.mutex.Unlock()
|
||||
|
||||
// Log status changes
|
||||
if !exists || previousStatus != result.Success {
|
||||
statusChanged := !exists || previousStatus != result.Success
|
||||
if statusChanged {
|
||||
anyStatusChanged = true
|
||||
if result.Success {
|
||||
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
|
||||
} else {
|
||||
@@ -556,7 +639,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !stillRunning {
|
||||
return // Stop processing if shutdown is in progress
|
||||
return anyStatusChanged // Stop processing if shutdown is in progress
|
||||
}
|
||||
|
||||
if !result.Success && !isRelayed && failureCount >= maxAttempts {
|
||||
@@ -573,6 +656,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return anyStatusChanged
|
||||
}
|
||||
|
||||
// GetHolepunchStatus returns the current holepunch status for all endpoints
|
||||
@@ -644,55 +729,55 @@ func (pm *PeerMonitor) Close() {
|
||||
logger.Debug("PeerMonitor: Cleanup complete")
|
||||
}
|
||||
|
||||
// TestPeer tests connectivity to a specific peer
|
||||
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||
pm.mutex.Lock()
|
||||
client, exists := pm.monitors[siteID]
|
||||
pm.mutex.Unlock()
|
||||
// // TestPeer tests connectivity to a specific peer
|
||||
// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||
// pm.mutex.Lock()
|
||||
// client, exists := pm.monitors[siteID]
|
||||
// pm.mutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||
}
|
||||
// if !exists {
|
||||
// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||
// }
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
defer cancel()
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
// defer cancel()
|
||||
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
return connected, rtt, nil
|
||||
}
|
||||
// connected, rtt := client.TestPeerConnection(ctx)
|
||||
// return connected, rtt, nil
|
||||
// }
|
||||
|
||||
// TestAllPeers tests connectivity to all peers
|
||||
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
} {
|
||||
pm.mutex.Lock()
|
||||
peers := make(map[int]*Client, len(pm.monitors))
|
||||
for siteID, client := range pm.monitors {
|
||||
peers[siteID] = client
|
||||
}
|
||||
pm.mutex.Unlock()
|
||||
// // TestAllPeers tests connectivity to all peers
|
||||
// func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
// Connected bool
|
||||
// RTT time.Duration
|
||||
// } {
|
||||
// pm.mutex.Lock()
|
||||
// peers := make(map[int]*Client, len(pm.monitors))
|
||||
// for siteID, client := range pm.monitors {
|
||||
// peers[siteID] = client
|
||||
// }
|
||||
// pm.mutex.Unlock()
|
||||
|
||||
results := make(map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
})
|
||||
for siteID, client := range peers {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
cancel()
|
||||
// results := make(map[int]struct {
|
||||
// Connected bool
|
||||
// RTT time.Duration
|
||||
// })
|
||||
// for siteID, client := range peers {
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
// connected, rtt := client.TestPeerConnection(ctx)
|
||||
// cancel()
|
||||
|
||||
results[siteID] = struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
}{
|
||||
Connected: connected,
|
||||
RTT: rtt,
|
||||
}
|
||||
}
|
||||
// results[siteID] = struct {
|
||||
// Connected bool
|
||||
// RTT time.Duration
|
||||
// }{
|
||||
// Connected: connected,
|
||||
// RTT: rtt,
|
||||
// }
|
||||
// }
|
||||
|
||||
return results
|
||||
}
|
||||
// return results
|
||||
// }
|
||||
|
||||
// initNetstack initializes the gvisor netstack
|
||||
func (pm *PeerMonitor) initNetstack() error {
|
||||
@@ -764,9 +849,9 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool {
|
||||
}
|
||||
|
||||
// Check if we are listening on this port
|
||||
pm.portsLock.Lock()
|
||||
pm.portsLock.RLock()
|
||||
active := pm.activePorts[uint16(port)]
|
||||
pm.portsLock.Unlock()
|
||||
pm.portsLock.RUnlock()
|
||||
|
||||
if !active {
|
||||
return false
|
||||
@@ -797,13 +882,12 @@ func (pm *PeerMonitor) runPacketSender() {
|
||||
defer pm.nsWg.Done()
|
||||
logger.Debug("PeerMonitor: Packet sender goroutine started")
|
||||
|
||||
// Use a ticker to periodically check for packets without blocking indefinitely
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pm.nsCtx.Done():
|
||||
// Use blocking ReadContext instead of polling - much more CPU efficient
|
||||
// This will block until a packet is available or context is cancelled
|
||||
pkt := pm.ep.ReadContext(pm.nsCtx)
|
||||
if pkt == nil {
|
||||
// Context was cancelled or endpoint closed
|
||||
logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets")
|
||||
// Drain any remaining packets before exiting
|
||||
for {
|
||||
@@ -815,36 +899,28 @@ func (pm *PeerMonitor) runPacketSender() {
|
||||
}
|
||||
logger.Debug("PeerMonitor: Packet sender goroutine exiting")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Try to read packets in batches
|
||||
for i := 0; i < 10; i++ {
|
||||
pkt := pm.ep.Read()
|
||||
if pkt == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Extract packet data
|
||||
slices := pkt.AsSlices()
|
||||
if len(slices) > 0 {
|
||||
var totalSize int
|
||||
for _, slice := range slices {
|
||||
totalSize += len(slice)
|
||||
}
|
||||
|
||||
buf := make([]byte, totalSize)
|
||||
pos := 0
|
||||
for _, slice := range slices {
|
||||
copy(buf[pos:], slice)
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Inject into MiddleDevice (outbound to WG)
|
||||
pm.middleDev.InjectOutbound(buf)
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
// Extract packet data
|
||||
slices := pkt.AsSlices()
|
||||
if len(slices) > 0 {
|
||||
var totalSize int
|
||||
for _, slice := range slices {
|
||||
totalSize += len(slice)
|
||||
}
|
||||
|
||||
buf := make([]byte, totalSize)
|
||||
pos := 0
|
||||
for _, slice := range slices {
|
||||
copy(buf[pos:], slice)
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Inject into MiddleDevice (outbound to WG)
|
||||
pm.middleDev.InjectOutbound(buf)
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,10 +32,19 @@ type Client struct {
|
||||
monitorLock sync.Mutex
|
||||
connLock sync.Mutex // Protects connection operations
|
||||
shutdownCh chan struct{}
|
||||
updateCh chan struct{}
|
||||
packetInterval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
dialer Dialer
|
||||
|
||||
// Exponential backoff fields
|
||||
defaultMinInterval time.Duration // Default minimum interval (initial)
|
||||
defaultMaxInterval time.Duration // Default maximum interval (cap for backoff)
|
||||
minInterval time.Duration // Minimum interval (initial)
|
||||
maxInterval time.Duration // Maximum interval (cap for backoff)
|
||||
backoffMultiplier float64 // Multiplier for each stable check
|
||||
stableCountToBackoff int // Number of stable checks before backing off
|
||||
}
|
||||
|
||||
// Dialer is a function that creates a connection
|
||||
@@ -50,28 +59,59 @@ type ConnectionStatus struct {
|
||||
// NewClient creates a new connection test client
|
||||
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
|
||||
return &Client{
|
||||
serverAddr: serverAddr,
|
||||
shutdownCh: make(chan struct{}),
|
||||
packetInterval: 2 * time.Second,
|
||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||
maxAttempts: 3, // Default max attempts
|
||||
dialer: dialer,
|
||||
serverAddr: serverAddr,
|
||||
shutdownCh: make(chan struct{}),
|
||||
updateCh: make(chan struct{}, 1),
|
||||
packetInterval: 2 * time.Second,
|
||||
defaultMinInterval: 2 * time.Second,
|
||||
defaultMaxInterval: 30 * time.Second,
|
||||
minInterval: 2 * time.Second,
|
||||
maxInterval: 30 * time.Second,
|
||||
backoffMultiplier: 1.5,
|
||||
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
|
||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||
maxAttempts: 3, // Default max attempts
|
||||
dialer: dialer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetPacketInterval changes how frequently packets are sent in monitor mode
|
||||
func (c *Client) SetPacketInterval(interval time.Duration) {
|
||||
c.packetInterval = interval
|
||||
func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) {
|
||||
c.monitorLock.Lock()
|
||||
c.packetInterval = minInterval
|
||||
c.minInterval = minInterval
|
||||
c.maxInterval = maxInterval
|
||||
updateCh := c.updateCh
|
||||
monitorRunning := c.monitorRunning
|
||||
c.monitorLock.Unlock()
|
||||
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if monitorRunning && updateCh != nil {
|
||||
select {
|
||||
case updateCh <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeout changes the timeout for waiting for responses
|
||||
func (c *Client) SetTimeout(timeout time.Duration) {
|
||||
c.timeout = timeout
|
||||
}
|
||||
func (c *Client) ResetPacketInterval() {
|
||||
c.monitorLock.Lock()
|
||||
c.packetInterval = c.defaultMinInterval
|
||||
c.minInterval = c.defaultMinInterval
|
||||
c.maxInterval = c.defaultMaxInterval
|
||||
updateCh := c.updateCh
|
||||
monitorRunning := c.monitorRunning
|
||||
c.monitorLock.Unlock()
|
||||
|
||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||
func (c *Client) SetMaxAttempts(attempts int) {
|
||||
c.maxAttempts = attempts
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if monitorRunning && updateCh != nil {
|
||||
select {
|
||||
case updateCh <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateServerAddr updates the server address and resets the connection
|
||||
@@ -125,9 +165,10 @@ func (c *Client) ensureConnection() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestConnection checks if the connection to the server is working
|
||||
// TestPeerConnection checks if the connection to the server is working
|
||||
// Returns true if connected, false otherwise
|
||||
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
|
||||
// logger.Debug("wgtester: testing connection to peer %s", c.serverAddr)
|
||||
if err := c.ensureConnection(); err != nil {
|
||||
logger.Warn("Failed to ensure connection: %v", err)
|
||||
return false, 0
|
||||
@@ -138,6 +179,9 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
|
||||
packet[4] = packetTypeRequest
|
||||
|
||||
// Reusable response buffer
|
||||
responseBuffer := make([]byte, packetSize)
|
||||
|
||||
// Send multiple attempts as specified
|
||||
for attempt := 0; attempt < c.maxAttempts; attempt++ {
|
||||
select {
|
||||
@@ -157,20 +201,17 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
|
||||
_, err := c.conn.Write(packet)
|
||||
if err != nil {
|
||||
c.connLock.Unlock()
|
||||
logger.Info("Error sending packet: %v", err)
|
||||
continue
|
||||
}
|
||||
// logger.Debug("Successfully sent monitor packet")
|
||||
|
||||
// Set read deadline
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||
|
||||
// Wait for response
|
||||
responseBuffer := make([]byte, packetSize)
|
||||
n, err := c.conn.Read(responseBuffer)
|
||||
c.connLock.Unlock()
|
||||
|
||||
@@ -211,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
return c.TestConnection(ctx)
|
||||
return c.TestPeerConnection(ctx)
|
||||
}
|
||||
|
||||
// MonitorCallback is the function type for connection status change callbacks
|
||||
@@ -238,28 +279,61 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
|
||||
go func() {
|
||||
var lastConnected bool
|
||||
firstRun := true
|
||||
stableCount := 0
|
||||
currentInterval := c.minInterval
|
||||
|
||||
ticker := time.NewTicker(c.packetInterval)
|
||||
defer ticker.Stop()
|
||||
timer := time.NewTimer(currentInterval)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.shutdownCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
case <-c.updateCh:
|
||||
// Interval settings changed, reset to minimum
|
||||
c.monitorLock.Lock()
|
||||
currentInterval = c.minInterval
|
||||
c.monitorLock.Unlock()
|
||||
|
||||
// Reset backoff state
|
||||
stableCount = 0
|
||||
|
||||
timer.Reset(currentInterval)
|
||||
logger.Debug("Packet interval updated, reset to %v", currentInterval)
|
||||
case <-timer.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
connected, rtt := c.TestConnection(ctx)
|
||||
connected, rtt := c.TestPeerConnection(ctx)
|
||||
cancel()
|
||||
|
||||
statusChanged := connected != lastConnected
|
||||
|
||||
// Callback if status changed or it's the first check
|
||||
if connected != lastConnected || firstRun {
|
||||
if statusChanged || firstRun {
|
||||
callback(ConnectionStatus{
|
||||
Connected: connected,
|
||||
RTT: rtt,
|
||||
})
|
||||
lastConnected = connected
|
||||
firstRun = false
|
||||
// Reset backoff on status change
|
||||
stableCount = 0
|
||||
currentInterval = c.minInterval
|
||||
} else {
|
||||
// Status is stable, increment counter
|
||||
stableCount++
|
||||
|
||||
// Apply exponential backoff after stable threshold
|
||||
if stableCount >= c.stableCountToBackoff {
|
||||
newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier)
|
||||
if newInterval > c.maxInterval {
|
||||
newInterval = c.maxInterval
|
||||
}
|
||||
currentInterval = newInterval
|
||||
}
|
||||
}
|
||||
|
||||
// Reset timer with current interval
|
||||
timer.Reset(currentInterval)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error {
|
||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
|
||||
var endpoint string
|
||||
if relay && siteConfig.RelayEndpoint != "" {
|
||||
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
|
||||
@@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
|
||||
}
|
||||
|
||||
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
||||
configBuilder.WriteString("persistent_keepalive_interval=5\n")
|
||||
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive))
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Configuring peer with config: %s", config)
|
||||
@@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it
|
||||
func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error {
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||
configBuilder.WriteString("update_only=true\n")
|
||||
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval))
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Updating persistent keepalive for peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatEndpoint(endpoint string) string {
|
||||
if strings.Contains(endpoint, ":") {
|
||||
return endpoint
|
||||
|
||||
@@ -33,6 +33,7 @@ type PeerRemove struct {
|
||||
type RelayPeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
RelayEndpoint string `json:"relayEndpoint"`
|
||||
RelayPort uint16 `json:"relayPort"`
|
||||
}
|
||||
|
||||
type UnRelayPeerData struct {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -48,12 +49,15 @@ type TokenResponse struct {
|
||||
|
||||
type ExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
RelayPort uint16 `json:"relayPort"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
SiteIds []int `json:"siteIds"`
|
||||
}
|
||||
|
||||
type WSMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
ConfigVersion int `json:"configVersion,omitempty"`
|
||||
}
|
||||
|
||||
// this is not json anymore
|
||||
@@ -75,6 +79,7 @@ type Client struct {
|
||||
handlersMux sync.RWMutex
|
||||
reconnectInterval time.Duration
|
||||
isConnected bool
|
||||
isDisconnected bool // Flag to track if client is intentionally disconnected
|
||||
reconnectMux sync.RWMutex
|
||||
pingInterval time.Duration
|
||||
pingTimeout time.Duration
|
||||
@@ -85,6 +90,19 @@ type Client struct {
|
||||
clientType string // Type of client (e.g., "newt", "olm")
|
||||
tlsConfig TLSConfig
|
||||
configNeedsSave bool // Flag to track if config needs to be saved
|
||||
configVersion int // Latest config version received from server
|
||||
configVersionMux sync.RWMutex
|
||||
token string // Cached authentication token
|
||||
exitNodes []ExitNode // Cached exit nodes from token response
|
||||
tokenMux sync.RWMutex // Protects token and exitNodes
|
||||
forceNewToken bool // Flag to force fetching a new token on next connection
|
||||
processingMessage bool // Flag to track if a message is currently being processed
|
||||
processingMux sync.RWMutex // Protects processingMessage
|
||||
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
|
||||
getPingData func() map[string]any // Callback to get additional ping data
|
||||
pingStarted bool // Flag to track if ping monitor has been started
|
||||
pingStartedMux sync.Mutex // Protects pingStarted
|
||||
pingDone chan struct{} // Channel to stop the ping monitor independently
|
||||
}
|
||||
|
||||
type ClientOption func(*Client)
|
||||
@@ -120,6 +138,13 @@ func WithTLSConfig(config TLSConfig) ClientOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPingDataProvider sets a callback to provide additional data for ping messages
|
||||
func WithPingDataProvider(fn func() map[string]any) ClientOption {
|
||||
return func(c *Client) {
|
||||
c.getPingData = fn
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) OnConnect(callback func() error) {
|
||||
c.onConnect = callback
|
||||
}
|
||||
@@ -152,6 +177,7 @@ func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.
|
||||
pingInterval: pingInterval,
|
||||
pingTimeout: pingTimeout,
|
||||
clientType: "olm",
|
||||
pingDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Apply options before loading config
|
||||
@@ -171,6 +197,9 @@ func (c *Client) GetConfig() *Config {
|
||||
|
||||
// Connect establishes the WebSocket connection
|
||||
func (c *Client) Connect() error {
|
||||
if c.isDisconnected {
|
||||
c.isDisconnected = false
|
||||
}
|
||||
go c.connectWithRetry()
|
||||
return nil
|
||||
}
|
||||
@@ -203,9 +232,31 @@ func (c *Client) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
|
||||
func (c *Client) Disconnect() error {
|
||||
c.isDisconnected = true
|
||||
c.setConnected(false)
|
||||
|
||||
// Stop the ping monitor
|
||||
c.stopPingMonitor()
|
||||
|
||||
// Wait for any message currently being processed to complete
|
||||
c.processingWg.Wait()
|
||||
|
||||
if c.conn != nil {
|
||||
c.writeMux.Lock()
|
||||
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
c.writeMux.Unlock()
|
||||
err := c.conn.Close()
|
||||
c.conn = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessage sends a message through the WebSocket connection
|
||||
func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||
if c.conn == nil {
|
||||
if c.isDisconnected || c.conn == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
@@ -214,14 +265,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||
Data: data,
|
||||
}
|
||||
|
||||
logger.Debug("Sending message: %s, data: %+v", messageType, data)
|
||||
logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
|
||||
|
||||
c.writeMux.Lock()
|
||||
defer c.writeMux.Unlock()
|
||||
return c.conn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
|
||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
|
||||
stopChan := make(chan struct{})
|
||||
updateChan := make(chan interface{})
|
||||
var dataMux sync.Mutex
|
||||
@@ -229,30 +280,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
||||
|
||||
go func() {
|
||||
count := 0
|
||||
maxAttempts := 10
|
||||
|
||||
err := c.SendMessage(messageType, currentData) // Send immediately
|
||||
if err != nil {
|
||||
logger.Error("Failed to send initial message: %v", err)
|
||||
send := func() {
|
||||
if c.isDisconnected || c.conn == nil {
|
||||
return
|
||||
}
|
||||
err := c.SendMessage(messageType, currentData)
|
||||
if err != nil {
|
||||
logger.Error("websocket: Failed to send message: %v", err)
|
||||
}
|
||||
count++
|
||||
}
|
||||
count++
|
||||
|
||||
send() // Send immediately
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if count >= maxAttempts {
|
||||
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||
if maxAttempts != -1 && count >= maxAttempts {
|
||||
logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||
return
|
||||
}
|
||||
dataMux.Lock()
|
||||
err = c.SendMessage(messageType, currentData)
|
||||
send()
|
||||
dataMux.Unlock()
|
||||
if err != nil {
|
||||
logger.Error("Failed to send message: %v", err)
|
||||
}
|
||||
count++
|
||||
case newData := <-updateChan:
|
||||
dataMux.Lock()
|
||||
// Merge newData into currentData if both are maps
|
||||
@@ -275,6 +328,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
// Suspend sending if disconnected
|
||||
for c.isDisconnected {
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
@@ -321,7 +382,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
tlsConfig = &tls.Config{}
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
tokenData := map[string]interface{}{
|
||||
@@ -350,7 +411,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
||||
|
||||
// print out the request for debugging
|
||||
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
|
||||
logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
|
||||
|
||||
// Make the request
|
||||
client := &http.Client{}
|
||||
@@ -367,7 +428,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
|
||||
// Return AuthError for 401/403 status codes
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
@@ -383,7 +444,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
logger.Error("Failed to decode token response.")
|
||||
logger.Error("websocket: Failed to decode token response.")
|
||||
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
@@ -395,7 +456,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
return "", nil, fmt.Errorf("received empty token from server")
|
||||
}
|
||||
|
||||
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
||||
logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
|
||||
|
||||
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
|
||||
}
|
||||
@@ -409,7 +470,8 @@ func (c *Client) connectWithRetry() {
|
||||
err := c.establishConnection()
|
||||
if err != nil {
|
||||
// Check if this is an auth error (401/403)
|
||||
if authErr, ok := err.(*AuthError); ok {
|
||||
var authErr *AuthError
|
||||
if errors.As(err, &authErr) {
|
||||
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
|
||||
// Trigger auth error callback if set (this should terminate the tunnel)
|
||||
if c.onAuthError != nil {
|
||||
@@ -420,7 +482,7 @@ func (c *Client) connectWithRetry() {
|
||||
continue
|
||||
}
|
||||
// For other errors (5xx, network issues), continue retrying
|
||||
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||
logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||
time.Sleep(c.reconnectInterval)
|
||||
continue
|
||||
}
|
||||
@@ -430,15 +492,25 @@ func (c *Client) connectWithRetry() {
|
||||
}
|
||||
|
||||
func (c *Client) establishConnection() error {
|
||||
// Get token for authentication
|
||||
token, exitNodes, err := c.getToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
// Get token for authentication - reuse cached token unless forced to get new one
|
||||
c.tokenMux.Lock()
|
||||
needNewToken := c.token == "" || c.forceNewToken
|
||||
if needNewToken {
|
||||
token, exitNodes, err := c.getToken()
|
||||
if err != nil {
|
||||
c.tokenMux.Unlock()
|
||||
return fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
c.token = token
|
||||
c.exitNodes = exitNodes
|
||||
c.forceNewToken = false
|
||||
|
||||
if c.onTokenUpdate != nil {
|
||||
c.onTokenUpdate(token, exitNodes)
|
||||
if c.onTokenUpdate != nil {
|
||||
c.onTokenUpdate(token, exitNodes)
|
||||
}
|
||||
}
|
||||
token := c.token
|
||||
c.tokenMux.Unlock()
|
||||
|
||||
// Parse the base URL to determine protocol and hostname
|
||||
baseURL, err := url.Parse(c.baseURL)
|
||||
@@ -473,7 +545,7 @@ func (c *Client) establishConnection() error {
|
||||
|
||||
// Use new TLS configuration method
|
||||
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||
logger.Info("Setting up TLS configuration for WebSocket connection")
|
||||
logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
|
||||
tlsConfig, err := c.setupTLS()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||
@@ -487,25 +559,38 @@ func (c *Client) establishConnection() error {
|
||||
dialer.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
dialer.TLSClientConfig.InsecureSkipVerify = true
|
||||
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
conn, _, err := dialer.Dial(u.String(), nil)
|
||||
conn, resp, err := dialer.Dial(u.String(), nil)
|
||||
if err != nil {
|
||||
// Check if this is an unauthorized error (401)
|
||||
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
|
||||
logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized")
|
||||
// Force getting a new token on next reconnect attempt
|
||||
c.tokenMux.Lock()
|
||||
c.forceNewToken = true
|
||||
c.tokenMux.Unlock()
|
||||
return &AuthError{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
Message: "WebSocket connection unauthorized",
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
c.setConnected(true)
|
||||
|
||||
// Start the ping monitor
|
||||
go c.pingMonitor()
|
||||
// Note: ping monitor is NOT started here - it will be started when
|
||||
// StartPingMonitor() is called after registration completes
|
||||
|
||||
// Start the read pump with disconnect detection
|
||||
go c.readPumpWithDisconnectDetection()
|
||||
|
||||
if c.onConnect != nil {
|
||||
if err := c.onConnect(); err != nil {
|
||||
logger.Error("OnConnect callback failed: %v", err)
|
||||
logger.Error("websocket: OnConnect callback failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -518,9 +603,9 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
||||
|
||||
// Handle new separate certificate configuration
|
||||
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
|
||||
logger.Info("Loading separate certificate files for mTLS")
|
||||
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||
logger.Info("websocket: Loading separate certificate files for mTLS")
|
||||
logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||
logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||
|
||||
// Load client certificate and key
|
||||
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
|
||||
@@ -531,7 +616,7 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
||||
|
||||
// Load CA certificates for remote validation if specified
|
||||
if len(c.tlsConfig.CAFiles) > 0 {
|
||||
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||
logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||
caCertPool := x509.NewCertPool()
|
||||
for _, caFile := range c.tlsConfig.CAFiles {
|
||||
caCert, err := os.ReadFile(caFile)
|
||||
@@ -557,13 +642,13 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
||||
|
||||
// Fallback to existing PKCS12 implementation for backward compatibility
|
||||
if c.tlsConfig.PKCS12File != "" {
|
||||
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
|
||||
logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
|
||||
return c.setupPKCS12TLS()
|
||||
}
|
||||
|
||||
// Legacy fallback using config.TlsClientCert
|
||||
if c.config.TlsClientCert != "" {
|
||||
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||
logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||
return loadClientCertificate(c.config.TlsClientCert)
|
||||
}
|
||||
|
||||
@@ -575,6 +660,59 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
|
||||
return loadClientCertificate(c.tlsConfig.PKCS12File)
|
||||
}
|
||||
|
||||
// sendPing sends a single ping message
|
||||
func (c *Client) sendPing() {
|
||||
if c.isDisconnected || c.conn == nil {
|
||||
return
|
||||
}
|
||||
// Skip ping if a message is currently being processed
|
||||
c.processingMux.RLock()
|
||||
isProcessing := c.processingMessage
|
||||
c.processingMux.RUnlock()
|
||||
if isProcessing {
|
||||
logger.Debug("websocket: Skipping ping, message is being processed")
|
||||
return
|
||||
}
|
||||
// Send application-level ping with config version
|
||||
c.configVersionMux.RLock()
|
||||
configVersion := c.configVersion
|
||||
c.configVersionMux.RUnlock()
|
||||
|
||||
pingData := map[string]any{
|
||||
"timestamp": time.Now().Unix(),
|
||||
"userToken": c.config.UserToken,
|
||||
}
|
||||
if c.getPingData != nil {
|
||||
for k, v := range c.getPingData() {
|
||||
pingData[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
pingMsg := WSMessage{
|
||||
Type: "olm/ping",
|
||||
Data: pingData,
|
||||
ConfigVersion: configVersion,
|
||||
}
|
||||
|
||||
logger.Debug("websocket: Sending ping: %+v", pingMsg)
|
||||
|
||||
c.writeMux.Lock()
|
||||
err := c.conn.WriteJSON(pingMsg)
|
||||
c.writeMux.Unlock()
|
||||
if err != nil {
|
||||
// Check if we're shutting down before logging error and reconnecting
|
||||
select {
|
||||
case <-c.done:
|
||||
// Expected during shutdown
|
||||
return
|
||||
default:
|
||||
logger.Error("websocket: Ping failed: %v", err)
|
||||
c.reconnect()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pingMonitor sends pings at a short interval and triggers reconnect on failure
|
||||
func (c *Client) pingMonitor() {
|
||||
ticker := time.NewTicker(c.pingInterval)
|
||||
@@ -584,29 +722,65 @@ func (c *Client) pingMonitor() {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case <-c.pingDone:
|
||||
return
|
||||
case <-ticker.C:
|
||||
if c.conn == nil {
|
||||
return
|
||||
}
|
||||
c.writeMux.Lock()
|
||||
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
|
||||
c.writeMux.Unlock()
|
||||
if err != nil {
|
||||
// Check if we're shutting down before logging error and reconnecting
|
||||
select {
|
||||
case <-c.done:
|
||||
// Expected during shutdown
|
||||
return
|
||||
default:
|
||||
logger.Error("Ping failed: %v", err)
|
||||
c.reconnect()
|
||||
return
|
||||
}
|
||||
}
|
||||
c.sendPing()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartPingMonitor starts the ping monitor goroutine.
|
||||
// This should be called after the client is registered and connected.
|
||||
// It is safe to call multiple times - only the first call will start the monitor.
|
||||
func (c *Client) StartPingMonitor() {
|
||||
c.pingStartedMux.Lock()
|
||||
defer c.pingStartedMux.Unlock()
|
||||
|
||||
if c.pingStarted {
|
||||
return
|
||||
}
|
||||
c.pingStarted = true
|
||||
|
||||
// Create a new pingDone channel for this ping monitor instance
|
||||
c.pingDone = make(chan struct{})
|
||||
|
||||
// Send an initial ping immediately
|
||||
go func() {
|
||||
c.sendPing()
|
||||
c.pingMonitor()
|
||||
}()
|
||||
}
|
||||
|
||||
// stopPingMonitor stops the ping monitor goroutine if it's running.
|
||||
func (c *Client) stopPingMonitor() {
|
||||
c.pingStartedMux.Lock()
|
||||
defer c.pingStartedMux.Unlock()
|
||||
|
||||
if !c.pingStarted {
|
||||
return
|
||||
}
|
||||
|
||||
// Close the pingDone channel to stop the monitor
|
||||
close(c.pingDone)
|
||||
c.pingStarted = false
|
||||
}
|
||||
|
||||
// GetConfigVersion returns the current config version
|
||||
func (c *Client) GetConfigVersion() int {
|
||||
c.configVersionMux.RLock()
|
||||
defer c.configVersionMux.RUnlock()
|
||||
return c.configVersion
|
||||
}
|
||||
|
||||
// setConfigVersion updates the config version if the new version is higher
|
||||
func (c *Client) setConfigVersion(version int) {
|
||||
c.configVersionMux.Lock()
|
||||
defer c.configVersionMux.Unlock()
|
||||
logger.Debug("websocket: setting config version to %d", version)
|
||||
c.configVersion = version
|
||||
}
|
||||
|
||||
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
|
||||
func (c *Client) readPumpWithDisconnectDetection() {
|
||||
defer func() {
|
||||
@@ -631,26 +805,47 @@ func (c *Client) readPumpWithDisconnectDetection() {
|
||||
var msg WSMessage
|
||||
err := c.conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
// Check if we're shutting down before logging error
|
||||
// Check if we're shutting down or explicitly disconnected before logging error
|
||||
select {
|
||||
case <-c.done:
|
||||
// Expected during shutdown, don't log as error
|
||||
logger.Debug("WebSocket connection closed during shutdown")
|
||||
logger.Debug("websocket: connection closed during shutdown")
|
||||
return
|
||||
default:
|
||||
// Check if explicitly disconnected
|
||||
if c.isDisconnected {
|
||||
logger.Debug("websocket: connection closed: client was explicitly disconnected")
|
||||
return
|
||||
}
|
||||
|
||||
// Unexpected error during normal operation
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
||||
logger.Error("WebSocket read error: %v", err)
|
||||
logger.Error("websocket: read error: %v", err)
|
||||
} else {
|
||||
logger.Debug("WebSocket connection closed: %v", err)
|
||||
logger.Debug("websocket: connection closed: %v", err)
|
||||
}
|
||||
return // triggers reconnect via defer
|
||||
}
|
||||
}
|
||||
|
||||
// Update config version from incoming message
|
||||
c.setConfigVersion(msg.ConfigVersion)
|
||||
|
||||
c.handlersMux.RLock()
|
||||
if handler, ok := c.handlers[msg.Type]; ok {
|
||||
// Mark that we're processing a message
|
||||
c.processingMux.Lock()
|
||||
c.processingMessage = true
|
||||
c.processingMux.Unlock()
|
||||
c.processingWg.Add(1)
|
||||
|
||||
handler(msg)
|
||||
|
||||
// Mark that we're done processing
|
||||
c.processingWg.Done()
|
||||
c.processingMux.Lock()
|
||||
c.processingMessage = false
|
||||
c.processingMux.Unlock()
|
||||
}
|
||||
c.handlersMux.RUnlock()
|
||||
}
|
||||
@@ -664,6 +859,12 @@ func (c *Client) reconnect() {
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
// Don't reconnect if explicitly disconnected
|
||||
if c.isDisconnected {
|
||||
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
|
||||
return
|
||||
}
|
||||
|
||||
// Only reconnect if we're not shutting down
|
||||
select {
|
||||
case <-c.done:
|
||||
@@ -681,7 +882,7 @@ func (c *Client) setConnected(status bool) {
|
||||
|
||||
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
|
||||
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||
logger.Info("Loading tls-client-cert %s", p12Path)
|
||||
logger.Info("websocket: Loading tls-client-cert %s", p12Path)
|
||||
// Read the PKCS12 file
|
||||
p12Data, err := os.ReadFile(p12Path)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user