mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-08 05:56:40 +00:00
Compare commits
138 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3e73d0189 | ||
|
|
df2fbdf160 | ||
|
|
cb4ac8199d | ||
|
|
dd4b86b3e5 | ||
|
|
bad290aa4e | ||
|
|
8c27d5e3bf | ||
|
|
7e7a37d49c | ||
|
|
d44aa97f32 | ||
|
|
b57ad74589 | ||
|
|
82256a3f6f | ||
|
|
9e140a94db | ||
|
|
d0c9ea5a57 | ||
|
|
c88810ef24 | ||
|
|
463a4eea79 | ||
|
|
4576a2e8a7 | ||
|
|
69c13adcdb | ||
|
|
3886c1a8c1 | ||
|
|
06eb4d4310 | ||
|
|
247c47b27f | ||
|
|
060038c29b | ||
|
|
5414d21dcd | ||
|
|
364fa020aa | ||
|
|
b96ee16fbf | ||
|
|
467d69aa7c | ||
|
|
7c7762ebc5 | ||
|
|
526f9c8b4e | ||
|
|
905983cf61 | ||
|
|
a0879114e2 | ||
|
|
0d54a07973 | ||
|
|
4cb2fde961 | ||
|
|
9602599565 | ||
|
|
11f858b341 | ||
|
|
29b2cb33a2 | ||
|
|
34290ffe09 | ||
|
|
1013d0591e | ||
|
|
2f6d62ab45 | ||
|
|
8d6ba79408 | ||
|
|
208b434cb7 | ||
|
|
39ce0ac407 | ||
|
|
72bee56412 | ||
|
|
b32da3a714 | ||
|
|
971452e5d3 | ||
|
|
bba4345b0f | ||
|
|
b2392fb250 | ||
|
|
697f4131e7 | ||
|
|
e282715251 | ||
|
|
709df6db3e | ||
|
|
cf2b436470 | ||
|
|
2a29021572 | ||
|
|
a3f9a89079 | ||
|
|
ee27bf3153 | ||
|
|
a90f681957 | ||
|
|
3afc82ef9a | ||
|
|
d3a16f4c59 | ||
|
|
2a1911a66f | ||
|
|
08341b2385 | ||
|
|
6cde07d479 | ||
|
|
06b1e84f99 | ||
|
|
2b7e93ec92 | ||
|
|
ca23ae7a30 | ||
|
|
661fd86305 | ||
|
|
594a499b95 | ||
|
|
44aed84827 | ||
|
|
bf038eb4a2 | ||
|
|
6da3129b4e | ||
|
|
ac0f9b6a82 | ||
|
|
16aef10cca | ||
|
|
19031ebdfd | ||
|
|
0eebbc51d5 | ||
|
|
d321a8ba7e | ||
|
|
3ea86222ca | ||
|
|
c3ebe930d9 | ||
|
|
f2b96f2a38 | ||
|
|
9038239bbe | ||
|
|
3e64eb9c4f | ||
|
|
92992b8c14 | ||
|
|
4ee9d77532 | ||
|
|
bd7a5bd4b0 | ||
|
|
1cd49f8ee3 | ||
|
|
7a919d867b | ||
|
|
ce50c627a7 | ||
|
|
691d5f0271 | ||
|
|
56151089e3 | ||
|
|
af7c1caf98 | ||
|
|
dd208ab67c | ||
|
|
8189d41a45 | ||
|
|
ea3477c8ce | ||
|
|
a8a0f92c9b | ||
|
|
7040a9436e | ||
|
|
04361242fe | ||
|
|
554b1d55dc | ||
|
|
b03f8911a5 | ||
|
|
47589570c9 | ||
|
|
9f5b8dea26 | ||
|
|
f6a1e1e27c | ||
|
|
f983a8f141 | ||
|
|
efce3cb0b2 | ||
|
|
6eeebd81b2 | ||
|
|
c970fd5a18 | ||
|
|
09bd02456d | ||
|
|
c24537af36 | ||
|
|
9de3f14799 | ||
|
|
0908f75f5f | ||
|
|
10958f8c55 | ||
|
|
b1840fd5c3 | ||
|
|
1df5eb19ff | ||
|
|
f71f183886 | ||
|
|
8922ca9736 | ||
|
|
38483f4a26 | ||
|
|
78c768e497 | ||
|
|
fc7df8a530 | ||
|
|
50b42059ac | ||
|
|
825f7fcf60 | ||
|
|
8c8ec72b40 | ||
|
|
c61b7fc4fb | ||
|
|
96e3376147 | ||
|
|
e47a7c80d1 | ||
|
|
f1e373f2d8 | ||
|
|
ef4d0db475 | ||
|
|
b6b97f5ed3 | ||
|
|
dff267a42e | ||
|
|
bb98db7f5e | ||
|
|
f1016200b3 | ||
|
|
f1ab8094cf | ||
|
|
ad2bc0d397 | ||
|
|
a78d141ca3 | ||
|
|
10b1ad2a5a | ||
|
|
8a9f29043a | ||
|
|
05c9d851f4 | ||
|
|
c9a6b85e1d | ||
|
|
a16021cd86 | ||
|
|
9506b545f4 | ||
|
|
17b87e6707 | ||
|
|
993f5f86c5 | ||
|
|
093a4c21f2 | ||
|
|
f7c0bb9135 | ||
|
|
a145b77f79 | ||
|
|
7b3f7d2b12 |
@@ -6,4 +6,5 @@ README.md
|
||||
Makefile
|
||||
public/
|
||||
LICENSE
|
||||
CONTRIBUTING.md
|
||||
CONTRIBUTING.md
|
||||
.git
|
||||
|
||||
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Summary
|
||||
description: A clear and concise summary of the requested feature.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Motivation
|
||||
description: |
|
||||
Why is this feature important?
|
||||
Explain the problem this feature would solve or what use case it would enable.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Proposed Solution
|
||||
description: |
|
||||
How would you like to see this feature implemented?
|
||||
Provide as much detail as possible about the desired behavior, configuration, or changes.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
description: Describe any alternative solutions or workarounds you've thought about.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context, mockups, or screenshots about the feature request here.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before submitting, please:
|
||||
- Check if there is an existing issue for this feature.
|
||||
- Clearly explain the benefit and use case.
|
||||
- Be as specific as possible to help contributors evaluate and implement.
|
||||
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: Bug Report
|
||||
description: Create a bug report
|
||||
labels: []
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Describe the Bug
|
||||
description: A clear and concise description of what the bug is.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Environment
|
||||
description: Please fill out the relevant details below for your environment.
|
||||
value: |
|
||||
- OS Type & Version: (e.g., Ubuntu 22.04)
|
||||
- Pangolin Version:
|
||||
- Gerbil Version:
|
||||
- Traefik Version:
|
||||
- Newt Version:
|
||||
- Olm Version: (if applicable)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: To Reproduce
|
||||
description: |
|
||||
Steps to reproduce the behavior, please provide a clear description of how to reproduce the issue, based on the linked minimal reproduction. Screenshots can be provided in the issue body below.
|
||||
|
||||
If using code blocks, make sure syntax highlighting is correct and double-check that the rendered preview is not broken.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: A clear and concise description of what you expected to happen.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before posting the issue go through the steps you've written down to make sure the steps provided are detailed and clear.
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Contributors should be able to follow the steps provided in order to reproduce the bug.
|
||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Need help or have questions?
|
||||
url: https://github.com/orgs/fosrl/discussions
|
||||
about: Ask questions, get help, and discuss with other community members
|
||||
- name: Request a Feature
|
||||
url: https://github.com/orgs/fosrl/discussions/new?category=feature-requests
|
||||
about: Feature requests should be opened as discussions so others can upvote and comment
|
||||
40
.github/dependabot.yml
vendored
Normal file
40
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "gomod"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
groups:
|
||||
dev-patch-updates:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "patch"
|
||||
dev-minor-updates:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
prod-patch-updates:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "patch"
|
||||
prod-minor-updates:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
|
||||
- package-ecosystem: "docker"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
groups:
|
||||
patch-updates:
|
||||
update-types:
|
||||
- "patch"
|
||||
minor-updates:
|
||||
update-types:
|
||||
- "minor"
|
||||
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
179
.github/workflows/cicd.yml
vendored
179
.github/workflows/cicd.yml
vendored
@@ -1,52 +1,161 @@
|
||||
name: CI/CD Pipeline
|
||||
|
||||
# CI/CD workflow for building, publishing, mirroring, signing container images and building release binaries.
|
||||
# Actions are pinned to specific SHAs to reduce supply-chain risk. This workflow triggers on tag push events.
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write # for GHCR push
|
||||
id-token: write # for Cosign Keyless (OIDC) Signing
|
||||
|
||||
# Required secrets:
|
||||
# - DOCKER_HUB_USERNAME / DOCKER_HUB_ACCESS_TOKEN: push to Docker Hub
|
||||
# - GITHUB_TOKEN: used for GHCR login and OIDC keyless signing
|
||||
# - COSIGN_PRIVATE_KEY / COSIGN_PASSWORD / COSIGN_PUBLIC_KEY: for key-based signing
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- "[0-9]+.[0-9]+.[0-9]+"
|
||||
- "[0-9]+.[0-9]+.[0-9]+.rc.[0-9]+"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: ubuntu-latest
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: amd64-runner
|
||||
# Job-level timeout to avoid runaway or stuck runs
|
||||
timeout-minutes: 120
|
||||
env:
|
||||
# Target images
|
||||
DOCKERHUB_IMAGE: docker.io/fosrl/${{ github.event.repository.name }}
|
||||
GHCR_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: docker.io
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: 1.23.1
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v6.2.0
|
||||
with:
|
||||
go-version: 1.25
|
||||
|
||||
- name: Build and push Docker images
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
make docker-build-release tag=$TAG
|
||||
- name: Update version in main.go
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
if [ -f main.go ]; then
|
||||
sed -i 's/version_replaceme/'"$TAG"'/' main.go
|
||||
echo "Updated main.go with version $TAG"
|
||||
else
|
||||
echo "main.go not found"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
- name: Build and push Docker images (Docker Hub)
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
make docker-build-release tag=$TAG
|
||||
echo "Built & pushed to: ${{ env.DOCKERHUB_IMAGE }}:${TAG}"
|
||||
shell: bash
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
- name: Login in to GHCR
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install skopeo + jq
|
||||
# skopeo: copy/inspect images between registries
|
||||
# jq: JSON parsing tool used to extract digest values
|
||||
run: |
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y skopeo jq
|
||||
skopeo --version
|
||||
shell: bash
|
||||
|
||||
- name: Copy tag from Docker Hub to GHCR
|
||||
# Mirror the already-built image (all architectures) to GHCR so we can sign it
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG=${{ env.TAG }}
|
||||
echo "Copying ${{ env.DOCKERHUB_IMAGE }}:${TAG} -> ${{ env.GHCR_IMAGE }}:${TAG}"
|
||||
skopeo copy --all --retry-times 3 \
|
||||
docker://$DOCKERHUB_IMAGE:$TAG \
|
||||
docker://$GHCR_IMAGE:$TAG
|
||||
shell: bash
|
||||
|
||||
- name: Install cosign
|
||||
# cosign is used to sign and verify container images (key and keyless)
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
|
||||
- name: Dual-sign and verify (GHCR & Docker Hub)
|
||||
# Sign each image by digest using keyless (OIDC) and key-based signing,
|
||||
# then verify both the public key signature and the keyless OIDC signature.
|
||||
env:
|
||||
TAG: ${{ env.TAG }}
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
COSIGN_YES: "true"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
issuer="https://token.actions.githubusercontent.com"
|
||||
id_regex="^https://github.com/${{ github.repository }}/.+" # accept this repo (all workflows/refs)
|
||||
|
||||
for IMAGE in "${GHCR_IMAGE}" "${DOCKERHUB_IMAGE}"; do
|
||||
echo "Processing ${IMAGE}:${TAG}"
|
||||
|
||||
DIGEST="$(skopeo inspect --retry-times 3 docker://${IMAGE}:${TAG} | jq -r '.Digest')"
|
||||
REF="${IMAGE}@${DIGEST}"
|
||||
echo "Resolved digest: ${REF}"
|
||||
|
||||
echo "==> cosign sign (keyless) --recursive ${REF}"
|
||||
cosign sign --recursive "${REF}"
|
||||
|
||||
echo "==> cosign sign (key) --recursive ${REF}"
|
||||
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${REF}"
|
||||
|
||||
echo "==> cosign verify (public key) ${REF}"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "${REF}" -o text
|
||||
|
||||
echo "==> cosign verify (keyless policy) ${REF}"
|
||||
cosign verify \
|
||||
--certificate-oidc-issuer "${issuer}" \
|
||||
--certificate-identity-regexp "${id_regex}" \
|
||||
"${REF}" -o text
|
||||
done
|
||||
shell: bash
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
shell: bash
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
|
||||
132
.github/workflows/mirror.yaml
vendored
Normal file
132
.github/workflows/mirror.yaml
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
name: Mirror & Sign (Docker Hub to GHCR)
|
||||
|
||||
on:
|
||||
workflow_dispatch: {}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write # for keyless OIDC
|
||||
|
||||
env:
|
||||
SOURCE_IMAGE: docker.io/fosrl/gerbil
|
||||
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
jobs:
|
||||
mirror-and-dual-sign:
|
||||
runs-on: amd64-runner
|
||||
steps:
|
||||
- name: Install skopeo + jq
|
||||
run: |
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y skopeo jq
|
||||
skopeo --version
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
|
||||
- name: Input check
|
||||
run: |
|
||||
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
|
||||
echo "Source : ${SOURCE_IMAGE}"
|
||||
echo "Target : ${DEST_IMAGE}"
|
||||
|
||||
# Auth for skopeo (containers-auth)
|
||||
- name: Skopeo login to GHCR
|
||||
run: |
|
||||
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
# Auth for cosign (docker-config)
|
||||
- name: Docker login to GHCR (for cosign)
|
||||
run: |
|
||||
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
|
||||
|
||||
- name: List source tags
|
||||
run: |
|
||||
set -euo pipefail
|
||||
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
|
||||
| jq -r '.Tags[]' | sort -u > src-tags.txt
|
||||
echo "Found source tags: $(wc -l < src-tags.txt)"
|
||||
head -n 20 src-tags.txt || true
|
||||
|
||||
- name: List destination tags (skip existing)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
|
||||
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
|
||||
else
|
||||
: > dst-tags.txt
|
||||
fi
|
||||
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
|
||||
|
||||
- name: Mirror, dual-sign, and verify
|
||||
env:
|
||||
# keyless
|
||||
COSIGN_YES: "true"
|
||||
# key-based
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
# verify
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
copied=0; skipped=0; v_ok=0; errs=0
|
||||
|
||||
issuer="https://token.actions.githubusercontent.com"
|
||||
id_regex="^https://github.com/${{ github.repository }}/.+"
|
||||
|
||||
while read -r tag; do
|
||||
[ -z "$tag" ] && continue
|
||||
|
||||
if grep -Fxq "$tag" dst-tags.txt; then
|
||||
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
|
||||
skipped=$((skipped+1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
|
||||
if ! skopeo copy --all --retry-times 3 \
|
||||
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
|
||||
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
|
||||
errs=$((errs+1)); continue
|
||||
fi
|
||||
copied=$((copied+1))
|
||||
|
||||
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
|
||||
ref="${DEST_IMAGE}@${digest}"
|
||||
|
||||
echo "==> cosign sign (keyless) --recursive ${ref}"
|
||||
if ! cosign sign --recursive "${ref}"; then
|
||||
echo "::warning title=Keyless sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign sign (key) --recursive ${ref}"
|
||||
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
|
||||
echo "::warning title=Key sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (public key) ${ref}"
|
||||
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
|
||||
echo "::warning title=Verify(pubkey) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (keyless policy) ${ref}"
|
||||
if ! cosign verify \
|
||||
--certificate-oidc-issuer "${issuer}" \
|
||||
--certificate-identity-regexp "${id_regex}" \
|
||||
"${ref}" -o text; then
|
||||
echo "::warning title=Verify(keyless) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
else
|
||||
v_ok=$((v_ok+1))
|
||||
fi
|
||||
done < src-tags.txt
|
||||
|
||||
echo "---- Summary ----"
|
||||
echo "Copied : $copied"
|
||||
echo "Skipped : $skipped"
|
||||
echo "Verified OK : $v_ok"
|
||||
echo "Errors : $errs"
|
||||
31
.github/workflows/test.yml
vendored
Normal file
31
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
name: Run Tests
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: amd64-runner
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v6.2.0
|
||||
with:
|
||||
go-version: 1.25
|
||||
|
||||
- name: Build go
|
||||
run: go build
|
||||
|
||||
- name: Build Docker image
|
||||
run: make build
|
||||
|
||||
- name: Build binaries
|
||||
run: make go-build-release
|
||||
1
.go-version
Normal file
1
.go-version
Normal file
@@ -0,0 +1 @@
|
||||
1.25
|
||||
@@ -4,11 +4,7 @@ Contributions are welcome!
|
||||
|
||||
Please see the contribution and local development guide on the docs page before getting started:
|
||||
|
||||
https://docs.fossorial.io/development
|
||||
|
||||
For ideas about what features to work on and our future plans, please see the roadmap:
|
||||
|
||||
https://docs.fossorial.io/roadmap
|
||||
https://docs.pangolin.net/development/contributing
|
||||
|
||||
### Licensing Considerations
|
||||
|
||||
|
||||
11
Dockerfile
11
Dockerfile
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.23.1-alpine AS builder
|
||||
FROM golang:1.25-alpine AS builder
|
||||
|
||||
# Set the working directory inside the container
|
||||
WORKDIR /app
|
||||
@@ -16,18 +16,13 @@ COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /gerbil
|
||||
|
||||
# Start a new stage from scratch
|
||||
FROM ubuntu:22.04 AS runner
|
||||
FROM alpine:3.23 AS runner
|
||||
|
||||
RUN apt-get update && apt-get install -y iptables iproute2 && rm -rf /var/lib/apt/lists/*
|
||||
RUN apk add --no-cache iptables iproute2
|
||||
|
||||
# Copy the pre-built binary file from the previous stage and the entrypoint script
|
||||
COPY --from=builder /gerbil /usr/local/bin/
|
||||
COPY entrypoint.sh /
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
# Copy the entrypoint script
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
# Command to run the executable
|
||||
CMD ["gerbil"]
|
||||
2
Makefile
2
Makefile
@@ -3,7 +3,7 @@ all: build push
|
||||
|
||||
docker-build-release:
|
||||
@if [ -z "$(tag)" ]; then \
|
||||
echo "Error: tag is required. Usage: make build-all tag=<tag>"; \
|
||||
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
|
||||
exit 1; \
|
||||
fi
|
||||
docker buildx build --platform linux/arm64,linux/amd64 -t fosrl/gerbil:latest -f Dockerfile --push .
|
||||
|
||||
81
README.md
81
README.md
@@ -4,16 +4,9 @@ Gerbil is a simple [WireGuard](https://www.wireguard.com/) interface management
|
||||
|
||||
### Installation and Documentation
|
||||
|
||||
Gerbil can be used standalone with your own API, a static JSON file, or with Pangolin and Newt as part of the larger system. See documentation below:
|
||||
Gerbil works with Pangolin, Newt, and Olm as part of the larger system. See documentation below:
|
||||
|
||||
- [Installation Instructions](https://docs.fossorial.io)
|
||||
- [Full Documentation](https://docs.fossorial.io)
|
||||
|
||||
## Preview
|
||||
|
||||
<img src="public/screenshots/preview.png" alt="Preview"/>
|
||||
|
||||
_Sample output of a Gerbil container connected to Pangolin and terminating various peers._
|
||||
- [Full Documentation](https://docs.pangolin.net)
|
||||
|
||||
## Key Functions
|
||||
|
||||
@@ -27,30 +20,71 @@ Gerbil will create the peers defined in the config on the WireGuard interface. T
|
||||
|
||||
### Report Bandwidth
|
||||
|
||||
Bytes transmitted in and out of each peer are collected every 10 seconds, and incremental usage is reported via the "reportBandwidthTo" endpoint. This can be used to track data usage of each peer on the remote server.
|
||||
Bytes transmitted in and out of each peer are collected every 10 seconds, and incremental usage is reported via the api endpoint. This can be used to track data usage of each peer on the remote server.
|
||||
|
||||
### Handle client relaying
|
||||
|
||||
Gerbil listens on port 21820 for incoming UDP hole punch packets to orchestrate NAT hole punching between olm and newt clients. Additionally, it handles relaying data through the gerbil server down to the newt. This is accomplished by scanning each packet for headers and handling them appropriately.
|
||||
|
||||
### SNI Proxy
|
||||
|
||||
Gerbil includes an SNI (Server Name Indication) proxy that enables intelligent routing of HTTPS traffic between Pangolin nodes. When a TLS connection comes in, the proxy extracts the hostname from the SNI extension and queries Pangolin to determine the correct routing destination. This allows seamless routing of web traffic through the WireGuard mesh network:
|
||||
|
||||
- If the hostname is configured for local handling (via local overrides or local SNIs), traffic is routed to the local proxy
|
||||
- Otherwise, the proxy queries Pangolin's routing API to determine which node should handle the traffic
|
||||
- Supports caching of routing decisions to improve performance
|
||||
- Handles connection pooling and graceful shutdown
|
||||
- Optional PROXY protocol v1 support to preserve original client IP addresses when forwarding to downstream proxies (HAProxy, Nginx, etc.)
|
||||
|
||||
The PROXY protocol allows downstream proxies to know the real client IP address instead of seeing the SNI proxy's IP. When enabled with `--proxy-protocol`, the SNI proxy will prepend a PROXY protocol header to each connection containing the original client's IP and port information.
|
||||
|
||||
In single node (self hosted) Pangolin deployments this can be bypassed by using port 443:443 to route to Traefik instead of the SNI proxy at 8443.
|
||||
|
||||
## CLI Args
|
||||
|
||||
Important:
|
||||
- `reachableAt`: How should the remote server reach Gerbil's API?
|
||||
- `generateAndSaveKeyTo`: Where to save the generated WireGuard private key to persist across restarts.
|
||||
- `remoteConfig` (optional): Remote config location to HTTP get the JSON based config from. See `example_config.json`
|
||||
- `config` (optional): Local JSON file path to load config. Used if remote config is not supplied. See `example_config.json`
|
||||
- `remoteConfig`: Remote config location to HTTP get the JSON based config from.
|
||||
|
||||
Note: You must use either `config` or `remoteConfig` to configure WireGuard.
|
||||
|
||||
- `reportBandwidthTo` (optional): Remote HTTP endpoint to send peer bandwidth data
|
||||
Others:
|
||||
- `reportBandwidthTo` (optional): **DEPRECATED** - Use `remoteConfig` instead. Remote HTTP endpoint to send peer bandwidth data
|
||||
- `interface` (optional): Name of the WireGuard interface created by Gerbil. Default: `wg0`
|
||||
- `listen` (optional): Port to listen on for HTTP server. Default: `3003`
|
||||
- `log-level` (optional): The log level to use. Default: INFO
|
||||
- `listen` (optional): Port to listen on for HTTP server. Default: `:3004`
|
||||
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: `INFO`
|
||||
- `mtu` (optional): MTU of the WireGuard interface. Default: `1280`
|
||||
- `notify` (optional): URL to notify on peer changes
|
||||
- `sni-port` (optional): Port for the SNI proxy to listen on. Default: `8443`
|
||||
- `local-proxy` (optional): Address for local proxy when routing local traffic. Default: `localhost`
|
||||
- `local-proxy-port` (optional): Port for local proxy when routing local traffic. Default: `443`
|
||||
- `local-overrides` (optional): Comma-separated list of domain names that should always be routed to the local proxy
|
||||
- `proxy-protocol` (optional): Enable PROXY protocol v1 for preserving client IP addresses when forwarding to downstream proxies. Default: `false`
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All CLI arguments can also be provided via environment variables:
|
||||
|
||||
- `INTERFACE`: Name of the WireGuard interface
|
||||
- `REMOTE_CONFIG`: URL of the remote config server
|
||||
- `LISTEN`: Address to listen on for HTTP server
|
||||
- `GENERATE_AND_SAVE_KEY_TO`: Path to save generated private key
|
||||
- `REACHABLE_AT`: Endpoint of the HTTP server to tell remote config about
|
||||
- `LOG_LEVEL`: Log level (DEBUG, INFO, WARN, ERROR, FATAL)
|
||||
- `MTU`: MTU of the WireGuard interface
|
||||
- `NOTIFY_URL`: URL to notify on peer changes
|
||||
- `SNI_PORT`: Port for the SNI proxy to listen on
|
||||
- `LOCAL_PROXY`: Address for local proxy when routing local traffic
|
||||
- `LOCAL_PROXY_PORT`: Port for local proxy when routing local traffic
|
||||
- `LOCAL_OVERRIDES`: Comma-separated list of domain names that should always be routed to the local proxy
|
||||
- `PROXY_PROTOCOL`: Enable PROXY protocol v1 for preserving client IP addresses (true/false)
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
./gerbil \
|
||||
--reachableAt=http://gerbil:3003 \
|
||||
--reachableAt=http://gerbil:3004 \
|
||||
--generateAndSaveKeyTo=/var/config/key \
|
||||
--remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config \
|
||||
--reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth
|
||||
--remoteConfig=http://pangolin:3001/api/v1/
|
||||
```
|
||||
|
||||
```yaml
|
||||
@@ -60,10 +94,9 @@ services:
|
||||
container_name: gerbil
|
||||
restart: unless-stopped
|
||||
command:
|
||||
- --reachableAt=http://gerbil:3003
|
||||
- --reachableAt=http://gerbil:3004
|
||||
- --generateAndSaveKeyTo=/var/config/key
|
||||
- --remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config
|
||||
- --reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth
|
||||
- --remoteConfig=http://pangolin:3001/api/v1/
|
||||
volumes:
|
||||
- ./config/:/var/config
|
||||
cap_add:
|
||||
@@ -71,6 +104,8 @@ services:
|
||||
- SYS_MODULE
|
||||
ports:
|
||||
- 51820:51820/udp
|
||||
- 21820:21820/udp
|
||||
- 443:8443/tcp # SNI proxy port
|
||||
```
|
||||
|
||||
## Build
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
||||
|
||||
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
||||
2. Send a detailed report to [security@fossorial.io](mailto:security@fossorial.io) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
2. Send a detailed report to [security@pangolin.net](mailto:security@pangolin.net) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
|
||||
- Description and location of the vulnerability.
|
||||
- Potential impact of the vulnerability.
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
{
|
||||
"privateKey": "kBGTgk7c+zncEEoSnMl+jsLjVh5ZVoL/HwBSQem+d1M=",
|
||||
"listenPort": 51820,
|
||||
"ipAddress": "10.0.0.1/24",
|
||||
"peers": [
|
||||
{
|
||||
"publicKey": "5UzzoeveFVSzuqK3nTMS5bA1jIMs1fQffVQzJ8MXUQM=",
|
||||
"allowedIps": ["10.0.0.0/28"]
|
||||
},
|
||||
{
|
||||
"publicKey": "kYrZpuO2NsrFoBh1GMNgkhd1i9Rgtu1rAjbJ7qsfngU=",
|
||||
"allowedIps": ["10.0.0.16/28"]
|
||||
},
|
||||
{
|
||||
"publicKey": "1YfPUVr9ZF4zehkbI2BQhCxaRLz+Vtwa4vJwH+mpK0A=",
|
||||
"allowedIps": ["10.0.0.32/28"]
|
||||
},
|
||||
{
|
||||
"publicKey": "2/U4oyZ+sai336Dal/yExCphL8AxyqvIxMk4qsUy4iI=",
|
||||
"allowedIps": ["10.0.0.48/28"]
|
||||
}
|
||||
]
|
||||
}
|
||||
16
go.mod
16
go.mod
@@ -1,10 +1,12 @@
|
||||
module github.com/fosrl/gerbil
|
||||
|
||||
go 1.23.1
|
||||
go 1.25
|
||||
|
||||
toolchain go1.23.2
|
||||
require (
|
||||
github.com/vishvananda/netlink v1.3.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/sync v0.1.0
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
)
|
||||
|
||||
@@ -14,10 +16,8 @@ require (
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
golang.org/x/crypto v0.8.0 // indirect
|
||||
golang.org/x/net v0.9.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.10.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect
|
||||
)
|
||||
|
||||
25
go.sum
25
go.sum
@@ -8,21 +8,24 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
||||
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
||||
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
|
||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
|
||||
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
||||
|
||||
645
main.go
645
main.go
@@ -2,21 +2,32 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
"github.com/fosrl/gerbil/proxy"
|
||||
"github.com/fosrl/gerbil/relay"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@@ -27,11 +38,16 @@ var (
|
||||
mtuInt int
|
||||
lastReadings = make(map[string]PeerReading)
|
||||
mu sync.Mutex
|
||||
wgMu sync.Mutex // Protects WireGuard operations
|
||||
notifyURL string
|
||||
proxyRelay *relay.UDPProxyServer
|
||||
proxySNI *proxy.SNIProxy
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
PrivateKey string `json:"privateKey"`
|
||||
ListenPort int `json:"listenPort"`
|
||||
RelayPort int `json:"relayPort"`
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Peers []Peer `json:"peers"`
|
||||
}
|
||||
@@ -57,6 +73,31 @@ var (
|
||||
wgClient *wgctrl.Client
|
||||
)
|
||||
|
||||
// Add this new type at the top with other type definitions
|
||||
type ClientEndpoint struct {
|
||||
OlmID string `json:"olmId"`
|
||||
NewtID string `json:"newtId"`
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type HolePunchMessage struct {
|
||||
OlmID string `json:"olmId"`
|
||||
NewtID string `json:"newtId"`
|
||||
}
|
||||
|
||||
type ProxyMappingUpdate struct {
|
||||
OldDestination relay.PeerDestination `json:"oldDestination"`
|
||||
NewDestination relay.PeerDestination `json:"newDestination"`
|
||||
}
|
||||
|
||||
type UpdateDestinationsRequest struct {
|
||||
SourceIP string `json:"sourceIp"`
|
||||
SourcePort int `json:"sourcePort"`
|
||||
Destinations []relay.PeerDestination `json:"destinations"`
|
||||
}
|
||||
|
||||
func parseLogLevel(level string) logger.LogLevel {
|
||||
switch strings.ToUpper(level) {
|
||||
case "DEBUG":
|
||||
@@ -75,27 +116,41 @@ func parseLogLevel(level string) logger.LogLevel {
|
||||
}
|
||||
|
||||
func main() {
|
||||
go monitorMemory(1024 * 1024 * 512) // trigger if memory usage exceeds 512MB
|
||||
|
||||
var (
|
||||
err error
|
||||
wgconfig WgConfig
|
||||
configFile string
|
||||
remoteConfigURL string
|
||||
reportBandwidthTo string
|
||||
generateAndSaveKeyTo string
|
||||
reachableAt string
|
||||
logLevel string
|
||||
mtu string
|
||||
sniProxyPort int
|
||||
localProxyAddr string
|
||||
localProxyPort int
|
||||
localOverridesStr string
|
||||
trustedUpstreamsStr string
|
||||
proxyProtocol bool
|
||||
)
|
||||
|
||||
interfaceName = os.Getenv("INTERFACE")
|
||||
configFile = os.Getenv("CONFIG")
|
||||
remoteConfigURL = os.Getenv("REMOTE_CONFIG")
|
||||
listenAddr = os.Getenv("LISTEN")
|
||||
reportBandwidthTo = os.Getenv("REPORT_BANDWIDTH_TO")
|
||||
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
||||
reachableAt = os.Getenv("REACHABLE_AT")
|
||||
logLevel = os.Getenv("LOG_LEVEL")
|
||||
mtu = os.Getenv("MTU")
|
||||
notifyURL = os.Getenv("NOTIFY_URL")
|
||||
|
||||
sniProxyPortStr := os.Getenv("SNI_PORT")
|
||||
localProxyAddr = os.Getenv("LOCAL_PROXY")
|
||||
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
|
||||
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
|
||||
trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS")
|
||||
proxyProtocolStr := os.Getenv("PROXY_PROTOCOL")
|
||||
|
||||
if interfaceName == "" {
|
||||
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
||||
@@ -104,31 +159,94 @@ func main() {
|
||||
flag.StringVar(&configFile, "config", "", "Path to local configuration file")
|
||||
}
|
||||
if remoteConfigURL == "" {
|
||||
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration")
|
||||
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL of the Pangolin server")
|
||||
}
|
||||
if listenAddr == "" {
|
||||
flag.StringVar(&listenAddr, "listen", ":3003", "Address to listen on")
|
||||
}
|
||||
if reportBandwidthTo == "" {
|
||||
flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "Address to listen on")
|
||||
flag.StringVar(&listenAddr, "listen", "", "DEPRECATED (overridden by reachableAt): Address to listen on")
|
||||
}
|
||||
// DEPRECATED AND UNSED: reportBandwidthTo
|
||||
// allow reportBandwidthTo to be passed but dont do anything with it just thow it away
|
||||
reportBandwidthTo := ""
|
||||
flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "DEPRECATED: Use remoteConfig instead")
|
||||
|
||||
if generateAndSaveKeyTo == "" {
|
||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||
}
|
||||
|
||||
if reachableAt == "" {
|
||||
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
|
||||
}
|
||||
|
||||
if logLevel == "" {
|
||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
}
|
||||
if mtu == "" {
|
||||
flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface")
|
||||
}
|
||||
if notifyURL == "" {
|
||||
flag.StringVar(¬ifyURL, "notify", "", "URL to notify on peer changes")
|
||||
}
|
||||
|
||||
if sniProxyPortStr != "" {
|
||||
if port, err := strconv.Atoi(sniProxyPortStr); err == nil {
|
||||
sniProxyPort = port
|
||||
}
|
||||
}
|
||||
if sniProxyPortStr == "" {
|
||||
flag.IntVar(&sniProxyPort, "sni-port", 8443, "Port to listen on")
|
||||
}
|
||||
|
||||
if localProxyAddr == "" {
|
||||
flag.StringVar(&localProxyAddr, "local-proxy", "localhost", "Local proxy address")
|
||||
}
|
||||
|
||||
if localProxyPortStr != "" {
|
||||
if port, err := strconv.Atoi(localProxyPortStr); err == nil {
|
||||
localProxyPort = port
|
||||
}
|
||||
}
|
||||
if localProxyPortStr == "" {
|
||||
flag.IntVar(&localProxyPort, "local-proxy-port", 443, "Local proxy port")
|
||||
}
|
||||
if localOverridesStr != "" {
|
||||
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
|
||||
}
|
||||
if trustedUpstreamsStr == "" {
|
||||
flag.StringVar(&trustedUpstreamsStr, "trusted-upstreams", "", "Comma-separated list of trusted upstream proxy domain names/IPs that can send PROXY protocol")
|
||||
}
|
||||
|
||||
if proxyProtocolStr != "" {
|
||||
proxyProtocol = strings.ToLower(proxyProtocolStr) == "true"
|
||||
}
|
||||
if proxyProtocolStr == "" {
|
||||
flag.BoolVar(&proxyProtocol, "proxy-protocol", true, "Enable PROXY protocol v1 for preserving client IP")
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
logger.Init()
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
// Base context for the application; cancel on SIGINT/SIGTERM
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
// try to parse as http://host:port and set the listenAddr to the :port from this reachableAt.
|
||||
if reachableAt != "" && listenAddr == "" {
|
||||
if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") {
|
||||
parts := strings.Split(reachableAt, ":")
|
||||
if len(parts) == 3 {
|
||||
port := parts[2]
|
||||
if strings.Contains(port, "/") {
|
||||
port = strings.Split(port, "/")[0]
|
||||
}
|
||||
listenAddr = ":" + port
|
||||
}
|
||||
}
|
||||
} else if listenAddr == "" {
|
||||
listenAddr = ":3003"
|
||||
}
|
||||
|
||||
mtuInt, err = strconv.Atoi(mtu)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse MTU: %v", err)
|
||||
@@ -144,6 +262,10 @@ func main() {
|
||||
logger.Fatal("You must provide either a config file or a remote config URL, not both")
|
||||
}
|
||||
|
||||
// clean up the reomte config URL for backwards compatibility
|
||||
remoteConfigURL = strings.TrimSuffix(remoteConfigURL, "/gerbil/get-config")
|
||||
remoteConfigURL = strings.TrimSuffix(remoteConfigURL, "/")
|
||||
|
||||
var key wgtypes.Key
|
||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
||||
if generateAndSaveKeyTo != "" {
|
||||
@@ -191,8 +313,8 @@ func main() {
|
||||
} else {
|
||||
// loop until we get the config
|
||||
for wgconfig.PrivateKey == "" {
|
||||
logger.Info("Fetching remote config from %s", remoteConfigURL)
|
||||
wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt)
|
||||
logger.Info("Fetching remote config from %s", remoteConfigURL+"/gerbil/get-config")
|
||||
wgconfig, err = loadRemoteConfig(remoteConfigURL+"/gerbil/get-config", key, reachableAt)
|
||||
if err != nil {
|
||||
logger.Error("Failed to load configuration: %v", err)
|
||||
time.Sleep(5 * time.Second)
|
||||
@@ -216,13 +338,97 @@ func main() {
|
||||
// Ensure the WireGuard peers exist
|
||||
ensureWireguardPeers(wgconfig.Peers)
|
||||
|
||||
if reportBandwidthTo != "" {
|
||||
go periodicBandwidthCheck(reportBandwidthTo)
|
||||
// Child error group derived from base context
|
||||
group, groupCtx := errgroup.WithContext(ctx)
|
||||
|
||||
// Periodic bandwidth reporting
|
||||
group.Go(func() error {
|
||||
return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth")
|
||||
})
|
||||
|
||||
// Start the UDP proxy server
|
||||
relayPort := wgconfig.RelayPort
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820 // in case there is no relay port set, use 21820
|
||||
}
|
||||
proxyRelay = relay.NewUDPProxyServer(groupCtx, fmt.Sprintf(":%d", relayPort), remoteConfigURL, key, reachableAt)
|
||||
err = proxyRelay.Start()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to start UDP proxy server: %v", err)
|
||||
}
|
||||
defer proxyRelay.Stop()
|
||||
|
||||
// TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING
|
||||
// SO YOU DON'T NEED TO SET THIS SEPARATELY
|
||||
// Parse local overrides
|
||||
var localOverrides []string
|
||||
if localOverridesStr != "" {
|
||||
localOverrides = strings.Split(localOverridesStr, ",")
|
||||
for i, domain := range localOverrides {
|
||||
localOverrides[i] = strings.TrimSpace(domain)
|
||||
}
|
||||
logger.Info("Local overrides configured: %v", localOverrides)
|
||||
}
|
||||
|
||||
var trustedUpstreams []string
|
||||
if trustedUpstreamsStr != "" {
|
||||
trustedUpstreams = strings.Split(trustedUpstreamsStr, ",")
|
||||
for i, upstream := range trustedUpstreams {
|
||||
trustedUpstreams[i] = strings.TrimSpace(upstream)
|
||||
}
|
||||
logger.Info("Trusted upstreams configured: %v", trustedUpstreams)
|
||||
}
|
||||
|
||||
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol, trustedUpstreams)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create proxy: %v", err)
|
||||
}
|
||||
|
||||
if err := proxySNI.Start(); err != nil {
|
||||
logger.Fatal("Failed to start proxy: %v", err)
|
||||
}
|
||||
|
||||
// Set up HTTP server
|
||||
http.HandleFunc("/peer", handlePeer)
|
||||
logger.Info("Starting server on %s", listenAddr)
|
||||
logger.Fatal("Failed to start server: %v", http.ListenAndServe(listenAddr, nil))
|
||||
http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping)
|
||||
http.HandleFunc("/update-destinations", handleUpdateDestinations)
|
||||
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
|
||||
http.HandleFunc("/healthz", handleHealthz)
|
||||
logger.Info("Starting HTTP server on %s", listenAddr)
|
||||
|
||||
// HTTP server with graceful shutdown on context cancel
|
||||
server := &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: nil,
|
||||
}
|
||||
group.Go(func() error {
|
||||
// http.ErrServerClosed is returned on graceful shutdown; not an error for us
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
group.Go(func() error {
|
||||
<-groupCtx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = server.Shutdown(shutdownCtx)
|
||||
// Stop background components as the context is canceled
|
||||
if proxySNI != nil {
|
||||
_ = proxySNI.Stop()
|
||||
}
|
||||
if proxyRelay != nil {
|
||||
proxyRelay.Stop()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Wait for all goroutines to finish
|
||||
if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
logger.Error("Service exited with error: %v", err)
|
||||
} else if errors.Is(err, context.Canceled) {
|
||||
logger.Info("Context cancelled, shutting down")
|
||||
}
|
||||
}
|
||||
|
||||
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
|
||||
@@ -349,6 +555,10 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||
}
|
||||
|
||||
if err := ensureWireguardFirewall(); err != nil {
|
||||
logger.Warn("Failed to ensure WireGuard firewall rules: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("WireGuard interface %s created and configured", interfaceName)
|
||||
|
||||
return nil
|
||||
@@ -377,6 +587,9 @@ func assignIPAddress(ipAddress string) error {
|
||||
}
|
||||
|
||||
func ensureWireguardPeers(peers []Peer) error {
|
||||
wgMu.Lock()
|
||||
defer wgMu.Unlock()
|
||||
|
||||
// get the current peers
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err != nil {
|
||||
@@ -399,8 +612,8 @@ func ensureWireguardPeers(peers []Peer) error {
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := removePeer(peer)
|
||||
if err != nil {
|
||||
// Note: We need to call the internal removal logic without re-acquiring the lock
|
||||
if err := removePeerInternal(peer); err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -416,8 +629,8 @@ func ensureWireguardPeers(peers []Peer) error {
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := addPeer(configPeer)
|
||||
if err != nil {
|
||||
// Note: We need to call the internal addition logic without re-acquiring the lock
|
||||
if err := addPeerInternal(configPeer); err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -476,8 +689,8 @@ func ensureMSSClamping() error {
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
errors = append(errors, fmt.Errorf(errMsg))
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -493,8 +706,8 @@ func ensureMSSClamping() error {
|
||||
if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
errors = append(errors, fmt.Errorf(errMsg))
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -514,6 +727,113 @@ func ensureMSSClamping() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureWireguardFirewall() error {
|
||||
// Rules to enforce:
|
||||
// 1. Allow established/related connections (responses to our outbound traffic)
|
||||
// 2. Allow ICMP ping packets
|
||||
// 3. Drop all other inbound traffic from peers
|
||||
|
||||
// Define the rules we want to ensure exist
|
||||
rules := [][]string{
|
||||
// Allow established and related connections (responses to outbound traffic)
|
||||
{
|
||||
"-A", "INPUT",
|
||||
"-i", interfaceName,
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "ESTABLISHED,RELATED",
|
||||
"-j", "ACCEPT",
|
||||
},
|
||||
// Allow ICMP ping requests
|
||||
{
|
||||
"-A", "INPUT",
|
||||
"-i", interfaceName,
|
||||
"-p", "icmp",
|
||||
"--icmp-type", "8",
|
||||
"-j", "ACCEPT",
|
||||
},
|
||||
// Drop all other inbound traffic from WireGuard interface
|
||||
{
|
||||
"-A", "INPUT",
|
||||
"-i", interfaceName,
|
||||
"-j", "DROP",
|
||||
},
|
||||
}
|
||||
|
||||
// First, try to delete any existing rules for this interface
|
||||
for _, rule := range rules {
|
||||
deleteArgs := make([]string, len(rule))
|
||||
copy(deleteArgs, rule)
|
||||
// Change -A to -D for deletion
|
||||
for i, arg := range deleteArgs {
|
||||
if arg == "-A" {
|
||||
deleteArgs[i] = "-D"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
deleteCmd := exec.Command("/usr/sbin/iptables", deleteArgs...)
|
||||
logger.Debug("Attempting to delete existing firewall rule: %v", deleteArgs)
|
||||
|
||||
// Try deletion multiple times to handle multiple existing rules
|
||||
for i := 0; i < 5; i++ {
|
||||
out, err := deleteCmd.CombinedOutput()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
logger.Debug("Deletion stopped: %v (output: %s)", exitErr.String(), string(out))
|
||||
}
|
||||
break // No more rules to delete
|
||||
}
|
||||
logger.Info("Deleted existing firewall rule (attempt %d)", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// Now add the rules
|
||||
var errors []error
|
||||
for i, rule := range rules {
|
||||
addCmd := exec.Command("/usr/sbin/iptables", rule...)
|
||||
logger.Info("Adding WireGuard firewall rule %d: %v", i+1, rule)
|
||||
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to add firewall rule %d: %v (output: %s)", i+1, err, string(out))
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify the rule was added by checking
|
||||
checkArgs := make([]string, len(rule))
|
||||
copy(checkArgs, rule)
|
||||
// Change -A to -C for check
|
||||
for j, arg := range checkArgs {
|
||||
if arg == "-A" {
|
||||
checkArgs[j] = "-C"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
checkCmd := exec.Command("/usr/sbin/iptables", checkArgs...)
|
||||
if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Rule verification failed for rule %d: %v (output: %s)", i+1, err, string(out))
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("Successfully added and verified WireGuard firewall rule %d", i+1)
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
var errMsgs []string
|
||||
for _, err := range errors {
|
||||
errMsgs = append(errMsgs, err.Error())
|
||||
}
|
||||
return fmt.Errorf("WireGuard firewall setup encountered errors:\n%s", strings.Join(errMsgs, "\n"))
|
||||
}
|
||||
|
||||
logger.Info("WireGuard firewall rules successfully configured for interface %s", interfaceName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func handlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
@@ -525,6 +845,15 @@ func handlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func handleHealthz(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func handleAddPeer(w http.ResponseWriter, r *http.Request) {
|
||||
var peer Peer
|
||||
if err := json.NewDecoder(r.Body).Decode(&peer); err != nil {
|
||||
@@ -538,11 +867,20 @@ func handleAddPeer(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Notify if notifyURL is set
|
||||
go notifyPeerChange("add", peer.PublicKey)
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "Peer added successfully"})
|
||||
}
|
||||
|
||||
func addPeer(peer Peer) error {
|
||||
wgMu.Lock()
|
||||
defer wgMu.Unlock()
|
||||
return addPeerInternal(peer)
|
||||
}
|
||||
|
||||
func addPeerInternal(peer Peer) error {
|
||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
@@ -550,12 +888,15 @@ func addPeer(peer Peer) error {
|
||||
|
||||
// parse allowed IPs into array of net.IPNet
|
||||
var allowedIPs []net.IPNet
|
||||
var wgIPs []string
|
||||
for _, ipStr := range peer.AllowedIPs {
|
||||
_, ipNet, err := net.ParseCIDR(ipStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse allowed IP: %v", err)
|
||||
}
|
||||
allowedIPs = append(allowedIPs, *ipNet)
|
||||
// Extract the IP address from the CIDR for relay cleanup
|
||||
wgIPs = append(wgIPs, ipNet.IP.String())
|
||||
}
|
||||
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
@@ -571,6 +912,13 @@ func addPeer(peer Peer) error {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
|
||||
// Clear relay connections for the peer's WireGuard IPs
|
||||
if proxyRelay != nil {
|
||||
for _, wgIP := range wgIPs {
|
||||
proxyRelay.OnPeerAdded(wgIP)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Peer %s added successfully", peer.PublicKey)
|
||||
|
||||
return nil
|
||||
@@ -589,16 +937,42 @@ func handleRemovePeer(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Notify if notifyURL is set
|
||||
go notifyPeerChange("remove", publicKey)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "Peer removed successfully"})
|
||||
}
|
||||
|
||||
func removePeer(publicKey string) error {
|
||||
wgMu.Lock()
|
||||
defer wgMu.Unlock()
|
||||
return removePeerInternal(publicKey)
|
||||
}
|
||||
|
||||
func removePeerInternal(publicKey string) error {
|
||||
pubKey, err := wgtypes.ParseKey(publicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
// Get current peer info before removing to clear relay connections
|
||||
var wgIPs []string
|
||||
if proxyRelay != nil {
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err == nil {
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
// Extract WireGuard IPs from this peer's allowed IPs
|
||||
for _, allowedIP := range peer.AllowedIPs {
|
||||
wgIPs = append(wgIPs, allowedIP.IP.String())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
Remove: true,
|
||||
@@ -612,24 +986,184 @@ func removePeer(publicKey string) error {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
|
||||
// Clear relay connections for the peer's WireGuard IPs
|
||||
if proxyRelay != nil {
|
||||
for _, wgIP := range wgIPs {
|
||||
proxyRelay.OnPeerRemoved(wgIP)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Peer %s removed successfully", publicKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func periodicBandwidthCheck(endpoint string) {
|
||||
func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
logger.Error("Invalid method: %s", r.Method)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var update ProxyMappingUpdate
|
||||
if err := json.NewDecoder(r.Body).Decode(&update); err != nil {
|
||||
logger.Error("Failed to decode request body: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to decode request body: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the update request
|
||||
if update.OldDestination.DestinationIP == "" || update.NewDestination.DestinationIP == "" {
|
||||
logger.Error("Both old and new destination IP addresses are required")
|
||||
http.Error(w, "Both old and new destination IP addresses are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if update.OldDestination.DestinationPort <= 0 || update.NewDestination.DestinationPort <= 0 {
|
||||
logger.Error("Both old and new destination ports must be positive integers")
|
||||
http.Error(w, "Both old and new destination ports must be positive integers", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Update the proxy mappings in the relay server
|
||||
if proxyRelay == nil {
|
||||
logger.Error("Proxy server is not available")
|
||||
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
updatedCount := proxyRelay.UpdateDestinationInMappings(update.OldDestination, update.NewDestination)
|
||||
|
||||
logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d",
|
||||
updatedCount,
|
||||
update.OldDestination.DestinationIP, update.OldDestination.DestinationPort,
|
||||
update.NewDestination.DestinationIP, update.NewDestination.DestinationPort)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "Proxy mappings updated successfully",
|
||||
"updatedCount": updatedCount,
|
||||
"oldDestination": update.OldDestination,
|
||||
"newDestination": update.NewDestination,
|
||||
})
|
||||
}
|
||||
|
||||
func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
logger.Error("Invalid method: %s", r.Method)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var request UpdateDestinationsRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
logger.Error("Failed to decode request body: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to decode request body: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the request
|
||||
if request.SourceIP == "" {
|
||||
logger.Error("Source IP address is required")
|
||||
http.Error(w, "Source IP address is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if request.SourcePort <= 0 {
|
||||
logger.Error("Source port must be a positive integer")
|
||||
http.Error(w, "Source port must be a positive integer", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(request.Destinations) == 0 {
|
||||
logger.Error("At least one destination is required")
|
||||
http.Error(w, "At least one destination is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate each destination
|
||||
for i, dest := range request.Destinations {
|
||||
if dest.DestinationIP == "" {
|
||||
logger.Error("Destination IP is required for destination %d", i)
|
||||
http.Error(w, fmt.Sprintf("Destination IP is required for destination %d", i), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if dest.DestinationPort <= 0 {
|
||||
logger.Error("Destination port must be a positive integer for destination %d", i)
|
||||
http.Error(w, fmt.Sprintf("Destination port must be a positive integer for destination %d", i), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update the proxy mappings in the relay server
|
||||
if proxyRelay == nil {
|
||||
logger.Error("Proxy server is not available")
|
||||
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
proxyRelay.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations)
|
||||
|
||||
logger.Info("Updated proxy mapping for %s:%d with %d destinations",
|
||||
request.SourceIP, request.SourcePort, len(request.Destinations))
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "Destinations updated successfully",
|
||||
"sourceIP": request.SourceIP,
|
||||
"sourcePort": request.SourcePort,
|
||||
"destinationCount": len(request.Destinations),
|
||||
"destinations": request.Destinations,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateLocalSNIsRequest represents the JSON payload for updating local SNIs
|
||||
type UpdateLocalSNIsRequest struct {
|
||||
FullDomains []string `json:"fullDomains"`
|
||||
}
|
||||
|
||||
func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
logger.Error("Invalid method: %s", r.Method)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateLocalSNIsRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
proxySNI.UpdateLocalSNIs(req.FullDomains)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "Local SNIs updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
func periodicBandwidthCheck(ctx context.Context, endpoint string) error {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
wgMu.Lock()
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
wgMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
@@ -640,8 +1174,13 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Track the set of peers currently present on the device to prune stale readings efficiently
|
||||
currentPeerKeys := make(map[string]struct{}, len(device.Peers))
|
||||
|
||||
for _, peer := range device.Peers {
|
||||
publicKey := peer.PublicKey.String()
|
||||
currentPeerKeys[publicKey] = struct{}{}
|
||||
|
||||
currentReading := PeerReading{
|
||||
BytesReceived: peer.ReceiveBytes,
|
||||
BytesTransmitted: peer.TransmitBytes,
|
||||
@@ -698,14 +1237,7 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
|
||||
// Clean up old peers
|
||||
for publicKey := range lastReadings {
|
||||
found := false
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if _, exists := currentPeerKeys[publicKey]; !exists {
|
||||
delete(lastReadings, publicKey)
|
||||
}
|
||||
}
|
||||
@@ -736,3 +1268,50 @@ func reportPeerBandwidth(apiURL string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// notifyPeerChange sends a POST request to notifyURL with the action and public key.
|
||||
func notifyPeerChange(action, publicKey string) {
|
||||
if notifyURL == "" {
|
||||
return
|
||||
}
|
||||
payload := map[string]string{
|
||||
"action": action,
|
||||
"publicKey": publicKey,
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to marshal notify payload: %v", err)
|
||||
return
|
||||
}
|
||||
resp, err := http.Post(notifyURL, "application/json", bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to notify peer change: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.Warn("Notify server returned non-OK: %s", resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func monitorMemory(limit uint64) {
|
||||
var m runtime.MemStats
|
||||
for {
|
||||
runtime.ReadMemStats(&m)
|
||||
if m.Alloc > limit {
|
||||
fmt.Printf("Memory spike detected (%d bytes). Dumping profile...\n", m.Alloc)
|
||||
|
||||
f, err := os.Create(fmt.Sprintf("/var/config/heap/heap-spike-%d.pprof", time.Now().Unix()))
|
||||
if err != nil {
|
||||
log.Println("could not create profile:", err)
|
||||
} else {
|
||||
pprof.WriteHeapProfile(f)
|
||||
f.Close()
|
||||
}
|
||||
|
||||
// Wait a while before checking again to avoid spamming profiles
|
||||
time.Sleep(5 * time.Minute)
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
847
proxy/proxy.go
Normal file
847
proxy/proxy.go
Normal file
@@ -0,0 +1,847 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
// RouteRecord represents a routing configuration
|
||||
type RouteRecord struct {
|
||||
Hostname string
|
||||
TargetHost string
|
||||
TargetPort int
|
||||
}
|
||||
|
||||
// RouteAPIResponse represents the response from the route API
|
||||
type RouteAPIResponse struct {
|
||||
Endpoints []string `json:"endpoints"`
|
||||
}
|
||||
|
||||
// ProxyProtocolInfo holds information parsed from incoming PROXY protocol header
|
||||
type ProxyProtocolInfo struct {
|
||||
Protocol string // TCP4 or TCP6
|
||||
SrcIP string
|
||||
DestIP string
|
||||
SrcPort int
|
||||
DestPort int
|
||||
OriginalConn net.Conn // The original connection after PROXY protocol parsing
|
||||
}
|
||||
|
||||
// SNIProxy represents the main proxy server
|
||||
type SNIProxy struct {
|
||||
port int
|
||||
cache *cache.Cache
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
localProxyAddr string
|
||||
localProxyPort int
|
||||
remoteConfigURL string
|
||||
publicKey string
|
||||
proxyProtocol bool // Enable PROXY protocol v1
|
||||
|
||||
// New fields for fast local SNI lookup
|
||||
localSNIs map[string]struct{}
|
||||
localSNIsLock sync.RWMutex
|
||||
|
||||
// Local overrides for domains that should always use local proxy
|
||||
localOverrides map[string]struct{}
|
||||
|
||||
// Track active tunnels by SNI
|
||||
activeTunnels map[string]*activeTunnel
|
||||
activeTunnelsLock sync.Mutex
|
||||
|
||||
// Trusted upstream proxies that can send PROXY protocol
|
||||
trustedUpstreams map[string]struct{}
|
||||
}
|
||||
|
||||
type activeTunnel struct {
|
||||
conns []net.Conn
|
||||
}
|
||||
|
||||
// readOnlyConn is a wrapper for io.Reader that implements net.Conn
|
||||
type readOnlyConn struct {
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
|
||||
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
|
||||
func (conn readOnlyConn) Close() error { return nil }
|
||||
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
|
||||
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
|
||||
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
// parseProxyProtocolHeader parses a PROXY protocol v1 header from the connection
|
||||
func (p *SNIProxy) parseProxyProtocolHeader(conn net.Conn) (*ProxyProtocolInfo, net.Conn, error) {
|
||||
// Check if the connection comes from a trusted upstream
|
||||
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to parse remote address: %w", err)
|
||||
}
|
||||
|
||||
// Resolve the remote IP to hostname to check if it's trusted
|
||||
// For simplicity, we'll check the IP directly in trusted upstreams
|
||||
// In production, you might want to do reverse DNS lookup
|
||||
if _, isTrusted := p.trustedUpstreams[remoteHost]; !isTrusted {
|
||||
// Not from trusted upstream, return original connection
|
||||
return nil, conn, nil
|
||||
}
|
||||
|
||||
// Set read timeout for PROXY protocol parsing
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to set read deadline: %w", err)
|
||||
}
|
||||
|
||||
// Read the first line (PROXY protocol header)
|
||||
buffer := make([]byte, 512) // PROXY protocol header should be much smaller
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
// If we can't read from trusted upstream, treat as regular connection
|
||||
logger.Debug("Could not read from trusted upstream %s, treating as regular connection: %v", remoteHost, err)
|
||||
// Clear read timeout before returning
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
||||
}
|
||||
return nil, conn, nil
|
||||
}
|
||||
|
||||
// Find the end of the first line (CRLF)
|
||||
headerEnd := bytes.Index(buffer[:n], []byte("\r\n"))
|
||||
if headerEnd == -1 {
|
||||
// No PROXY protocol header found, treat as regular TLS connection
|
||||
// Return the connection with the buffered data prepended
|
||||
logger.Debug("No PROXY protocol header from trusted upstream %s, treating as regular TLS connection", remoteHost)
|
||||
|
||||
// Clear read timeout
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", err)
|
||||
}
|
||||
|
||||
// Create a reader that includes the buffered data + original connection
|
||||
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
||||
wrappedConn := &proxyProtocolConn{
|
||||
Conn: conn,
|
||||
reader: newReader,
|
||||
}
|
||||
return nil, wrappedConn, nil
|
||||
}
|
||||
|
||||
headerLine := string(buffer[:headerEnd])
|
||||
remainingData := buffer[headerEnd+2 : n]
|
||||
|
||||
// Parse PROXY protocol line: "PROXY TCP4/TCP6 srcIP destIP srcPort destPort"
|
||||
parts := strings.Fields(headerLine)
|
||||
if len(parts) != 6 || parts[0] != "PROXY" {
|
||||
// Check for PROXY UNKNOWN
|
||||
if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" {
|
||||
// PROXY UNKNOWN - use original connection info
|
||||
return nil, conn, nil
|
||||
}
|
||||
// Invalid PROXY protocol, but might be regular TLS - treat as such
|
||||
logger.Debug("Invalid PROXY protocol from trusted upstream %s, treating as regular TLS connection: %s", remoteHost, headerLine)
|
||||
|
||||
// Clear read timeout
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", err)
|
||||
}
|
||||
|
||||
// Return the connection with all buffered data prepended
|
||||
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
||||
wrappedConn := &proxyProtocolConn{
|
||||
Conn: conn,
|
||||
reader: newReader,
|
||||
}
|
||||
return nil, wrappedConn, nil
|
||||
}
|
||||
|
||||
protocol := parts[1]
|
||||
srcIP := parts[2]
|
||||
destIP := parts[3]
|
||||
srcPort, err := strconv.Atoi(parts[4])
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4])
|
||||
}
|
||||
destPort, err := strconv.Atoi(parts[5])
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5])
|
||||
}
|
||||
|
||||
// Create a new reader that includes remaining data + original connection
|
||||
var newReader io.Reader
|
||||
if len(remainingData) > 0 {
|
||||
newReader = io.MultiReader(bytes.NewReader(remainingData), conn)
|
||||
} else {
|
||||
newReader = conn
|
||||
}
|
||||
|
||||
// Create a wrapper connection that reads from the combined reader
|
||||
wrappedConn := &proxyProtocolConn{
|
||||
Conn: conn,
|
||||
reader: newReader,
|
||||
}
|
||||
|
||||
proxyInfo := &ProxyProtocolInfo{
|
||||
Protocol: protocol,
|
||||
SrcIP: srcIP,
|
||||
DestIP: destIP,
|
||||
SrcPort: srcPort,
|
||||
DestPort: destPort,
|
||||
OriginalConn: wrappedConn,
|
||||
}
|
||||
|
||||
// Clear read timeout
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to clear read deadline: %w", err)
|
||||
}
|
||||
|
||||
return proxyInfo, wrappedConn, nil
|
||||
}
|
||||
|
||||
// proxyProtocolConn wraps a connection to read from a custom reader
|
||||
type proxyProtocolConn struct {
|
||||
net.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (c *proxyProtocolConn) Read(b []byte) (int, error) {
|
||||
return c.reader.Read(b)
|
||||
}
|
||||
|
||||
// buildProxyProtocolHeaderFromInfo creates a PROXY protocol v1 header using ProxyProtocolInfo
|
||||
func (p *SNIProxy) buildProxyProtocolHeaderFromInfo(proxyInfo *ProxyProtocolInfo, targetAddr net.Addr) string {
|
||||
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
// Fallback for unknown address types
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
// Use the original client information from the PROXY protocol
|
||||
var targetIP string
|
||||
var protocol string
|
||||
|
||||
// Parse source IP to determine protocol family
|
||||
srcIP := net.ParseIP(proxyInfo.SrcIP)
|
||||
if srcIP == nil {
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
if srcIP.To4() != nil {
|
||||
// Source is IPv4, use TCP4 protocol
|
||||
protocol = "TCP4"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is also IPv4, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is IPv6, but we need IPv4 for consistent protocol family
|
||||
if targetTCP.IP.IsLoopback() {
|
||||
targetIP = "127.0.0.1"
|
||||
} else {
|
||||
targetIP = "127.0.0.1" // Safe fallback
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Source is IPv6, use TCP6 protocol
|
||||
protocol = "TCP6"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is IPv4, convert to IPv6 representation
|
||||
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is also IPv6, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||
protocol,
|
||||
proxyInfo.SrcIP,
|
||||
targetIP,
|
||||
proxyInfo.SrcPort,
|
||||
targetTCP.Port)
|
||||
}
|
||||
|
||||
// buildProxyProtocolHeader creates a PROXY protocol v1 header
|
||||
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
|
||||
clientTCP, ok := clientAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
// Fallback for unknown address types
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
// Fallback for unknown address types
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
// Determine protocol family based on client IP and normalize target IP accordingly
|
||||
var protocol string
|
||||
var targetIP string
|
||||
|
||||
if clientTCP.IP.To4() != nil {
|
||||
// Client is IPv4, use TCP4 protocol
|
||||
protocol = "TCP4"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is also IPv4, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is IPv6, but we need IPv4 for consistent protocol family
|
||||
// Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1
|
||||
if targetTCP.IP.IsLoopback() {
|
||||
targetIP = "127.0.0.1"
|
||||
} else {
|
||||
// For non-loopback IPv6 targets, we could try to extract embedded IPv4
|
||||
// or fall back to a sensible IPv4 address based on the target
|
||||
targetIP = "127.0.0.1" // Safe fallback
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Client is IPv6, use TCP6 protocol
|
||||
protocol = "TCP6"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is IPv4, convert to IPv6 representation
|
||||
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is also IPv6, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||
protocol,
|
||||
clientTCP.IP.String(),
|
||||
targetIP,
|
||||
clientTCP.Port,
|
||||
targetTCP.Port)
|
||||
}
|
||||
|
||||
// NewSNIProxy creates a new SNI proxy instance
|
||||
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool, trustedUpstreams []string) (*SNIProxy, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create local overrides map
|
||||
overridesMap := make(map[string]struct{})
|
||||
for _, domain := range localOverrides {
|
||||
if domain != "" {
|
||||
overridesMap[domain] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Create trusted upstreams map
|
||||
trustedMap := make(map[string]struct{})
|
||||
for _, upstream := range trustedUpstreams {
|
||||
if upstream != "" {
|
||||
// Add both the domain and potentially resolved IPs
|
||||
trustedMap[upstream] = struct{}{}
|
||||
|
||||
// Try to resolve the domain to IPs and add them too
|
||||
if ips, err := net.LookupIP(upstream); err == nil {
|
||||
for _, ip := range ips {
|
||||
trustedMap[ip.String()] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
proxy := &SNIProxy{
|
||||
port: port,
|
||||
cache: cache.New(3*time.Second, 10*time.Minute),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
localProxyAddr: localProxyAddr,
|
||||
localProxyPort: localProxyPort,
|
||||
remoteConfigURL: remoteConfigURL,
|
||||
publicKey: publicKey,
|
||||
proxyProtocol: proxyProtocol,
|
||||
localSNIs: make(map[string]struct{}),
|
||||
localOverrides: overridesMap,
|
||||
activeTunnels: make(map[string]*activeTunnel),
|
||||
trustedUpstreams: trustedMap,
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// Start begins listening for connections
|
||||
func (p *SNIProxy) Start() error {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
|
||||
}
|
||||
|
||||
p.listener = listener
|
||||
logger.Debug("SNI Proxy listening on port %d", p.port)
|
||||
|
||||
// Accept connections in a goroutine
|
||||
go p.acceptConnections()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the proxy
|
||||
func (p *SNIProxy) Stop() error {
|
||||
log.Println("Stopping SNI Proxy...")
|
||||
|
||||
p.cancel()
|
||||
|
||||
if p.listener != nil {
|
||||
p.listener.Close()
|
||||
}
|
||||
|
||||
// Wait for all goroutines to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
p.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Println("All connections closed gracefully")
|
||||
case <-time.After(30 * time.Second):
|
||||
log.Println("Timeout waiting for connections to close")
|
||||
}
|
||||
|
||||
log.Println("SNI Proxy stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptConnections handles incoming connections
|
||||
func (p *SNIProxy) acceptConnections() {
|
||||
for {
|
||||
conn, err := p.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
logger.Debug("Accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
p.wg.Add(1)
|
||||
go p.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// readClientHello reads and parses the TLS ClientHello message
|
||||
func (p *SNIProxy) readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
|
||||
var hello *tls.ClientHelloInfo
|
||||
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
|
||||
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
hello = new(tls.ClientHelloInfo)
|
||||
*hello = *argHello
|
||||
return nil, nil
|
||||
},
|
||||
}).Handshake()
|
||||
if hello == nil {
|
||||
return nil, err
|
||||
}
|
||||
return hello, nil
|
||||
}
|
||||
|
||||
// peekClientHello reads the ClientHello while preserving the data for forwarding
|
||||
func (p *SNIProxy) peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
|
||||
peekedBytes := new(bytes.Buffer)
|
||||
hello, err := p.readClientHello(io.TeeReader(reader, peekedBytes))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return hello, io.MultiReader(peekedBytes, reader), nil
|
||||
}
|
||||
|
||||
// extractSNI extracts the SNI hostname from the TLS ClientHello
|
||||
func (p *SNIProxy) extractSNI(conn net.Conn) (string, io.Reader, error) {
|
||||
clientHello, clientReader, err := p.peekClientHello(conn)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to peek ClientHello: %w", err)
|
||||
}
|
||||
|
||||
if clientHello.ServerName == "" {
|
||||
return "", clientReader, fmt.Errorf("no SNI hostname found in ClientHello")
|
||||
}
|
||||
|
||||
return clientHello.ServerName, clientReader, nil
|
||||
}
|
||||
|
||||
// handleConnection processes a single client connection
|
||||
func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
defer p.wg.Done()
|
||||
defer clientConn.Close()
|
||||
|
||||
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
|
||||
|
||||
// Check for PROXY protocol from trusted upstream
|
||||
var proxyInfo *ProxyProtocolInfo
|
||||
var actualClientConn net.Conn = clientConn
|
||||
|
||||
if len(p.trustedUpstreams) > 0 {
|
||||
var err error
|
||||
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to parse PROXY protocol: %v", err)
|
||||
return
|
||||
}
|
||||
if proxyInfo != nil {
|
||||
logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d",
|
||||
proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort)
|
||||
} else {
|
||||
// No PROXY protocol detected, but connection is from trusted upstream
|
||||
// This is fine - treat as regular connection
|
||||
logger.Debug("No PROXY protocol detected from trusted upstream, treating as regular connection")
|
||||
}
|
||||
}
|
||||
|
||||
// Set read timeout for SNI extraction
|
||||
if err := actualClientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
logger.Debug("Failed to set read deadline: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract SNI hostname
|
||||
hostname, clientReader, err := p.extractSNI(actualClientConn)
|
||||
if err != nil {
|
||||
logger.Debug("SNI extraction failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if hostname == "" {
|
||||
log.Println("No SNI hostname found")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("SNI hostname detected: %s", hostname)
|
||||
|
||||
// Remove read timeout for normal operation
|
||||
if err := actualClientConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get routing information - use original client address if available from PROXY protocol
|
||||
var clientAddrStr string
|
||||
if proxyInfo != nil {
|
||||
clientAddrStr = fmt.Sprintf("%s:%d", proxyInfo.SrcIP, proxyInfo.SrcPort)
|
||||
} else {
|
||||
clientAddrStr = clientConn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
route, err := p.getRoute(hostname, clientAddrStr)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to get route for %s: %v", hostname, err)
|
||||
return
|
||||
}
|
||||
|
||||
if route == nil {
|
||||
logger.Debug("No route found for hostname: %s", hostname)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("Routing %s to %s:%d", hostname, route.TargetHost, route.TargetPort)
|
||||
|
||||
// Connect to target server
|
||||
targetConn, err := net.DialTimeout("tcp",
|
||||
fmt.Sprintf("%s:%d", route.TargetHost, route.TargetPort),
|
||||
10*time.Second)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to connect to target %s:%d: %v",
|
||||
route.TargetHost, route.TargetPort, err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
|
||||
|
||||
// Send PROXY protocol header if enabled
|
||||
if p.proxyProtocol {
|
||||
var proxyHeader string
|
||||
if proxyInfo != nil {
|
||||
// Use original client info from PROXY protocol
|
||||
proxyHeader = p.buildProxyProtocolHeaderFromInfo(proxyInfo, targetConn.LocalAddr())
|
||||
} else {
|
||||
// Use direct client connection info
|
||||
proxyHeader = buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
|
||||
}
|
||||
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
|
||||
|
||||
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
|
||||
logger.Debug("Failed to send PROXY protocol header: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Track this tunnel by SNI
|
||||
p.activeTunnelsLock.Lock()
|
||||
tunnel, ok := p.activeTunnels[hostname]
|
||||
if !ok {
|
||||
tunnel = &activeTunnel{}
|
||||
p.activeTunnels[hostname] = tunnel
|
||||
}
|
||||
tunnel.conns = append(tunnel.conns, actualClientConn)
|
||||
p.activeTunnelsLock.Unlock()
|
||||
|
||||
defer func() {
|
||||
// Remove this conn from active tunnels
|
||||
p.activeTunnelsLock.Lock()
|
||||
if tunnel, ok := p.activeTunnels[hostname]; ok {
|
||||
newConns := make([]net.Conn, 0, len(tunnel.conns))
|
||||
for _, c := range tunnel.conns {
|
||||
if c != actualClientConn {
|
||||
newConns = append(newConns, c)
|
||||
}
|
||||
}
|
||||
if len(newConns) == 0 {
|
||||
delete(p.activeTunnels, hostname)
|
||||
} else {
|
||||
tunnel.conns = newConns
|
||||
}
|
||||
}
|
||||
p.activeTunnelsLock.Unlock()
|
||||
}()
|
||||
|
||||
// Start bidirectional data transfer
|
||||
p.pipe(actualClientConn, targetConn, clientReader)
|
||||
}
|
||||
|
||||
// getRoute retrieves routing information for a hostname
|
||||
func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
||||
// Check local overrides first
|
||||
if _, isOverride := p.localOverrides[hostname]; isOverride {
|
||||
logger.Debug("Local override matched for hostname: %s", hostname)
|
||||
return &RouteRecord{
|
||||
Hostname: hostname,
|
||||
TargetHost: p.localProxyAddr,
|
||||
TargetPort: p.localProxyPort,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Fast path: check if hostname is in localSNIs
|
||||
p.localSNIsLock.RLock()
|
||||
_, isLocal := p.localSNIs[hostname]
|
||||
p.localSNIsLock.RUnlock()
|
||||
if isLocal {
|
||||
return &RouteRecord{
|
||||
Hostname: hostname,
|
||||
TargetHost: p.localProxyAddr,
|
||||
TargetPort: p.localProxyPort,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
if cached, found := p.cache.Get(hostname); found {
|
||||
if cached == nil {
|
||||
return nil, nil // Cached negative result
|
||||
}
|
||||
logger.Debug("Cache hit for hostname: %s", hostname)
|
||||
return cached.(*RouteRecord), nil
|
||||
}
|
||||
|
||||
logger.Debug("Cache miss for hostname: %s, querying API", hostname)
|
||||
|
||||
// Query API with timeout
|
||||
ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Construct API URL (without hostname in path)
|
||||
apiURL := fmt.Sprintf("%s/gerbil/get-resolved-hostname", p.remoteConfigURL)
|
||||
|
||||
// Create request body with hostname and public key
|
||||
requestBody := map[string]string{
|
||||
"hostname": hostname,
|
||||
"publicKey": p.publicKey,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Make HTTP request
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("API request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// Cache negative result for shorter time (1 minute)
|
||||
p.cache.Set(hostname, nil, 1*time.Minute)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var apiResponse RouteAPIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode API response: %w", err)
|
||||
}
|
||||
|
||||
endpoints := apiResponse.Endpoints
|
||||
|
||||
// Default target configuration
|
||||
targetHost := p.localProxyAddr
|
||||
targetPort := p.localProxyPort
|
||||
|
||||
// If no endpoints returned, use local node
|
||||
if len(endpoints) == 0 {
|
||||
logger.Debug("No endpoints returned for hostname: %s, using local node", hostname)
|
||||
} else {
|
||||
// Select endpoint using consistent hashing for stickiness
|
||||
selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints)
|
||||
targetHost = selectedEndpoint
|
||||
targetPort = 443 // Default HTTPS port
|
||||
logger.Debug("Selected endpoint %s for hostname %s from client %s", selectedEndpoint, hostname, clientAddr)
|
||||
}
|
||||
|
||||
route := &RouteRecord{
|
||||
Hostname: hostname,
|
||||
TargetHost: targetHost,
|
||||
TargetPort: targetPort,
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
p.cache.Set(hostname, route, cache.DefaultExpiration)
|
||||
logger.Debug("Cached route for hostname: %s", hostname)
|
||||
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// selectStickyEndpoint selects an endpoint using consistent hashing to ensure
|
||||
// the same client always routes to the same endpoint for load balancing
|
||||
func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) string {
|
||||
if len(endpoints) == 0 {
|
||||
return p.localProxyAddr
|
||||
}
|
||||
if len(endpoints) == 1 {
|
||||
return endpoints[0]
|
||||
}
|
||||
|
||||
// Use FNV hash for consistent selection based on client address
|
||||
hash := fnv.New32a()
|
||||
hash.Write([]byte(clientAddr))
|
||||
index := hash.Sum32() % uint32(len(endpoints))
|
||||
|
||||
return endpoints[index]
|
||||
}
|
||||
|
||||
// pipe handles bidirectional data transfer between connections
|
||||
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// closeOnce ensures we only close connections once
|
||||
var closeOnce sync.Once
|
||||
closeConns := func() {
|
||||
closeOnce.Do(func() {
|
||||
// Close both connections to unblock any pending reads
|
||||
clientConn.Close()
|
||||
targetConn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
// Copy data from client to target (using the buffered reader)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer closeConns()
|
||||
|
||||
// Use a large buffer for better performance
|
||||
buf := make([]byte, 32*1024)
|
||||
_, err := io.CopyBuffer(targetConn, clientReader, buf)
|
||||
if err != nil && err != io.EOF {
|
||||
logger.Debug("Copy client->target error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Copy data from target to client
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer closeConns()
|
||||
|
||||
// Use a large buffer for better performance
|
||||
buf := make([]byte, 32*1024)
|
||||
_, err := io.CopyBuffer(clientConn, targetConn, buf)
|
||||
if err != nil && err != io.EOF {
|
||||
logger.Debug("Copy target->client error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// GetCacheStats returns cache statistics
|
||||
func (p *SNIProxy) GetCacheStats() (int, int) {
|
||||
return p.cache.ItemCount(), len(p.cache.Items())
|
||||
}
|
||||
|
||||
// ClearCache clears all cached entries
|
||||
func (p *SNIProxy) ClearCache() {
|
||||
p.cache.Flush()
|
||||
log.Println("Cache cleared")
|
||||
}
|
||||
|
||||
// UpdateLocalSNIs updates the local SNIs and invalidates cache for changed domains
|
||||
func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) {
|
||||
newSNIs := make(map[string]struct{})
|
||||
for _, domain := range fullDomains {
|
||||
newSNIs[domain] = struct{}{}
|
||||
// Invalidate any cached route for this domain
|
||||
p.cache.Delete(domain)
|
||||
}
|
||||
|
||||
// Update localSNIs
|
||||
p.localSNIsLock.Lock()
|
||||
removed := make([]string, 0)
|
||||
for sni := range p.localSNIs {
|
||||
if _, stillLocal := newSNIs[sni]; !stillLocal {
|
||||
removed = append(removed, sni)
|
||||
}
|
||||
}
|
||||
p.localSNIs = newSNIs
|
||||
p.localSNIsLock.Unlock()
|
||||
|
||||
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
|
||||
|
||||
// Terminate tunnels for removed SNIs
|
||||
if len(removed) > 0 {
|
||||
p.activeTunnelsLock.Lock()
|
||||
for _, sni := range removed {
|
||||
if tunnels, ok := p.activeTunnels[sni]; ok {
|
||||
for _, conn := range tunnels.conns {
|
||||
conn.Close()
|
||||
}
|
||||
delete(p.activeTunnels, sni)
|
||||
logger.Debug("Closed tunnels for SNI target change: %s", sni)
|
||||
}
|
||||
}
|
||||
p.activeTunnelsLock.Unlock()
|
||||
}
|
||||
}
|
||||
119
proxy/proxy_test.go
Normal file
119
proxy/proxy_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildProxyProtocolHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientAddr string
|
||||
targetAddr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 client and target",
|
||||
clientAddr: "192.168.1.100:12345",
|
||||
targetAddr: "10.0.0.1:443",
|
||||
expected: "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv6 client and target",
|
||||
clientAddr: "[2001:db8::1]:12345",
|
||||
targetAddr: "[2001:db8::2]:443",
|
||||
expected: "PROXY TCP6 2001:db8::1 2001:db8::2 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv4 client with IPv6 loopback target",
|
||||
clientAddr: "192.168.1.100:12345",
|
||||
targetAddr: "[::1]:443",
|
||||
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv4 client with IPv6 target",
|
||||
clientAddr: "192.168.1.100:12345",
|
||||
targetAddr: "[2001:db8::2]:443",
|
||||
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv6 client with IPv4 target",
|
||||
clientAddr: "[2001:db8::1]:12345",
|
||||
targetAddr: "10.0.0.1:443",
|
||||
expected: "PROXY TCP6 2001:db8::1 ::ffff:10.0.0.1 12345 443\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clientTCP, err := net.ResolveTCPAddr("tcp", tt.clientAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve client address: %v", err)
|
||||
}
|
||||
|
||||
targetTCP, err := net.ResolveTCPAddr("tcp", tt.targetAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve target address: %v", err)
|
||||
}
|
||||
|
||||
result := buildProxyProtocolHeader(clientTCP, targetTCP)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
|
||||
// Test with non-TCP address type
|
||||
clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}
|
||||
targetAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443}
|
||||
|
||||
result := buildProxyProtocolHeader(clientAddr, targetAddr)
|
||||
expected := "PROXY UNKNOWN\r\n"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected %q, got %q", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
|
||||
proxy, err := NewSNIProxy(8443, "", "", "127.0.0.1", 443, nil, true, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SNI proxy: %v", err)
|
||||
}
|
||||
|
||||
// Test IPv4 case
|
||||
proxyInfo := &ProxyProtocolInfo{
|
||||
Protocol: "TCP4",
|
||||
SrcIP: "10.0.0.1",
|
||||
DestIP: "192.168.1.100",
|
||||
SrcPort: 12345,
|
||||
DestPort: 443,
|
||||
}
|
||||
|
||||
targetAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
|
||||
header := proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
|
||||
|
||||
expected := "PROXY TCP4 10.0.0.1 127.0.0.1 12345 8080\r\n"
|
||||
if header != expected {
|
||||
t.Errorf("Expected header '%s', got '%s'", expected, header)
|
||||
}
|
||||
|
||||
// Test IPv6 case
|
||||
proxyInfo = &ProxyProtocolInfo{
|
||||
Protocol: "TCP6",
|
||||
SrcIP: "2001:db8::1",
|
||||
DestIP: "2001:db8::2",
|
||||
SrcPort: 12345,
|
||||
DestPort: 443,
|
||||
}
|
||||
|
||||
targetAddr, _ = net.ResolveTCPAddr("tcp6", "[::1]:8080")
|
||||
header = proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
|
||||
|
||||
expected = "PROXY TCP6 2001:db8::1 ::1 12345 8080\r\n"
|
||||
if header != expected {
|
||||
t.Errorf("Expected header '%s', got '%s'", expected, header)
|
||||
}
|
||||
}
|
||||
1060
relay/relay.go
Normal file
1060
relay/relay.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user