mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-07 21:46:40 +00:00
Merge branch 'main' into dev
This commit is contained in:
182
.github/workflows/cicd.yml
vendored
182
.github/workflows/cicd.yml
vendored
@@ -1,52 +1,160 @@
|
||||
name: CI/CD Pipeline
|
||||
|
||||
# CI/CD workflow for building, publishing, mirroring, signing container images and building release binaries.
|
||||
# Actions are pinned to specific SHAs to reduce supply-chain risk. This workflow triggers on tag push events.
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write # for GHCR push
|
||||
id-token: write # for Cosign Keyless (OIDC) Signing
|
||||
|
||||
# Required secrets:
|
||||
# - DOCKER_HUB_USERNAME / DOCKER_HUB_ACCESS_TOKEN: push to Docker Hub
|
||||
# - GITHUB_TOKEN: used for GHCR login and OIDC keyless signing
|
||||
# - COSIGN_PRIVATE_KEY / COSIGN_PASSWORD / COSIGN_PUBLIC_KEY: for key-based signing
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: amd64-runner
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: amd64-runner
|
||||
# Job-level timeout to avoid runaway or stuck runs
|
||||
timeout-minutes: 120
|
||||
env:
|
||||
# Target images
|
||||
DOCKERHUB_IMAGE: docker.io/${{ secrets.DOCKER_HUB_USERNAME }}/${{ github.event.repository.name }}
|
||||
GHCR_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: docker.io
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
shell: bash
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 1.25
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
|
||||
with:
|
||||
go-version: 1.25
|
||||
|
||||
- name: Build and push Docker images
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
make docker-build-release tag=$TAG
|
||||
- name: Update version in main.go
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
if [ -f main.go ]; then
|
||||
sed -i 's/version_replaceme/'"$TAG"'/' main.go
|
||||
echo "Updated main.go with version $TAG"
|
||||
else
|
||||
echo "main.go not found"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
- name: Build and push Docker images (Docker Hub)
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
make docker-build-release tag=$TAG
|
||||
echo "Built & pushed to: ${{ env.DOCKERHUB_IMAGE }}:${TAG}"
|
||||
shell: bash
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
- name: Login in to GHCR
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install skopeo + jq
|
||||
# skopeo: copy/inspect images between registries
|
||||
# jq: JSON parsing tool used to extract digest values
|
||||
run: |
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y skopeo jq
|
||||
skopeo --version
|
||||
shell: bash
|
||||
|
||||
- name: Copy tag from Docker Hub to GHCR
|
||||
# Mirror the already-built image (all architectures) to GHCR so we can sign it
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG=${{ env.TAG }}
|
||||
echo "Copying ${{ env.DOCKERHUB_IMAGE }}:${TAG} -> ${{ env.GHCR_IMAGE }}:${TAG}"
|
||||
skopeo copy --all --retry-times 3 \
|
||||
docker://$DOCKERHUB_IMAGE:$TAG \
|
||||
docker://$GHCR_IMAGE:$TAG
|
||||
shell: bash
|
||||
|
||||
- name: Install cosign
|
||||
# cosign is used to sign and verify container images (key and keyless)
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
|
||||
- name: Dual-sign and verify (GHCR & Docker Hub)
|
||||
# Sign each image by digest using keyless (OIDC) and key-based signing,
|
||||
# then verify both the public key signature and the keyless OIDC signature.
|
||||
env:
|
||||
TAG: ${{ env.TAG }}
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
COSIGN_YES: "true"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
issuer="https://token.actions.githubusercontent.com"
|
||||
id_regex="^https://github.com/${{ github.repository }}/.+" # accept this repo (all workflows/refs)
|
||||
|
||||
for IMAGE in "${GHCR_IMAGE}" "${DOCKERHUB_IMAGE}"; do
|
||||
echo "Processing ${IMAGE}:${TAG}"
|
||||
|
||||
DIGEST="$(skopeo inspect --retry-times 3 docker://${IMAGE}:${TAG} | jq -r '.Digest')"
|
||||
REF="${IMAGE}@${DIGEST}"
|
||||
echo "Resolved digest: ${REF}"
|
||||
|
||||
echo "==> cosign sign (keyless) --recursive ${REF}"
|
||||
cosign sign --recursive "${REF}"
|
||||
|
||||
echo "==> cosign sign (key) --recursive ${REF}"
|
||||
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${REF}"
|
||||
|
||||
echo "==> cosign verify (public key) ${REF}"
|
||||
cosign verify --key env://COSIGN_PUBLIC_KEY "${REF}" -o text
|
||||
|
||||
echo "==> cosign verify (keyless policy) ${REF}"
|
||||
cosign verify \
|
||||
--certificate-oidc-issuer "${issuer}" \
|
||||
--certificate-identity-regexp "${id_regex}" \
|
||||
"${REF}" -o text
|
||||
done
|
||||
shell: bash
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
shell: bash
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
|
||||
132
.github/workflows/mirror.yaml
vendored
Normal file
132
.github/workflows/mirror.yaml
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
name: Mirror & Sign (Docker Hub to GHCR)
|
||||
|
||||
on:
|
||||
workflow_dispatch: {}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write # for keyless OIDC
|
||||
|
||||
env:
|
||||
SOURCE_IMAGE: docker.io/fosrl/gerbil
|
||||
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
jobs:
|
||||
mirror-and-dual-sign:
|
||||
runs-on: amd64-runner
|
||||
steps:
|
||||
- name: Install skopeo + jq
|
||||
run: |
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y skopeo jq
|
||||
skopeo --version
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
|
||||
- name: Input check
|
||||
run: |
|
||||
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
|
||||
echo "Source : ${SOURCE_IMAGE}"
|
||||
echo "Target : ${DEST_IMAGE}"
|
||||
|
||||
# Auth for skopeo (containers-auth)
|
||||
- name: Skopeo login to GHCR
|
||||
run: |
|
||||
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
# Auth for cosign (docker-config)
|
||||
- name: Docker login to GHCR (for cosign)
|
||||
run: |
|
||||
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
|
||||
|
||||
- name: List source tags
|
||||
run: |
|
||||
set -euo pipefail
|
||||
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
|
||||
| jq -r '.Tags[]' | sort -u > src-tags.txt
|
||||
echo "Found source tags: $(wc -l < src-tags.txt)"
|
||||
head -n 20 src-tags.txt || true
|
||||
|
||||
- name: List destination tags (skip existing)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
|
||||
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
|
||||
else
|
||||
: > dst-tags.txt
|
||||
fi
|
||||
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
|
||||
|
||||
- name: Mirror, dual-sign, and verify
|
||||
env:
|
||||
# keyless
|
||||
COSIGN_YES: "true"
|
||||
# key-based
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
# verify
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
copied=0; skipped=0; v_ok=0; errs=0
|
||||
|
||||
issuer="https://token.actions.githubusercontent.com"
|
||||
id_regex="^https://github.com/${{ github.repository }}/.+"
|
||||
|
||||
while read -r tag; do
|
||||
[ -z "$tag" ] && continue
|
||||
|
||||
if grep -Fxq "$tag" dst-tags.txt; then
|
||||
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
|
||||
skipped=$((skipped+1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
|
||||
if ! skopeo copy --all --retry-times 3 \
|
||||
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
|
||||
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
|
||||
errs=$((errs+1)); continue
|
||||
fi
|
||||
copied=$((copied+1))
|
||||
|
||||
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
|
||||
ref="${DEST_IMAGE}@${digest}"
|
||||
|
||||
echo "==> cosign sign (keyless) --recursive ${ref}"
|
||||
if ! cosign sign --recursive "${ref}"; then
|
||||
echo "::warning title=Keyless sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign sign (key) --recursive ${ref}"
|
||||
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
|
||||
echo "::warning title=Key sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (public key) ${ref}"
|
||||
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
|
||||
echo "::warning title=Verify(pubkey) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (keyless policy) ${ref}"
|
||||
if ! cosign verify \
|
||||
--certificate-oidc-issuer "${issuer}" \
|
||||
--certificate-identity-regexp "${id_regex}" \
|
||||
"${ref}" -o text; then
|
||||
echo "::warning title=Verify(keyless) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
else
|
||||
v_ok=$((v_ok+1))
|
||||
fi
|
||||
done < src-tags.txt
|
||||
|
||||
echo "---- Summary ----"
|
||||
echo "Copied : $copied"
|
||||
echo "Skipped : $skipped"
|
||||
echo "Verified OK : $v_ok"
|
||||
echo "Errors : $errs"
|
||||
11
.github/workflows/test.yml
vendored
11
.github/workflows/test.yml
vendored
@@ -1,5 +1,8 @@
|
||||
name: Run Tests
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
@@ -8,15 +11,15 @@ on:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: amd64-runner
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
|
||||
with:
|
||||
go-version: '1.25'
|
||||
go-version: 1.25
|
||||
|
||||
- name: Build go
|
||||
run: go build
|
||||
|
||||
385
37.diff
Normal file
385
37.diff
Normal file
@@ -0,0 +1,385 @@
|
||||
diff --git a/main.go b/main.go
|
||||
index 7a99c4d..61c186f 100644
|
||||
--- a/main.go
|
||||
+++ b/main.go
|
||||
@@ -2,7 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
+ "context"
|
||||
"encoding/json"
|
||||
+ "errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
"github.com/fosrl/gerbil/proxy"
|
||||
"github.com/fosrl/gerbil/relay"
|
||||
"github.com/vishvananda/netlink"
|
||||
+ "golang.org/x/sync/errgroup"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@@ -217,6 +220,10 @@ func main() {
|
||||
logger.Init()
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
+ // Base context for the application; cancel on SIGINT/SIGTERM
|
||||
+ ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
+ defer stop()
|
||||
+
|
||||
// try to parse as http://host:port and set the listenAddr to the :port from this reachableAt.
|
||||
if reachableAt != "" && listenAddr == "" {
|
||||
if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") {
|
||||
@@ -324,10 +331,16 @@ func main() {
|
||||
// Ensure the WireGuard peers exist
|
||||
ensureWireguardPeers(wgconfig.Peers)
|
||||
|
||||
- go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
|
||||
+ // Child error group derived from base context
|
||||
+ group, groupCtx := errgroup.WithContext(ctx)
|
||||
+
|
||||
+ // Periodic bandwidth reporting
|
||||
+ group.Go(func() error {
|
||||
+ return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth")
|
||||
+ })
|
||||
|
||||
// Start the UDP proxy server
|
||||
- proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
|
||||
+ proxyRelay = relay.NewUDPProxyServer(groupCtx, ":21820", remoteConfigURL, key, reachableAt)
|
||||
err = proxyRelay.Start()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to start UDP proxy server: %v", err)
|
||||
@@ -371,18 +384,39 @@ func main() {
|
||||
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
|
||||
logger.Info("Starting HTTP server on %s", listenAddr)
|
||||
|
||||
- // Run HTTP server in a goroutine
|
||||
- go func() {
|
||||
- if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
- logger.Error("HTTP server failed: %v", err)
|
||||
+ // HTTP server with graceful shutdown on context cancel
|
||||
+ server := &http.Server{
|
||||
+ Addr: listenAddr,
|
||||
+ Handler: nil,
|
||||
+ }
|
||||
+ group.Go(func() error {
|
||||
+ // http.ErrServerClosed is returned on graceful shutdown; not an error for us
|
||||
+ if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
+ return err
|
||||
+ }
|
||||
+ return nil
|
||||
+ })
|
||||
+ group.Go(func() error {
|
||||
+ <-groupCtx.Done()
|
||||
+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
+ defer cancel()
|
||||
+ _ = server.Shutdown(shutdownCtx)
|
||||
+ // Stop background components as the context is canceled
|
||||
+ if proxySNI != nil {
|
||||
+ _ = proxySNI.Stop()
|
||||
+ }
|
||||
+ if proxyRelay != nil {
|
||||
+ proxyRelay.Stop()
|
||||
}
|
||||
- }()
|
||||
+ return nil
|
||||
+ })
|
||||
|
||||
- // Keep the main goroutine running
|
||||
- sigCh := make(chan os.Signal, 1)
|
||||
- signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
- <-sigCh
|
||||
- logger.Info("Shutting down servers...")
|
||||
+ // Wait for all goroutines to finish
|
||||
+ if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
+ logger.Error("Service exited with error: %v", err)
|
||||
+ } else if errors.Is(err, context.Canceled) {
|
||||
+ logger.Info("Context cancelled, shutting down")
|
||||
+ }
|
||||
}
|
||||
|
||||
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
|
||||
@@ -639,7 +673,7 @@ func ensureMSSClamping() error {
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
- logger.Error(errMsg)
|
||||
+ logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
@@ -656,7 +690,7 @@ func ensureMSSClamping() error {
|
||||
if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
- logger.Error(errMsg)
|
||||
+ logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
@@ -977,13 +1011,18 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
-func periodicBandwidthCheck(endpoint string) {
|
||||
+func periodicBandwidthCheck(ctx context.Context, endpoint string) error {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
- for range ticker.C {
|
||||
- if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
- logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
+ for {
|
||||
+ select {
|
||||
+ case <-ticker.C:
|
||||
+ if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
+ logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
+ }
|
||||
+ case <-ctx.Done():
|
||||
+ return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
diff --git a/relay/relay.go b/relay/relay.go
|
||||
index e74ed87..e3fef04 100644
|
||||
--- a/relay/relay.go
|
||||
+++ b/relay/relay.go
|
||||
@@ -1,6 +1,7 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
+ "context"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
@@ -112,6 +113,8 @@ type UDPProxyServer struct {
|
||||
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
||||
privateKey wgtypes.Key
|
||||
packetChan chan Packet
|
||||
+ ctx context.Context
|
||||
+ cancel context.CancelFunc
|
||||
|
||||
// Session tracking for WireGuard peers
|
||||
// Key format: "senderIndex:receiverIndex"
|
||||
@@ -123,14 +126,17 @@ type UDPProxyServer struct {
|
||||
ReachableAt string
|
||||
}
|
||||
|
||||
-// NewUDPProxyServer initializes the server with a buffered packet channel.
|
||||
-func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
+// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
|
||||
+func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
+ ctx, cancel := context.WithCancel(parentCtx)
|
||||
return &UDPProxyServer{
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 1000),
|
||||
ReachableAt: reachableAt,
|
||||
+ ctx: ctx,
|
||||
+ cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,17 +183,51 @@ func (s *UDPProxyServer) Start() error {
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) Stop() {
|
||||
- s.conn.Close()
|
||||
+ // Signal all background goroutines to stop
|
||||
+ if s.cancel != nil {
|
||||
+ s.cancel()
|
||||
+ }
|
||||
+ // Close listener to unblock reads
|
||||
+ if s.conn != nil {
|
||||
+ _ = s.conn.Close()
|
||||
+ }
|
||||
+ // Close all downstream UDP connections
|
||||
+ s.connections.Range(func(key, value interface{}) bool {
|
||||
+ if dc, ok := value.(*DestinationConn); ok && dc.conn != nil {
|
||||
+ _ = dc.conn.Close()
|
||||
+ }
|
||||
+ return true
|
||||
+ })
|
||||
+ // Close packet channel to stop workers
|
||||
+ select {
|
||||
+ case <-s.ctx.Done():
|
||||
+ default:
|
||||
+ }
|
||||
+ close(s.packetChan)
|
||||
}
|
||||
|
||||
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
|
||||
func (s *UDPProxyServer) readPackets() {
|
||||
for {
|
||||
+ // Exit promptly if context is canceled
|
||||
+ select {
|
||||
+ case <-s.ctx.Done():
|
||||
+ return
|
||||
+ default:
|
||||
+ }
|
||||
buf := bufferPool.Get().([]byte)
|
||||
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
- logger.Error("Error reading UDP packet: %v", err)
|
||||
- continue
|
||||
+ // If we're shutting down, exit
|
||||
+ select {
|
||||
+ case <-s.ctx.Done():
|
||||
+ bufferPool.Put(buf[:1500])
|
||||
+ return
|
||||
+ default:
|
||||
+ logger.Error("Error reading UDP packet: %v", err)
|
||||
+ bufferPool.Put(buf[:1500])
|
||||
+ continue
|
||||
+ }
|
||||
}
|
||||
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
|
||||
}
|
||||
@@ -588,49 +628,67 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
||||
// Add a cleanup method to periodically remove idle connections
|
||||
func (s *UDPProxyServer) cleanupIdleConnections() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
- for range ticker.C {
|
||||
- now := time.Now()
|
||||
- s.connections.Range(func(key, value interface{}) bool {
|
||||
- destConn := value.(*DestinationConn)
|
||||
- if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
- destConn.conn.Close()
|
||||
- s.connections.Delete(key)
|
||||
- }
|
||||
- return true
|
||||
- })
|
||||
+ defer ticker.Stop()
|
||||
+ for {
|
||||
+ select {
|
||||
+ case <-ticker.C:
|
||||
+ now := time.Now()
|
||||
+ s.connections.Range(func(key, value interface{}) bool {
|
||||
+ destConn := value.(*DestinationConn)
|
||||
+ if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
+ destConn.conn.Close()
|
||||
+ s.connections.Delete(key)
|
||||
+ }
|
||||
+ return true
|
||||
+ })
|
||||
+ case <-s.ctx.Done():
|
||||
+ return
|
||||
+ }
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle sessions
|
||||
func (s *UDPProxyServer) cleanupIdleSessions() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
- for range ticker.C {
|
||||
- now := time.Now()
|
||||
- s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
- session := value.(*WireGuardSession)
|
||||
- if now.Sub(session.LastSeen) > 15*time.Minute {
|
||||
- s.wgSessions.Delete(key)
|
||||
- logger.Debug("Removed idle session: %s", key)
|
||||
- }
|
||||
- return true
|
||||
- })
|
||||
+ defer ticker.Stop()
|
||||
+ for {
|
||||
+ select {
|
||||
+ case <-ticker.C:
|
||||
+ now := time.Now()
|
||||
+ s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
+ session := value.(*WireGuardSession)
|
||||
+ if now.Sub(session.LastSeen) > 15*time.Minute {
|
||||
+ s.wgSessions.Delete(key)
|
||||
+ logger.Debug("Removed idle session: %s", key)
|
||||
+ }
|
||||
+ return true
|
||||
+ })
|
||||
+ case <-s.ctx.Done():
|
||||
+ return
|
||||
+ }
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle proxy mappings
|
||||
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
- for range ticker.C {
|
||||
- now := time.Now()
|
||||
- s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
- mapping := value.(ProxyMapping)
|
||||
- // Remove mappings that haven't been used in 30 minutes
|
||||
- if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
- s.proxyMappings.Delete(key)
|
||||
- logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
- }
|
||||
- return true
|
||||
- })
|
||||
+ defer ticker.Stop()
|
||||
+ for {
|
||||
+ select {
|
||||
+ case <-ticker.C:
|
||||
+ now := time.Now()
|
||||
+ s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
+ mapping := value.(ProxyMapping)
|
||||
+ // Remove mappings that haven't been used in 30 minutes
|
||||
+ if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
+ s.proxyMappings.Delete(key)
|
||||
+ logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
+ }
|
||||
+ return true
|
||||
+ })
|
||||
+ case <-s.ctx.Done():
|
||||
+ return
|
||||
+ }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -943,23 +1001,29 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
||||
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
|
||||
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
- for range ticker.C {
|
||||
- now := time.Now()
|
||||
- s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
- pattern := value.(*CommunicationPattern)
|
||||
-
|
||||
- // Get the most recent activity
|
||||
- lastActivity := pattern.LastFromClient
|
||||
- if pattern.LastFromDest.After(lastActivity) {
|
||||
- lastActivity = pattern.LastFromDest
|
||||
- }
|
||||
+ defer ticker.Stop()
|
||||
+ for {
|
||||
+ select {
|
||||
+ case <-ticker.C:
|
||||
+ now := time.Now()
|
||||
+ s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
+ pattern := value.(*CommunicationPattern)
|
||||
+
|
||||
+ // Get the most recent activity
|
||||
+ lastActivity := pattern.LastFromClient
|
||||
+ if pattern.LastFromDest.After(lastActivity) {
|
||||
+ lastActivity = pattern.LastFromDest
|
||||
+ }
|
||||
|
||||
- // Remove patterns that haven't had activity in 20 minutes
|
||||
- if now.Sub(lastActivity) > 20*time.Minute {
|
||||
- s.commPatterns.Delete(key)
|
||||
- logger.Debug("Removed idle communication pattern: %s", key)
|
||||
- }
|
||||
- return true
|
||||
- })
|
||||
+ // Remove patterns that haven't had activity in 20 minutes
|
||||
+ if now.Sub(lastActivity) > 20*time.Minute {
|
||||
+ s.commPatterns.Delete(key)
|
||||
+ logger.Debug("Removed idle communication pattern: %s", key)
|
||||
+ }
|
||||
+ return true
|
||||
+ })
|
||||
+ case <-s.ctx.Done():
|
||||
+ return
|
||||
+ }
|
||||
}
|
||||
}
|
||||
@@ -16,7 +16,7 @@ COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /gerbil
|
||||
|
||||
# Start a new stage from scratch
|
||||
FROM alpine:3.22 AS runner
|
||||
FROM alpine:3.23 AS runner
|
||||
|
||||
RUN apk add --no-cache iptables iproute2
|
||||
|
||||
|
||||
6
go.mod
6
go.mod
@@ -5,7 +5,7 @@ go 1.25
|
||||
require (
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
)
|
||||
|
||||
@@ -16,8 +16,8 @@ require (
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
golang.org/x/net v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect
|
||||
)
|
||||
|
||||
12
go.sum
12
go.sum
@@ -16,16 +16,16 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
||||
|
||||
89
main.go
89
main.go
@@ -2,7 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
"github.com/fosrl/gerbil/proxy"
|
||||
"github.com/fosrl/gerbil/relay"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@@ -217,6 +220,10 @@ func main() {
|
||||
logger.Init()
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
// Base context for the application; cancel on SIGINT/SIGTERM
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
// try to parse as http://host:port and set the listenAddr to the :port from this reachableAt.
|
||||
if reachableAt != "" && listenAddr == "" {
|
||||
if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") {
|
||||
@@ -324,10 +331,16 @@ func main() {
|
||||
// Ensure the WireGuard peers exist
|
||||
ensureWireguardPeers(wgconfig.Peers)
|
||||
|
||||
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
|
||||
// Child error group derived from base context
|
||||
group, groupCtx := errgroup.WithContext(ctx)
|
||||
|
||||
// Periodic bandwidth reporting
|
||||
group.Go(func() error {
|
||||
return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth")
|
||||
})
|
||||
|
||||
// Start the UDP proxy server
|
||||
proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
|
||||
proxyRelay = relay.NewUDPProxyServer(groupCtx, ":21820", remoteConfigURL, key, reachableAt)
|
||||
err = proxyRelay.Start()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to start UDP proxy server: %v", err)
|
||||
@@ -371,18 +384,39 @@ func main() {
|
||||
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
|
||||
logger.Info("Starting HTTP server on %s", listenAddr)
|
||||
|
||||
// Run HTTP server in a goroutine
|
||||
go func() {
|
||||
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
logger.Error("HTTP server failed: %v", err)
|
||||
// HTTP server with graceful shutdown on context cancel
|
||||
server := &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: nil,
|
||||
}
|
||||
group.Go(func() error {
|
||||
// http.ErrServerClosed is returned on graceful shutdown; not an error for us
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
group.Go(func() error {
|
||||
<-groupCtx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = server.Shutdown(shutdownCtx)
|
||||
// Stop background components as the context is canceled
|
||||
if proxySNI != nil {
|
||||
_ = proxySNI.Stop()
|
||||
}
|
||||
if proxyRelay != nil {
|
||||
proxyRelay.Stop()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Keep the main goroutine running
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
logger.Info("Shutting down servers...")
|
||||
// Wait for all goroutines to finish
|
||||
if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
logger.Error("Service exited with error: %v", err)
|
||||
} else if errors.Is(err, context.Canceled) {
|
||||
logger.Info("Context cancelled, shutting down")
|
||||
}
|
||||
}
|
||||
|
||||
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
|
||||
@@ -639,7 +673,7 @@ func ensureMSSClamping() error {
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
@@ -656,7 +690,7 @@ func ensureMSSClamping() error {
|
||||
if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
@@ -977,13 +1011,18 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func periodicBandwidthCheck(endpoint string) {
|
||||
func periodicBandwidthCheck(ctx context.Context, endpoint string) error {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1003,8 +1042,13 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Track the set of peers currently present on the device to prune stale readings efficiently
|
||||
currentPeerKeys := make(map[string]struct{}, len(device.Peers))
|
||||
|
||||
for _, peer := range device.Peers {
|
||||
publicKey := peer.PublicKey.String()
|
||||
currentPeerKeys[publicKey] = struct{}{}
|
||||
|
||||
currentReading := PeerReading{
|
||||
BytesReceived: peer.ReceiveBytes,
|
||||
BytesTransmitted: peer.TransmitBytes,
|
||||
@@ -1061,14 +1105,7 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
|
||||
// Clean up old peers
|
||||
for publicKey := range lastReadings {
|
||||
found := false
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if _, exists := currentPeerKeys[publicKey]; !exists {
|
||||
delete(lastReadings, publicKey)
|
||||
}
|
||||
}
|
||||
|
||||
222
relay/relay.go
222
relay/relay.go
@@ -2,6 +2,7 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -60,12 +61,41 @@ type DestinationConn struct {
|
||||
|
||||
// Type for storing WireGuard handshake information
|
||||
type WireGuardSession struct {
|
||||
mu sync.RWMutex
|
||||
ReceiverIndex uint32
|
||||
SenderIndex uint32
|
||||
DestAddr *net.UDPAddr
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// GetSenderIndex returns the SenderIndex in a thread-safe manner
|
||||
func (s *WireGuardSession) GetSenderIndex() uint32 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.SenderIndex
|
||||
}
|
||||
|
||||
// GetDestAddr returns the DestAddr in a thread-safe manner
|
||||
func (s *WireGuardSession) GetDestAddr() *net.UDPAddr {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.DestAddr
|
||||
}
|
||||
|
||||
// GetLastSeen returns the LastSeen timestamp in a thread-safe manner
|
||||
func (s *WireGuardSession) GetLastSeen() time.Time {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.LastSeen
|
||||
}
|
||||
|
||||
// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner
|
||||
func (s *WireGuardSession) UpdateLastSeen() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.LastSeen = time.Now()
|
||||
}
|
||||
|
||||
// Type for tracking bidirectional communication patterns to rebuild sessions
|
||||
type CommunicationPattern struct {
|
||||
FromClient *net.UDPAddr // The client address
|
||||
@@ -114,6 +144,8 @@ type UDPProxyServer struct {
|
||||
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
||||
privateKey wgtypes.Key
|
||||
packetChan chan Packet
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Session tracking for WireGuard peers
|
||||
// Key format: "senderIndex:receiverIndex"
|
||||
@@ -125,14 +157,17 @@ type UDPProxyServer struct {
|
||||
ReachableAt string
|
||||
}
|
||||
|
||||
// NewUDPProxyServer initializes the server with a buffered packet channel.
|
||||
func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
|
||||
func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
return &UDPProxyServer{
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 1000),
|
||||
ReachableAt: reachableAt,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,17 +214,51 @@ func (s *UDPProxyServer) Start() error {
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) Stop() {
|
||||
s.conn.Close()
|
||||
// Signal all background goroutines to stop
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
// Close listener to unblock reads
|
||||
if s.conn != nil {
|
||||
_ = s.conn.Close()
|
||||
}
|
||||
// Close all downstream UDP connections
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
if dc, ok := value.(*DestinationConn); ok && dc.conn != nil {
|
||||
_ = dc.conn.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
// Close packet channel to stop workers
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
default:
|
||||
}
|
||||
close(s.packetChan)
|
||||
}
|
||||
|
||||
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
|
||||
func (s *UDPProxyServer) readPackets() {
|
||||
for {
|
||||
// Exit promptly if context is canceled
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
buf := bufferPool.Get().([]byte)
|
||||
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
continue
|
||||
// If we're shutting down, exit
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
bufferPool.Put(buf[:1500])
|
||||
return
|
||||
default:
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
bufferPool.Put(buf[:1500])
|
||||
continue
|
||||
}
|
||||
}
|
||||
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
|
||||
}
|
||||
@@ -445,13 +514,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
// First check for existing sessions to see if we know where to send this packet
|
||||
s.wgSessions.Range(func(k, v interface{}) bool {
|
||||
session := v.(*WireGuardSession)
|
||||
if session.SenderIndex == receiverIndex {
|
||||
// Found matching session
|
||||
destAddr = session.DestAddr
|
||||
|
||||
// Update last seen time
|
||||
session.LastSeen = time.Now()
|
||||
s.wgSessions.Store(k, session)
|
||||
// Check if session matches (read lock for check)
|
||||
if session.GetSenderIndex() == receiverIndex {
|
||||
// Found matching session - get dest addr and update last seen
|
||||
destAddr = session.GetDestAddr()
|
||||
session.UpdateLastSeen()
|
||||
return false // stop iteration
|
||||
}
|
||||
return true // continue iteration
|
||||
@@ -591,49 +658,69 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
||||
// Add a cleanup method to periodically remove idle connections
|
||||
func (s *UDPProxyServer) cleanupIdleConnections() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
destConn := value.(*DestinationConn)
|
||||
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
destConn.conn.Close()
|
||||
s.connections.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
destConn := value.(*DestinationConn)
|
||||
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
destConn.conn.Close()
|
||||
s.connections.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle sessions
|
||||
func (s *UDPProxyServer) cleanupIdleSessions() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
session := value.(*WireGuardSession)
|
||||
if now.Sub(session.LastSeen) > 15*time.Minute {
|
||||
s.wgSessions.Delete(key)
|
||||
logger.Debug("Removed idle session: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
session := value.(*WireGuardSession)
|
||||
// Use thread-safe method to read LastSeen
|
||||
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
|
||||
s.wgSessions.Delete(key)
|
||||
logger.Debug("Removed idle session: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle proxy mappings
|
||||
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
mapping := value.(ProxyMapping)
|
||||
// Remove mappings that haven't been used in 30 minutes
|
||||
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
s.proxyMappings.Delete(key)
|
||||
logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
mapping := value.(ProxyMapping)
|
||||
// Remove mappings that haven't been used in 30 minutes
|
||||
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
s.proxyMappings.Delete(key)
|
||||
logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -738,8 +825,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) {
|
||||
keyStr := key.(string)
|
||||
session := value.(*WireGuardSession)
|
||||
|
||||
// Check if the session's destination address contains the WG IP
|
||||
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
|
||||
// Check if the session's destination address contains the WG IP (thread-safe)
|
||||
destAddr := session.GetDestAddr()
|
||||
if destAddr != nil && destAddr.IP.String() == ip {
|
||||
keysToDelete = append(keysToDelete, keyStr)
|
||||
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
|
||||
}
|
||||
@@ -929,14 +1017,12 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
||||
|
||||
// Check if we already have this session
|
||||
if _, exists := s.wgSessions.Load(sessionKey); !exists {
|
||||
session := &WireGuardSession{
|
||||
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
||||
ReceiverIndex: pattern.DestIndex,
|
||||
SenderIndex: pattern.ClientIndex,
|
||||
DestAddr: pattern.ToDestination,
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
|
||||
s.wgSessions.Store(sessionKey, session)
|
||||
})
|
||||
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
|
||||
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
|
||||
}
|
||||
@@ -946,23 +1032,29 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
||||
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
|
||||
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
pattern := value.(*CommunicationPattern)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
pattern := value.(*CommunicationPattern)
|
||||
|
||||
// Get the most recent activity
|
||||
lastActivity := pattern.LastFromClient
|
||||
if pattern.LastFromDest.After(lastActivity) {
|
||||
lastActivity = pattern.LastFromDest
|
||||
}
|
||||
// Get the most recent activity
|
||||
lastActivity := pattern.LastFromClient
|
||||
if pattern.LastFromDest.After(lastActivity) {
|
||||
lastActivity = pattern.LastFromDest
|
||||
}
|
||||
|
||||
// Remove patterns that haven't had activity in 20 minutes
|
||||
if now.Sub(lastActivity) > 20*time.Minute {
|
||||
s.commPatterns.Delete(key)
|
||||
logger.Debug("Removed idle communication pattern: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
// Remove patterns that haven't had activity in 20 minutes
|
||||
if now.Sub(lastActivity) > 20*time.Minute {
|
||||
s.commPatterns.Delete(key)
|
||||
logger.Debug("Removed idle communication pattern: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user