Merge branch 'main' into dev

This commit is contained in:
Owen
2025-12-06 12:15:34 -05:00
9 changed files with 899 additions and 142 deletions

View File

@@ -1,52 +1,160 @@
name: CI/CD Pipeline
# CI/CD workflow for building, publishing, mirroring, signing container images and building release binaries.
# Actions are pinned to specific SHAs to reduce supply-chain risk. This workflow triggers on tag push events.
permissions:
contents: read
packages: write # for GHCR push
id-token: write # for Cosign Keyless (OIDC) Signing
# Required secrets:
# - DOCKER_HUB_USERNAME / DOCKER_HUB_ACCESS_TOKEN: push to Docker Hub
# - GITHUB_TOKEN: used for GHCR login and OIDC keyless signing
# - COSIGN_PRIVATE_KEY / COSIGN_PASSWORD / COSIGN_PUBLIC_KEY: for key-based signing
on:
push:
tags:
- "*"
concurrency:
group: ${{ github.ref }}
cancel-in-progress: true
jobs:
release:
name: Build and Release
runs-on: amd64-runner
# Job-level timeout to avoid runaway or stuck runs
timeout-minutes: 120
env:
# Target images
DOCKERHUB_IMAGE: docker.io/${{ secrets.DOCKER_HUB_USERNAME }}/${{ github.event.repository.name }}
GHCR_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
steps:
- name: Checkout code
uses: actions/checkout@v5
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
- name: Log in to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with:
registry: docker.io
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
- name: Extract tag name
id: get-tag
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
shell: bash
- name: Install Go
uses: actions/setup-go@v6
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
with:
go-version: 1.25
- name: Build and push Docker images
- name: Update version in main.go
run: |
TAG=${{ env.TAG }}
if [ -f main.go ]; then
sed -i 's/version_replaceme/'"$TAG"'/' main.go
echo "Updated main.go with version $TAG"
else
echo "main.go not found"
fi
shell: bash
- name: Build and push Docker images (Docker Hub)
run: |
TAG=${{ env.TAG }}
make docker-build-release tag=$TAG
echo "Built & pushed to: ${{ env.DOCKERHUB_IMAGE }}:${TAG}"
shell: bash
- name: Login in to GHCR
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Install skopeo + jq
# skopeo: copy/inspect images between registries
# jq: JSON parsing tool used to extract digest values
run: |
sudo apt-get update -y
sudo apt-get install -y skopeo jq
skopeo --version
shell: bash
- name: Copy tag from Docker Hub to GHCR
# Mirror the already-built image (all architectures) to GHCR so we can sign it
run: |
set -euo pipefail
TAG=${{ env.TAG }}
echo "Copying ${{ env.DOCKERHUB_IMAGE }}:${TAG} -> ${{ env.GHCR_IMAGE }}:${TAG}"
skopeo copy --all --retry-times 3 \
docker://$DOCKERHUB_IMAGE:$TAG \
docker://$GHCR_IMAGE:$TAG
shell: bash
- name: Install cosign
# cosign is used to sign and verify container images (key and keyless)
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
- name: Dual-sign and verify (GHCR & Docker Hub)
# Sign each image by digest using keyless (OIDC) and key-based signing,
# then verify both the public key signature and the keyless OIDC signature.
env:
TAG: ${{ env.TAG }}
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
COSIGN_YES: "true"
run: |
set -euo pipefail
issuer="https://token.actions.githubusercontent.com"
id_regex="^https://github.com/${{ github.repository }}/.+" # accept this repo (all workflows/refs)
for IMAGE in "${GHCR_IMAGE}" "${DOCKERHUB_IMAGE}"; do
echo "Processing ${IMAGE}:${TAG}"
DIGEST="$(skopeo inspect --retry-times 3 docker://${IMAGE}:${TAG} | jq -r '.Digest')"
REF="${IMAGE}@${DIGEST}"
echo "Resolved digest: ${REF}"
echo "==> cosign sign (keyless) --recursive ${REF}"
cosign sign --recursive "${REF}"
echo "==> cosign sign (key) --recursive ${REF}"
cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${REF}"
echo "==> cosign verify (public key) ${REF}"
cosign verify --key env://COSIGN_PUBLIC_KEY "${REF}" -o text
echo "==> cosign verify (keyless policy) ${REF}"
cosign verify \
--certificate-oidc-issuer "${issuer}" \
--certificate-identity-regexp "${id_regex}" \
"${REF}" -o text
done
shell: bash
- name: Build binaries
run: |
make go-build-release
shell: bash
- name: Upload artifacts from /bin
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
with:
name: binaries
path: bin/

132
.github/workflows/mirror.yaml vendored Normal file
View File

@@ -0,0 +1,132 @@
name: Mirror & Sign (Docker Hub to GHCR)
on:
workflow_dispatch: {}
permissions:
contents: read
packages: write
id-token: write # for keyless OIDC
env:
SOURCE_IMAGE: docker.io/fosrl/gerbil
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
jobs:
mirror-and-dual-sign:
runs-on: amd64-runner
steps:
- name: Install skopeo + jq
run: |
sudo apt-get update -y
sudo apt-get install -y skopeo jq
skopeo --version
- name: Install cosign
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
- name: Input check
run: |
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
echo "Source : ${SOURCE_IMAGE}"
echo "Target : ${DEST_IMAGE}"
# Auth for skopeo (containers-auth)
- name: Skopeo login to GHCR
run: |
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
# Auth for cosign (docker-config)
- name: Docker login to GHCR (for cosign)
run: |
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
- name: List source tags
run: |
set -euo pipefail
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
| jq -r '.Tags[]' | sort -u > src-tags.txt
echo "Found source tags: $(wc -l < src-tags.txt)"
head -n 20 src-tags.txt || true
- name: List destination tags (skip existing)
run: |
set -euo pipefail
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
else
: > dst-tags.txt
fi
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
- name: Mirror, dual-sign, and verify
env:
# keyless
COSIGN_YES: "true"
# key-based
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
# verify
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
run: |
set -euo pipefail
copied=0; skipped=0; v_ok=0; errs=0
issuer="https://token.actions.githubusercontent.com"
id_regex="^https://github.com/${{ github.repository }}/.+"
while read -r tag; do
[ -z "$tag" ] && continue
if grep -Fxq "$tag" dst-tags.txt; then
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
skipped=$((skipped+1))
continue
fi
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
if ! skopeo copy --all --retry-times 3 \
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
errs=$((errs+1)); continue
fi
copied=$((copied+1))
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
ref="${DEST_IMAGE}@${digest}"
echo "==> cosign sign (keyless) --recursive ${ref}"
if ! cosign sign --recursive "${ref}"; then
echo "::warning title=Keyless sign failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign sign (key) --recursive ${ref}"
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
echo "::warning title=Key sign failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign verify (public key) ${ref}"
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
echo "::warning title=Verify(pubkey) failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign verify (keyless policy) ${ref}"
if ! cosign verify \
--certificate-oidc-issuer "${issuer}" \
--certificate-identity-regexp "${id_regex}" \
"${ref}" -o text; then
echo "::warning title=Verify(keyless) failed::${ref}"
errs=$((errs+1))
else
v_ok=$((v_ok+1))
fi
done < src-tags.txt
echo "---- Summary ----"
echo "Copied : $copied"
echo "Skipped : $skipped"
echo "Verified OK : $v_ok"
echo "Errors : $errs"

View File

@@ -1,5 +1,8 @@
name: Run Tests
permissions:
contents: read
on:
pull_request:
branches:
@@ -8,15 +11,15 @@ on:
jobs:
test:
runs-on: ubuntu-latest
runs-on: amd64-runner
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- name: Set up Go
uses: actions/setup-go@v6
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
with:
go-version: '1.25'
go-version: 1.25
- name: Build go
run: go build

385
37.diff Normal file
View File

@@ -0,0 +1,385 @@
diff --git a/main.go b/main.go
index 7a99c4d..61c186f 100644
--- a/main.go
+++ b/main.go
@@ -2,7 +2,9 @@ package main
import (
"bytes"
+ "context"
"encoding/json"
+ "errors"
"flag"
"fmt"
"io"
@@ -21,6 +23,7 @@ import (
"github.com/fosrl/gerbil/proxy"
"github.com/fosrl/gerbil/relay"
"github.com/vishvananda/netlink"
+ "golang.org/x/sync/errgroup"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -217,6 +220,10 @@ func main() {
logger.Init()
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
+ // Base context for the application; cancel on SIGINT/SIGTERM
+ ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+ defer stop()
+
// try to parse as http://host:port and set the listenAddr to the :port from this reachableAt.
if reachableAt != "" && listenAddr == "" {
if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") {
@@ -324,10 +331,16 @@ func main() {
// Ensure the WireGuard peers exist
ensureWireguardPeers(wgconfig.Peers)
- go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
+ // Child error group derived from base context
+ group, groupCtx := errgroup.WithContext(ctx)
+
+ // Periodic bandwidth reporting
+ group.Go(func() error {
+ return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth")
+ })
// Start the UDP proxy server
- proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
+ proxyRelay = relay.NewUDPProxyServer(groupCtx, ":21820", remoteConfigURL, key, reachableAt)
err = proxyRelay.Start()
if err != nil {
logger.Fatal("Failed to start UDP proxy server: %v", err)
@@ -371,18 +384,39 @@ func main() {
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
logger.Info("Starting HTTP server on %s", listenAddr)
- // Run HTTP server in a goroutine
- go func() {
- if err := http.ListenAndServe(listenAddr, nil); err != nil {
- logger.Error("HTTP server failed: %v", err)
+ // HTTP server with graceful shutdown on context cancel
+ server := &http.Server{
+ Addr: listenAddr,
+ Handler: nil,
+ }
+ group.Go(func() error {
+ // http.ErrServerClosed is returned on graceful shutdown; not an error for us
+ if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
+ return err
+ }
+ return nil
+ })
+ group.Go(func() error {
+ <-groupCtx.Done()
+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ _ = server.Shutdown(shutdownCtx)
+ // Stop background components as the context is canceled
+ if proxySNI != nil {
+ _ = proxySNI.Stop()
+ }
+ if proxyRelay != nil {
+ proxyRelay.Stop()
}
- }()
+ return nil
+ })
- // Keep the main goroutine running
- sigCh := make(chan os.Signal, 1)
- signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
- <-sigCh
- logger.Info("Shutting down servers...")
+ // Wait for all goroutines to finish
+ if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) {
+ logger.Error("Service exited with error: %v", err)
+ } else if errors.Is(err, context.Canceled) {
+ logger.Info("Context cancelled, shutting down")
+ }
}
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
@@ -639,7 +673,7 @@ func ensureMSSClamping() error {
if out, err := addCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
chain, err, string(out))
- logger.Error(errMsg)
+ logger.Error("%s", errMsg)
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
@@ -656,7 +690,7 @@ func ensureMSSClamping() error {
if out, err := checkCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
chain, err, string(out))
- logger.Error(errMsg)
+ logger.Error("%s", errMsg)
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
@@ -977,13 +1011,18 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
})
}
-func periodicBandwidthCheck(endpoint string) {
+func periodicBandwidthCheck(ctx context.Context, endpoint string) error {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
- for range ticker.C {
- if err := reportPeerBandwidth(endpoint); err != nil {
- logger.Info("Failed to report peer bandwidth: %v", err)
+ for {
+ select {
+ case <-ticker.C:
+ if err := reportPeerBandwidth(endpoint); err != nil {
+ logger.Info("Failed to report peer bandwidth: %v", err)
+ }
+ case <-ctx.Done():
+ return ctx.Err()
}
}
}
diff --git a/relay/relay.go b/relay/relay.go
index e74ed87..e3fef04 100644
--- a/relay/relay.go
+++ b/relay/relay.go
@@ -1,6 +1,7 @@
package relay
import (
+ "context"
"bytes"
"encoding/binary"
"encoding/json"
@@ -112,6 +113,8 @@ type UDPProxyServer struct {
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
privateKey wgtypes.Key
packetChan chan Packet
+ ctx context.Context
+ cancel context.CancelFunc
// Session tracking for WireGuard peers
// Key format: "senderIndex:receiverIndex"
@@ -123,14 +126,17 @@ type UDPProxyServer struct {
ReachableAt string
}
-// NewUDPProxyServer initializes the server with a buffered packet channel.
-func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
+// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
+func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
+ ctx, cancel := context.WithCancel(parentCtx)
return &UDPProxyServer{
addr: addr,
serverURL: serverURL,
privateKey: privateKey,
packetChan: make(chan Packet, 1000),
ReachableAt: reachableAt,
+ ctx: ctx,
+ cancel: cancel,
}
}
@@ -177,17 +183,51 @@ func (s *UDPProxyServer) Start() error {
}
func (s *UDPProxyServer) Stop() {
- s.conn.Close()
+ // Signal all background goroutines to stop
+ if s.cancel != nil {
+ s.cancel()
+ }
+ // Close listener to unblock reads
+ if s.conn != nil {
+ _ = s.conn.Close()
+ }
+ // Close all downstream UDP connections
+ s.connections.Range(func(key, value interface{}) bool {
+ if dc, ok := value.(*DestinationConn); ok && dc.conn != nil {
+ _ = dc.conn.Close()
+ }
+ return true
+ })
+ // Close packet channel to stop workers
+ select {
+ case <-s.ctx.Done():
+ default:
+ }
+ close(s.packetChan)
}
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
func (s *UDPProxyServer) readPackets() {
for {
+ // Exit promptly if context is canceled
+ select {
+ case <-s.ctx.Done():
+ return
+ default:
+ }
buf := bufferPool.Get().([]byte)
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
if err != nil {
- logger.Error("Error reading UDP packet: %v", err)
- continue
+ // If we're shutting down, exit
+ select {
+ case <-s.ctx.Done():
+ bufferPool.Put(buf[:1500])
+ return
+ default:
+ logger.Error("Error reading UDP packet: %v", err)
+ bufferPool.Put(buf[:1500])
+ continue
+ }
}
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
}
@@ -588,49 +628,67 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
// Add a cleanup method to periodically remove idle connections
func (s *UDPProxyServer) cleanupIdleConnections() {
ticker := time.NewTicker(5 * time.Minute)
- for range ticker.C {
- now := time.Now()
- s.connections.Range(func(key, value interface{}) bool {
- destConn := value.(*DestinationConn)
- if now.Sub(destConn.lastUsed) > 10*time.Minute {
- destConn.conn.Close()
- s.connections.Delete(key)
- }
- return true
- })
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.C:
+ now := time.Now()
+ s.connections.Range(func(key, value interface{}) bool {
+ destConn := value.(*DestinationConn)
+ if now.Sub(destConn.lastUsed) > 10*time.Minute {
+ destConn.conn.Close()
+ s.connections.Delete(key)
+ }
+ return true
+ })
+ case <-s.ctx.Done():
+ return
+ }
}
}
// New method to periodically remove idle sessions
func (s *UDPProxyServer) cleanupIdleSessions() {
ticker := time.NewTicker(5 * time.Minute)
- for range ticker.C {
- now := time.Now()
- s.wgSessions.Range(func(key, value interface{}) bool {
- session := value.(*WireGuardSession)
- if now.Sub(session.LastSeen) > 15*time.Minute {
- s.wgSessions.Delete(key)
- logger.Debug("Removed idle session: %s", key)
- }
- return true
- })
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.C:
+ now := time.Now()
+ s.wgSessions.Range(func(key, value interface{}) bool {
+ session := value.(*WireGuardSession)
+ if now.Sub(session.LastSeen) > 15*time.Minute {
+ s.wgSessions.Delete(key)
+ logger.Debug("Removed idle session: %s", key)
+ }
+ return true
+ })
+ case <-s.ctx.Done():
+ return
+ }
}
}
// New method to periodically remove idle proxy mappings
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
ticker := time.NewTicker(10 * time.Minute)
- for range ticker.C {
- now := time.Now()
- s.proxyMappings.Range(func(key, value interface{}) bool {
- mapping := value.(ProxyMapping)
- // Remove mappings that haven't been used in 30 minutes
- if now.Sub(mapping.LastUsed) > 30*time.Minute {
- s.proxyMappings.Delete(key)
- logger.Debug("Removed idle proxy mapping: %s", key)
- }
- return true
- })
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.C:
+ now := time.Now()
+ s.proxyMappings.Range(func(key, value interface{}) bool {
+ mapping := value.(ProxyMapping)
+ // Remove mappings that haven't been used in 30 minutes
+ if now.Sub(mapping.LastUsed) > 30*time.Minute {
+ s.proxyMappings.Delete(key)
+ logger.Debug("Removed idle proxy mapping: %s", key)
+ }
+ return true
+ })
+ case <-s.ctx.Done():
+ return
+ }
}
}
@@ -943,23 +1001,29 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
ticker := time.NewTicker(10 * time.Minute)
- for range ticker.C {
- now := time.Now()
- s.commPatterns.Range(func(key, value interface{}) bool {
- pattern := value.(*CommunicationPattern)
-
- // Get the most recent activity
- lastActivity := pattern.LastFromClient
- if pattern.LastFromDest.After(lastActivity) {
- lastActivity = pattern.LastFromDest
- }
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.C:
+ now := time.Now()
+ s.commPatterns.Range(func(key, value interface{}) bool {
+ pattern := value.(*CommunicationPattern)
+
+ // Get the most recent activity
+ lastActivity := pattern.LastFromClient
+ if pattern.LastFromDest.After(lastActivity) {
+ lastActivity = pattern.LastFromDest
+ }
- // Remove patterns that haven't had activity in 20 minutes
- if now.Sub(lastActivity) > 20*time.Minute {
- s.commPatterns.Delete(key)
- logger.Debug("Removed idle communication pattern: %s", key)
- }
- return true
- })
+ // Remove patterns that haven't had activity in 20 minutes
+ if now.Sub(lastActivity) > 20*time.Minute {
+ s.commPatterns.Delete(key)
+ logger.Debug("Removed idle communication pattern: %s", key)
+ }
+ return true
+ })
+ case <-s.ctx.Done():
+ return
+ }
}
}

View File

@@ -16,7 +16,7 @@ COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -o /gerbil
# Start a new stage from scratch
FROM alpine:3.22 AS runner
FROM alpine:3.23 AS runner
RUN apk add --no-cache iptables iproute2

6
go.mod
View File

@@ -5,7 +5,7 @@ go 1.25
require (
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.43.0
golang.org/x/crypto v0.45.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
)
@@ -16,8 +16,8 @@ require (
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/net v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.37.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect
)

12
go.sum
View File

@@ -16,16 +16,16 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=

85
main.go
View File

@@ -2,7 +2,9 @@ package main
import (
"bytes"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
@@ -21,6 +23,7 @@ import (
"github.com/fosrl/gerbil/proxy"
"github.com/fosrl/gerbil/relay"
"github.com/vishvananda/netlink"
"golang.org/x/sync/errgroup"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -217,6 +220,10 @@ func main() {
logger.Init()
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
// Base context for the application; cancel on SIGINT/SIGTERM
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
// try to parse as http://host:port and set the listenAddr to the :port from this reachableAt.
if reachableAt != "" && listenAddr == "" {
if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") {
@@ -324,10 +331,16 @@ func main() {
// Ensure the WireGuard peers exist
ensureWireguardPeers(wgconfig.Peers)
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
// Child error group derived from base context
group, groupCtx := errgroup.WithContext(ctx)
// Periodic bandwidth reporting
group.Go(func() error {
return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth")
})
// Start the UDP proxy server
proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
proxyRelay = relay.NewUDPProxyServer(groupCtx, ":21820", remoteConfigURL, key, reachableAt)
err = proxyRelay.Start()
if err != nil {
logger.Fatal("Failed to start UDP proxy server: %v", err)
@@ -371,18 +384,39 @@ func main() {
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
logger.Info("Starting HTTP server on %s", listenAddr)
// Run HTTP server in a goroutine
go func() {
if err := http.ListenAndServe(listenAddr, nil); err != nil {
logger.Error("HTTP server failed: %v", err)
// HTTP server with graceful shutdown on context cancel
server := &http.Server{
Addr: listenAddr,
Handler: nil,
}
}()
group.Go(func() error {
// http.ErrServerClosed is returned on graceful shutdown; not an error for us
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
})
group.Go(func() error {
<-groupCtx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = server.Shutdown(shutdownCtx)
// Stop background components as the context is canceled
if proxySNI != nil {
_ = proxySNI.Stop()
}
if proxyRelay != nil {
proxyRelay.Stop()
}
return nil
})
// Keep the main goroutine running
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
logger.Info("Shutting down servers...")
// Wait for all goroutines to finish
if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) {
logger.Error("Service exited with error: %v", err)
} else if errors.Is(err, context.Canceled) {
logger.Info("Context cancelled, shutting down")
}
}
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
@@ -639,7 +673,7 @@ func ensureMSSClamping() error {
if out, err := addCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
chain, err, string(out))
logger.Error(errMsg)
logger.Error("%s", errMsg)
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
@@ -656,7 +690,7 @@ func ensureMSSClamping() error {
if out, err := checkCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
chain, err, string(out))
logger.Error(errMsg)
logger.Error("%s", errMsg)
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
@@ -977,14 +1011,19 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
})
}
func periodicBandwidthCheck(endpoint string) {
func periodicBandwidthCheck(ctx context.Context, endpoint string) error {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
for {
select {
case <-ticker.C:
if err := reportPeerBandwidth(endpoint); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err)
}
case <-ctx.Done():
return ctx.Err()
}
}
}
@@ -1003,8 +1042,13 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
mu.Lock()
defer mu.Unlock()
// Track the set of peers currently present on the device to prune stale readings efficiently
currentPeerKeys := make(map[string]struct{}, len(device.Peers))
for _, peer := range device.Peers {
publicKey := peer.PublicKey.String()
currentPeerKeys[publicKey] = struct{}{}
currentReading := PeerReading{
BytesReceived: peer.ReceiveBytes,
BytesTransmitted: peer.TransmitBytes,
@@ -1061,14 +1105,7 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
// Clean up old peers
for publicKey := range lastReadings {
found := false
for _, peer := range device.Peers {
if peer.PublicKey.String() == publicKey {
found = true
break
}
}
if !found {
if _, exists := currentPeerKeys[publicKey]; !exists {
delete(lastReadings, publicKey)
}
}

View File

@@ -2,6 +2,7 @@ package relay
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
@@ -60,12 +61,41 @@ type DestinationConn struct {
// Type for storing WireGuard handshake information
type WireGuardSession struct {
mu sync.RWMutex
ReceiverIndex uint32
SenderIndex uint32
DestAddr *net.UDPAddr
LastSeen time.Time
}
// GetSenderIndex returns the SenderIndex in a thread-safe manner
func (s *WireGuardSession) GetSenderIndex() uint32 {
s.mu.RLock()
defer s.mu.RUnlock()
return s.SenderIndex
}
// GetDestAddr returns the DestAddr in a thread-safe manner
func (s *WireGuardSession) GetDestAddr() *net.UDPAddr {
s.mu.RLock()
defer s.mu.RUnlock()
return s.DestAddr
}
// GetLastSeen returns the LastSeen timestamp in a thread-safe manner
func (s *WireGuardSession) GetLastSeen() time.Time {
s.mu.RLock()
defer s.mu.RUnlock()
return s.LastSeen
}
// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner
func (s *WireGuardSession) UpdateLastSeen() {
s.mu.Lock()
defer s.mu.Unlock()
s.LastSeen = time.Now()
}
// Type for tracking bidirectional communication patterns to rebuild sessions
type CommunicationPattern struct {
FromClient *net.UDPAddr // The client address
@@ -114,6 +144,8 @@ type UDPProxyServer struct {
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
privateKey wgtypes.Key
packetChan chan Packet
ctx context.Context
cancel context.CancelFunc
// Session tracking for WireGuard peers
// Key format: "senderIndex:receiverIndex"
@@ -125,14 +157,17 @@ type UDPProxyServer struct {
ReachableAt string
}
// NewUDPProxyServer initializes the server with a buffered packet channel.
func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
ctx, cancel := context.WithCancel(parentCtx)
return &UDPProxyServer{
addr: addr,
serverURL: serverURL,
privateKey: privateKey,
packetChan: make(chan Packet, 1000),
ReachableAt: reachableAt,
ctx: ctx,
cancel: cancel,
}
}
@@ -179,18 +214,52 @@ func (s *UDPProxyServer) Start() error {
}
func (s *UDPProxyServer) Stop() {
s.conn.Close()
// Signal all background goroutines to stop
if s.cancel != nil {
s.cancel()
}
// Close listener to unblock reads
if s.conn != nil {
_ = s.conn.Close()
}
// Close all downstream UDP connections
s.connections.Range(func(key, value interface{}) bool {
if dc, ok := value.(*DestinationConn); ok && dc.conn != nil {
_ = dc.conn.Close()
}
return true
})
// Close packet channel to stop workers
select {
case <-s.ctx.Done():
default:
}
close(s.packetChan)
}
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
func (s *UDPProxyServer) readPackets() {
for {
// Exit promptly if context is canceled
select {
case <-s.ctx.Done():
return
default:
}
buf := bufferPool.Get().([]byte)
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
if err != nil {
// If we're shutting down, exit
select {
case <-s.ctx.Done():
bufferPool.Put(buf[:1500])
return
default:
logger.Error("Error reading UDP packet: %v", err)
bufferPool.Put(buf[:1500])
continue
}
}
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
}
}
@@ -445,13 +514,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
// First check for existing sessions to see if we know where to send this packet
s.wgSessions.Range(func(k, v interface{}) bool {
session := v.(*WireGuardSession)
if session.SenderIndex == receiverIndex {
// Found matching session
destAddr = session.DestAddr
// Update last seen time
session.LastSeen = time.Now()
s.wgSessions.Store(k, session)
// Check if session matches (read lock for check)
if session.GetSenderIndex() == receiverIndex {
// Found matching session - get dest addr and update last seen
destAddr = session.GetDestAddr()
session.UpdateLastSeen()
return false // stop iteration
}
return true // continue iteration
@@ -591,7 +658,10 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
// Add a cleanup method to periodically remove idle connections
func (s *UDPProxyServer) cleanupIdleConnections() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.connections.Range(func(key, value interface{}) bool {
destConn := value.(*DestinationConn)
@@ -601,29 +671,43 @@ func (s *UDPProxyServer) cleanupIdleConnections() {
}
return true
})
case <-s.ctx.Done():
return
}
}
}
// New method to periodically remove idle sessions
func (s *UDPProxyServer) cleanupIdleSessions() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.wgSessions.Range(func(key, value interface{}) bool {
session := value.(*WireGuardSession)
if now.Sub(session.LastSeen) > 15*time.Minute {
// Use thread-safe method to read LastSeen
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
s.wgSessions.Delete(key)
logger.Debug("Removed idle session: %s", key)
}
return true
})
case <-s.ctx.Done():
return
}
}
}
// New method to periodically remove idle proxy mappings
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
ticker := time.NewTicker(10 * time.Minute)
for range ticker.C {
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.proxyMappings.Range(func(key, value interface{}) bool {
mapping := value.(ProxyMapping)
@@ -634,6 +718,9 @@ func (s *UDPProxyServer) cleanupIdleProxyMappings() {
}
return true
})
case <-s.ctx.Done():
return
}
}
}
@@ -738,8 +825,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) {
keyStr := key.(string)
session := value.(*WireGuardSession)
// Check if the session's destination address contains the WG IP
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
// Check if the session's destination address contains the WG IP (thread-safe)
destAddr := session.GetDestAddr()
if destAddr != nil && destAddr.IP.String() == ip {
keysToDelete = append(keysToDelete, keyStr)
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
}
@@ -929,14 +1017,12 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
// Check if we already have this session
if _, exists := s.wgSessions.Load(sessionKey); !exists {
session := &WireGuardSession{
s.wgSessions.Store(sessionKey, &WireGuardSession{
ReceiverIndex: pattern.DestIndex,
SenderIndex: pattern.ClientIndex,
DestAddr: pattern.ToDestination,
LastSeen: time.Now(),
}
s.wgSessions.Store(sessionKey, session)
})
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
}
@@ -946,7 +1032,10 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
ticker := time.NewTicker(10 * time.Minute)
for range ticker.C {
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.commPatterns.Range(func(key, value interface{}) bool {
pattern := value.(*CommunicationPattern)
@@ -964,5 +1053,8 @@ func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
}
return true
})
case <-s.ctx.Done():
return
}
}
}