diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 3402a0a..bf6dfc5 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -1,52 +1,160 @@ name: CI/CD Pipeline +# CI/CD workflow for building, publishing, mirroring, signing container images and building release binaries. +# Actions are pinned to specific SHAs to reduce supply-chain risk. This workflow triggers on tag push events. + +permissions: + contents: read + packages: write # for GHCR push + id-token: write # for Cosign Keyless (OIDC) Signing + +# Required secrets: +# - DOCKER_HUB_USERNAME / DOCKER_HUB_ACCESS_TOKEN: push to Docker Hub +# - GITHUB_TOKEN: used for GHCR login and OIDC keyless signing +# - COSIGN_PRIVATE_KEY / COSIGN_PASSWORD / COSIGN_PUBLIC_KEY: for key-based signing + on: - push: - tags: - - "*" + push: + tags: + - "*" + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true jobs: - release: - name: Build and Release - runs-on: amd64-runner + release: + name: Build and Release + runs-on: amd64-runner + # Job-level timeout to avoid runaway or stuck runs + timeout-minutes: 120 + env: + # Target images + DOCKERHUB_IMAGE: docker.io/${{ secrets.DOCKER_HUB_USERNAME }}/${{ github.event.repository.name }} + GHCR_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }} - steps: - - name: Checkout code - uses: actions/checkout@v5 + steps: + - name: Checkout code + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + - name: Set up QEMU + uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1 - - name: Log in to Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKER_HUB_USERNAME }} - password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }} + - name: Log in to Docker Hub + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 + with: + registry: docker.io + username: ${{ secrets.DOCKER_HUB_USERNAME }} + password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }} - - name: Extract tag name - id: get-tag - run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV + - name: Extract tag name + id: get-tag + run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV + shell: bash - - name: Install Go - uses: actions/setup-go@v6 - with: - go-version: 1.25 + - name: Install Go + uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0 + with: + go-version: 1.25 - - name: Build and push Docker images - run: | - TAG=${{ env.TAG }} - make docker-build-release tag=$TAG + - name: Update version in main.go + run: | + TAG=${{ env.TAG }} + if [ -f main.go ]; then + sed -i 's/version_replaceme/'"$TAG"'/' main.go + echo "Updated main.go with version $TAG" + else + echo "main.go not found" + fi + shell: bash - - name: Build binaries - run: | - make go-build-release + - name: Build and push Docker images (Docker Hub) + run: | + TAG=${{ env.TAG }} + make docker-build-release tag=$TAG + echo "Built & pushed to: ${{ env.DOCKERHUB_IMAGE }}:${TAG}" + shell: bash - - name: Upload artifacts from /bin - uses: actions/upload-artifact@v4 - with: - name: binaries - path: bin/ + - name: Login in to GHCR + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Install skopeo + jq + # skopeo: copy/inspect images between registries + # jq: JSON parsing tool used to extract digest values + run: | + sudo apt-get update -y + sudo apt-get install -y skopeo jq + skopeo --version + shell: bash + + - name: Copy tag from Docker Hub to GHCR + # Mirror the already-built image (all architectures) to GHCR so we can sign it + run: | + set -euo pipefail + TAG=${{ env.TAG }} + echo "Copying ${{ env.DOCKERHUB_IMAGE }}:${TAG} -> ${{ env.GHCR_IMAGE }}:${TAG}" + skopeo copy --all --retry-times 3 \ + docker://$DOCKERHUB_IMAGE:$TAG \ + docker://$GHCR_IMAGE:$TAG + shell: bash + + - name: Install cosign + # cosign is used to sign and verify container images (key and keyless) + uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0 + + - name: Dual-sign and verify (GHCR & Docker Hub) + # Sign each image by digest using keyless (OIDC) and key-based signing, + # then verify both the public key signature and the keyless OIDC signature. + env: + TAG: ${{ env.TAG }} + COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }} + COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }} + COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }} + COSIGN_YES: "true" + run: | + set -euo pipefail + + issuer="https://token.actions.githubusercontent.com" + id_regex="^https://github.com/${{ github.repository }}/.+" # accept this repo (all workflows/refs) + + for IMAGE in "${GHCR_IMAGE}" "${DOCKERHUB_IMAGE}"; do + echo "Processing ${IMAGE}:${TAG}" + + DIGEST="$(skopeo inspect --retry-times 3 docker://${IMAGE}:${TAG} | jq -r '.Digest')" + REF="${IMAGE}@${DIGEST}" + echo "Resolved digest: ${REF}" + + echo "==> cosign sign (keyless) --recursive ${REF}" + cosign sign --recursive "${REF}" + + echo "==> cosign sign (key) --recursive ${REF}" + cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${REF}" + + echo "==> cosign verify (public key) ${REF}" + cosign verify --key env://COSIGN_PUBLIC_KEY "${REF}" -o text + + echo "==> cosign verify (keyless policy) ${REF}" + cosign verify \ + --certificate-oidc-issuer "${issuer}" \ + --certificate-identity-regexp "${id_regex}" \ + "${REF}" -o text + done + shell: bash + + - name: Build binaries + run: | + make go-build-release + shell: bash + + - name: Upload artifacts from /bin + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 + with: + name: binaries + path: bin/ diff --git a/.github/workflows/mirror.yaml b/.github/workflows/mirror.yaml new file mode 100644 index 0000000..793073e --- /dev/null +++ b/.github/workflows/mirror.yaml @@ -0,0 +1,132 @@ +name: Mirror & Sign (Docker Hub to GHCR) + +on: + workflow_dispatch: {} + +permissions: + contents: read + packages: write + id-token: write # for keyless OIDC + +env: + SOURCE_IMAGE: docker.io/fosrl/gerbil + DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }} + +jobs: + mirror-and-dual-sign: + runs-on: amd64-runner + steps: + - name: Install skopeo + jq + run: | + sudo apt-get update -y + sudo apt-get install -y skopeo jq + skopeo --version + + - name: Install cosign + uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0 + + - name: Input check + run: | + test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1) + echo "Source : ${SOURCE_IMAGE}" + echo "Target : ${DEST_IMAGE}" + + # Auth for skopeo (containers-auth) + - name: Skopeo login to GHCR + run: | + skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}" + + # Auth for cosign (docker-config) + - name: Docker login to GHCR (for cosign) + run: | + echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin + + - name: List source tags + run: | + set -euo pipefail + skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \ + | jq -r '.Tags[]' | sort -u > src-tags.txt + echo "Found source tags: $(wc -l < src-tags.txt)" + head -n 20 src-tags.txt || true + + - name: List destination tags (skip existing) + run: | + set -euo pipefail + if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then + jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt + else + : > dst-tags.txt + fi + echo "Existing destination tags: $(wc -l < dst-tags.txt)" + + - name: Mirror, dual-sign, and verify + env: + # keyless + COSIGN_YES: "true" + # key-based + COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }} + COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }} + # verify + COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }} + run: | + set -euo pipefail + copied=0; skipped=0; v_ok=0; errs=0 + + issuer="https://token.actions.githubusercontent.com" + id_regex="^https://github.com/${{ github.repository }}/.+" + + while read -r tag; do + [ -z "$tag" ] && continue + + if grep -Fxq "$tag" dst-tags.txt; then + echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}" + skipped=$((skipped+1)) + continue + fi + + echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}" + if ! skopeo copy --all --retry-times 3 \ + docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then + echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}" + errs=$((errs+1)); continue + fi + copied=$((copied+1)) + + digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')" + ref="${DEST_IMAGE}@${digest}" + + echo "==> cosign sign (keyless) --recursive ${ref}" + if ! cosign sign --recursive "${ref}"; then + echo "::warning title=Keyless sign failed::${ref}" + errs=$((errs+1)) + fi + + echo "==> cosign sign (key) --recursive ${ref}" + if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then + echo "::warning title=Key sign failed::${ref}" + errs=$((errs+1)) + fi + + echo "==> cosign verify (public key) ${ref}" + if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then + echo "::warning title=Verify(pubkey) failed::${ref}" + errs=$((errs+1)) + fi + + echo "==> cosign verify (keyless policy) ${ref}" + if ! cosign verify \ + --certificate-oidc-issuer "${issuer}" \ + --certificate-identity-regexp "${id_regex}" \ + "${ref}" -o text; then + echo "::warning title=Verify(keyless) failed::${ref}" + errs=$((errs+1)) + else + v_ok=$((v_ok+1)) + fi + done < src-tags.txt + + echo "---- Summary ----" + echo "Copied : $copied" + echo "Skipped : $skipped" + echo "Verified OK : $v_ok" + echo "Errors : $errs" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1b9637e..5581708 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,5 +1,8 @@ name: Run Tests +permissions: + contents: read + on: pull_request: branches: @@ -8,15 +11,15 @@ on: jobs: test: - runs-on: ubuntu-latest + runs-on: amd64-runner steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Set up Go - uses: actions/setup-go@v6 + uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0 with: - go-version: '1.25' + go-version: 1.25 - name: Build go run: go build diff --git a/37.diff b/37.diff new file mode 100644 index 0000000..d80429c --- /dev/null +++ b/37.diff @@ -0,0 +1,385 @@ +diff --git a/main.go b/main.go +index 7a99c4d..61c186f 100644 +--- a/main.go ++++ b/main.go +@@ -2,7 +2,9 @@ package main + + import ( + "bytes" ++ "context" + "encoding/json" ++ "errors" + "flag" + "fmt" + "io" +@@ -21,6 +23,7 @@ import ( + "github.com/fosrl/gerbil/proxy" + "github.com/fosrl/gerbil/relay" + "github.com/vishvananda/netlink" ++ "golang.org/x/sync/errgroup" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + ) +@@ -217,6 +220,10 @@ func main() { + logger.Init() + logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + ++ // Base context for the application; cancel on SIGINT/SIGTERM ++ ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) ++ defer stop() ++ + // try to parse as http://host:port and set the listenAddr to the :port from this reachableAt. + if reachableAt != "" && listenAddr == "" { + if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") { +@@ -324,10 +331,16 @@ func main() { + // Ensure the WireGuard peers exist + ensureWireguardPeers(wgconfig.Peers) + +- go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth") ++ // Child error group derived from base context ++ group, groupCtx := errgroup.WithContext(ctx) ++ ++ // Periodic bandwidth reporting ++ group.Go(func() error { ++ return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth") ++ }) + + // Start the UDP proxy server +- proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) ++ proxyRelay = relay.NewUDPProxyServer(groupCtx, ":21820", remoteConfigURL, key, reachableAt) + err = proxyRelay.Start() + if err != nil { + logger.Fatal("Failed to start UDP proxy server: %v", err) +@@ -371,18 +384,39 @@ func main() { + http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs) + logger.Info("Starting HTTP server on %s", listenAddr) + +- // Run HTTP server in a goroutine +- go func() { +- if err := http.ListenAndServe(listenAddr, nil); err != nil { +- logger.Error("HTTP server failed: %v", err) ++ // HTTP server with graceful shutdown on context cancel ++ server := &http.Server{ ++ Addr: listenAddr, ++ Handler: nil, ++ } ++ group.Go(func() error { ++ // http.ErrServerClosed is returned on graceful shutdown; not an error for us ++ if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { ++ return err ++ } ++ return nil ++ }) ++ group.Go(func() error { ++ <-groupCtx.Done() ++ shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ++ defer cancel() ++ _ = server.Shutdown(shutdownCtx) ++ // Stop background components as the context is canceled ++ if proxySNI != nil { ++ _ = proxySNI.Stop() ++ } ++ if proxyRelay != nil { ++ proxyRelay.Stop() + } +- }() ++ return nil ++ }) + +- // Keep the main goroutine running +- sigCh := make(chan os.Signal, 1) +- signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) +- <-sigCh +- logger.Info("Shutting down servers...") ++ // Wait for all goroutines to finish ++ if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) { ++ logger.Error("Service exited with error: %v", err) ++ } else if errors.Is(err, context.Canceled) { ++ logger.Info("Context cancelled, shutting down") ++ } + } + + func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { +@@ -639,7 +673,7 @@ func ensureMSSClamping() error { + if out, err := addCmd.CombinedOutput(); err != nil { + errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", + chain, err, string(out)) +- logger.Error(errMsg) ++ logger.Error("%s", errMsg) + errors = append(errors, fmt.Errorf("%s", errMsg)) + continue + } +@@ -656,7 +690,7 @@ func ensureMSSClamping() error { + if out, err := checkCmd.CombinedOutput(); err != nil { + errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", + chain, err, string(out)) +- logger.Error(errMsg) ++ logger.Error("%s", errMsg) + errors = append(errors, fmt.Errorf("%s", errMsg)) + continue + } +@@ -977,13 +1011,18 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) { + }) + } + +-func periodicBandwidthCheck(endpoint string) { ++func periodicBandwidthCheck(ctx context.Context, endpoint string) error { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + +- for range ticker.C { +- if err := reportPeerBandwidth(endpoint); err != nil { +- logger.Info("Failed to report peer bandwidth: %v", err) ++ for { ++ select { ++ case <-ticker.C: ++ if err := reportPeerBandwidth(endpoint); err != nil { ++ logger.Info("Failed to report peer bandwidth: %v", err) ++ } ++ case <-ctx.Done(): ++ return ctx.Err() + } + } + } +diff --git a/relay/relay.go b/relay/relay.go +index e74ed87..e3fef04 100644 +--- a/relay/relay.go ++++ b/relay/relay.go +@@ -1,6 +1,7 @@ + package relay + + import ( ++ "context" + "bytes" + "encoding/binary" + "encoding/json" +@@ -112,6 +113,8 @@ type UDPProxyServer struct { + connections sync.Map // map[string]*DestinationConn where key is destination "ip:port" + privateKey wgtypes.Key + packetChan chan Packet ++ ctx context.Context ++ cancel context.CancelFunc + + // Session tracking for WireGuard peers + // Key format: "senderIndex:receiverIndex" +@@ -123,14 +126,17 @@ type UDPProxyServer struct { + ReachableAt string + } + +-// NewUDPProxyServer initializes the server with a buffered packet channel. +-func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer { ++// NewUDPProxyServer initializes the server with a buffered packet channel and derived context. ++func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer { ++ ctx, cancel := context.WithCancel(parentCtx) + return &UDPProxyServer{ + addr: addr, + serverURL: serverURL, + privateKey: privateKey, + packetChan: make(chan Packet, 1000), + ReachableAt: reachableAt, ++ ctx: ctx, ++ cancel: cancel, + } + } + +@@ -177,17 +183,51 @@ func (s *UDPProxyServer) Start() error { + } + + func (s *UDPProxyServer) Stop() { +- s.conn.Close() ++ // Signal all background goroutines to stop ++ if s.cancel != nil { ++ s.cancel() ++ } ++ // Close listener to unblock reads ++ if s.conn != nil { ++ _ = s.conn.Close() ++ } ++ // Close all downstream UDP connections ++ s.connections.Range(func(key, value interface{}) bool { ++ if dc, ok := value.(*DestinationConn); ok && dc.conn != nil { ++ _ = dc.conn.Close() ++ } ++ return true ++ }) ++ // Close packet channel to stop workers ++ select { ++ case <-s.ctx.Done(): ++ default: ++ } ++ close(s.packetChan) + } + + // readPackets continuously reads from the UDP socket and pushes packets into the channel. + func (s *UDPProxyServer) readPackets() { + for { ++ // Exit promptly if context is canceled ++ select { ++ case <-s.ctx.Done(): ++ return ++ default: ++ } + buf := bufferPool.Get().([]byte) + n, remoteAddr, err := s.conn.ReadFromUDP(buf) + if err != nil { +- logger.Error("Error reading UDP packet: %v", err) +- continue ++ // If we're shutting down, exit ++ select { ++ case <-s.ctx.Done(): ++ bufferPool.Put(buf[:1500]) ++ return ++ default: ++ logger.Error("Error reading UDP packet: %v", err) ++ bufferPool.Put(buf[:1500]) ++ continue ++ } + } + s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n} + } +@@ -588,49 +628,67 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd + // Add a cleanup method to periodically remove idle connections + func (s *UDPProxyServer) cleanupIdleConnections() { + ticker := time.NewTicker(5 * time.Minute) +- for range ticker.C { +- now := time.Now() +- s.connections.Range(func(key, value interface{}) bool { +- destConn := value.(*DestinationConn) +- if now.Sub(destConn.lastUsed) > 10*time.Minute { +- destConn.conn.Close() +- s.connections.Delete(key) +- } +- return true +- }) ++ defer ticker.Stop() ++ for { ++ select { ++ case <-ticker.C: ++ now := time.Now() ++ s.connections.Range(func(key, value interface{}) bool { ++ destConn := value.(*DestinationConn) ++ if now.Sub(destConn.lastUsed) > 10*time.Minute { ++ destConn.conn.Close() ++ s.connections.Delete(key) ++ } ++ return true ++ }) ++ case <-s.ctx.Done(): ++ return ++ } + } + } + + // New method to periodically remove idle sessions + func (s *UDPProxyServer) cleanupIdleSessions() { + ticker := time.NewTicker(5 * time.Minute) +- for range ticker.C { +- now := time.Now() +- s.wgSessions.Range(func(key, value interface{}) bool { +- session := value.(*WireGuardSession) +- if now.Sub(session.LastSeen) > 15*time.Minute { +- s.wgSessions.Delete(key) +- logger.Debug("Removed idle session: %s", key) +- } +- return true +- }) ++ defer ticker.Stop() ++ for { ++ select { ++ case <-ticker.C: ++ now := time.Now() ++ s.wgSessions.Range(func(key, value interface{}) bool { ++ session := value.(*WireGuardSession) ++ if now.Sub(session.LastSeen) > 15*time.Minute { ++ s.wgSessions.Delete(key) ++ logger.Debug("Removed idle session: %s", key) ++ } ++ return true ++ }) ++ case <-s.ctx.Done(): ++ return ++ } + } + } + + // New method to periodically remove idle proxy mappings + func (s *UDPProxyServer) cleanupIdleProxyMappings() { + ticker := time.NewTicker(10 * time.Minute) +- for range ticker.C { +- now := time.Now() +- s.proxyMappings.Range(func(key, value interface{}) bool { +- mapping := value.(ProxyMapping) +- // Remove mappings that haven't been used in 30 minutes +- if now.Sub(mapping.LastUsed) > 30*time.Minute { +- s.proxyMappings.Delete(key) +- logger.Debug("Removed idle proxy mapping: %s", key) +- } +- return true +- }) ++ defer ticker.Stop() ++ for { ++ select { ++ case <-ticker.C: ++ now := time.Now() ++ s.proxyMappings.Range(func(key, value interface{}) bool { ++ mapping := value.(ProxyMapping) ++ // Remove mappings that haven't been used in 30 minutes ++ if now.Sub(mapping.LastUsed) > 30*time.Minute { ++ s.proxyMappings.Delete(key) ++ logger.Debug("Removed idle proxy mapping: %s", key) ++ } ++ return true ++ }) ++ case <-s.ctx.Done(): ++ return ++ } + } + } + +@@ -943,23 +1001,29 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { + // cleanupIdleCommunicationPatterns periodically removes idle communication patterns + func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { + ticker := time.NewTicker(10 * time.Minute) +- for range ticker.C { +- now := time.Now() +- s.commPatterns.Range(func(key, value interface{}) bool { +- pattern := value.(*CommunicationPattern) +- +- // Get the most recent activity +- lastActivity := pattern.LastFromClient +- if pattern.LastFromDest.After(lastActivity) { +- lastActivity = pattern.LastFromDest +- } ++ defer ticker.Stop() ++ for { ++ select { ++ case <-ticker.C: ++ now := time.Now() ++ s.commPatterns.Range(func(key, value interface{}) bool { ++ pattern := value.(*CommunicationPattern) ++ ++ // Get the most recent activity ++ lastActivity := pattern.LastFromClient ++ if pattern.LastFromDest.After(lastActivity) { ++ lastActivity = pattern.LastFromDest ++ } + +- // Remove patterns that haven't had activity in 20 minutes +- if now.Sub(lastActivity) > 20*time.Minute { +- s.commPatterns.Delete(key) +- logger.Debug("Removed idle communication pattern: %s", key) +- } +- return true +- }) ++ // Remove patterns that haven't had activity in 20 minutes ++ if now.Sub(lastActivity) > 20*time.Minute { ++ s.commPatterns.Delete(key) ++ logger.Debug("Removed idle communication pattern: %s", key) ++ } ++ return true ++ }) ++ case <-s.ctx.Done(): ++ return ++ } + } + } diff --git a/Dockerfile b/Dockerfile index 8b94de3..d3ccae2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,7 @@ COPY . . RUN CGO_ENABLED=0 GOOS=linux go build -o /gerbil # Start a new stage from scratch -FROM alpine:3.22 AS runner +FROM alpine:3.23 AS runner RUN apk add --no-cache iptables iproute2 diff --git a/go.mod b/go.mod index 72f9c0f..a47ae8b 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.25 require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.43.0 + golang.org/x/crypto v0.45.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 ) @@ -16,8 +16,8 @@ require ( github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.4.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect - golang.org/x/net v0.45.0 // indirect + golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.1.0 // indirect - golang.org/x/sys v0.37.0 // indirect + golang.org/x/sys v0.38.0 // indirect golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect ) diff --git a/go.sum b/go.sum index bd7354b..4b4298f 100644 --- a/go.sum +++ b/go.sum @@ -16,16 +16,16 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM= -golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= diff --git a/main.go b/main.go index 7a99c4d..6352b4b 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,9 @@ package main import ( "bytes" + "context" "encoding/json" + "errors" "flag" "fmt" "io" @@ -21,6 +23,7 @@ import ( "github.com/fosrl/gerbil/proxy" "github.com/fosrl/gerbil/relay" "github.com/vishvananda/netlink" + "golang.org/x/sync/errgroup" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -217,6 +220,10 @@ func main() { logger.Init() logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + // Base context for the application; cancel on SIGINT/SIGTERM + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + // try to parse as http://host:port and set the listenAddr to the :port from this reachableAt. if reachableAt != "" && listenAddr == "" { if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") { @@ -324,10 +331,16 @@ func main() { // Ensure the WireGuard peers exist ensureWireguardPeers(wgconfig.Peers) - go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth") + // Child error group derived from base context + group, groupCtx := errgroup.WithContext(ctx) + + // Periodic bandwidth reporting + group.Go(func() error { + return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth") + }) // Start the UDP proxy server - proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) + proxyRelay = relay.NewUDPProxyServer(groupCtx, ":21820", remoteConfigURL, key, reachableAt) err = proxyRelay.Start() if err != nil { logger.Fatal("Failed to start UDP proxy server: %v", err) @@ -371,18 +384,39 @@ func main() { http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs) logger.Info("Starting HTTP server on %s", listenAddr) - // Run HTTP server in a goroutine - go func() { - if err := http.ListenAndServe(listenAddr, nil); err != nil { - logger.Error("HTTP server failed: %v", err) + // HTTP server with graceful shutdown on context cancel + server := &http.Server{ + Addr: listenAddr, + Handler: nil, + } + group.Go(func() error { + // http.ErrServerClosed is returned on graceful shutdown; not an error for us + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err } - }() + return nil + }) + group.Go(func() error { + <-groupCtx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) + // Stop background components as the context is canceled + if proxySNI != nil { + _ = proxySNI.Stop() + } + if proxyRelay != nil { + proxyRelay.Stop() + } + return nil + }) - // Keep the main goroutine running - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - <-sigCh - logger.Info("Shutting down servers...") + // Wait for all goroutines to finish + if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) { + logger.Error("Service exited with error: %v", err) + } else if errors.Is(err, context.Canceled) { + logger.Info("Context cancelled, shutting down") + } } func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { @@ -639,7 +673,7 @@ func ensureMSSClamping() error { if out, err := addCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", chain, err, string(out)) - logger.Error(errMsg) + logger.Error("%s", errMsg) errors = append(errors, fmt.Errorf("%s", errMsg)) continue } @@ -656,7 +690,7 @@ func ensureMSSClamping() error { if out, err := checkCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", chain, err, string(out)) - logger.Error(errMsg) + logger.Error("%s", errMsg) errors = append(errors, fmt.Errorf("%s", errMsg)) continue } @@ -977,13 +1011,18 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) { }) } -func periodicBandwidthCheck(endpoint string) { +func periodicBandwidthCheck(ctx context.Context, endpoint string) error { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() - for range ticker.C { - if err := reportPeerBandwidth(endpoint); err != nil { - logger.Info("Failed to report peer bandwidth: %v", err) + for { + select { + case <-ticker.C: + if err := reportPeerBandwidth(endpoint); err != nil { + logger.Info("Failed to report peer bandwidth: %v", err) + } + case <-ctx.Done(): + return ctx.Err() } } } @@ -1003,8 +1042,13 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) { mu.Lock() defer mu.Unlock() + // Track the set of peers currently present on the device to prune stale readings efficiently + currentPeerKeys := make(map[string]struct{}, len(device.Peers)) + for _, peer := range device.Peers { publicKey := peer.PublicKey.String() + currentPeerKeys[publicKey] = struct{}{} + currentReading := PeerReading{ BytesReceived: peer.ReceiveBytes, BytesTransmitted: peer.TransmitBytes, @@ -1061,14 +1105,7 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) { // Clean up old peers for publicKey := range lastReadings { - found := false - for _, peer := range device.Peers { - if peer.PublicKey.String() == publicKey { - found = true - break - } - } - if !found { + if _, exists := currentPeerKeys[publicKey]; !exists { delete(lastReadings, publicKey) } } diff --git a/relay/relay.go b/relay/relay.go index 4f2483d..59faa4d 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -2,6 +2,7 @@ package relay import ( "bytes" + "context" "encoding/binary" "encoding/json" "fmt" @@ -60,12 +61,41 @@ type DestinationConn struct { // Type for storing WireGuard handshake information type WireGuardSession struct { + mu sync.RWMutex ReceiverIndex uint32 SenderIndex uint32 DestAddr *net.UDPAddr LastSeen time.Time } +// GetSenderIndex returns the SenderIndex in a thread-safe manner +func (s *WireGuardSession) GetSenderIndex() uint32 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.SenderIndex +} + +// GetDestAddr returns the DestAddr in a thread-safe manner +func (s *WireGuardSession) GetDestAddr() *net.UDPAddr { + s.mu.RLock() + defer s.mu.RUnlock() + return s.DestAddr +} + +// GetLastSeen returns the LastSeen timestamp in a thread-safe manner +func (s *WireGuardSession) GetLastSeen() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.LastSeen +} + +// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner +func (s *WireGuardSession) UpdateLastSeen() { + s.mu.Lock() + defer s.mu.Unlock() + s.LastSeen = time.Now() +} + // Type for tracking bidirectional communication patterns to rebuild sessions type CommunicationPattern struct { FromClient *net.UDPAddr // The client address @@ -114,6 +144,8 @@ type UDPProxyServer struct { connections sync.Map // map[string]*DestinationConn where key is destination "ip:port" privateKey wgtypes.Key packetChan chan Packet + ctx context.Context + cancel context.CancelFunc // Session tracking for WireGuard peers // Key format: "senderIndex:receiverIndex" @@ -125,14 +157,17 @@ type UDPProxyServer struct { ReachableAt string } -// NewUDPProxyServer initializes the server with a buffered packet channel. -func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer { +// NewUDPProxyServer initializes the server with a buffered packet channel and derived context. +func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer { + ctx, cancel := context.WithCancel(parentCtx) return &UDPProxyServer{ addr: addr, serverURL: serverURL, privateKey: privateKey, packetChan: make(chan Packet, 1000), ReachableAt: reachableAt, + ctx: ctx, + cancel: cancel, } } @@ -179,17 +214,51 @@ func (s *UDPProxyServer) Start() error { } func (s *UDPProxyServer) Stop() { - s.conn.Close() + // Signal all background goroutines to stop + if s.cancel != nil { + s.cancel() + } + // Close listener to unblock reads + if s.conn != nil { + _ = s.conn.Close() + } + // Close all downstream UDP connections + s.connections.Range(func(key, value interface{}) bool { + if dc, ok := value.(*DestinationConn); ok && dc.conn != nil { + _ = dc.conn.Close() + } + return true + }) + // Close packet channel to stop workers + select { + case <-s.ctx.Done(): + default: + } + close(s.packetChan) } // readPackets continuously reads from the UDP socket and pushes packets into the channel. func (s *UDPProxyServer) readPackets() { for { + // Exit promptly if context is canceled + select { + case <-s.ctx.Done(): + return + default: + } buf := bufferPool.Get().([]byte) n, remoteAddr, err := s.conn.ReadFromUDP(buf) if err != nil { - logger.Error("Error reading UDP packet: %v", err) - continue + // If we're shutting down, exit + select { + case <-s.ctx.Done(): + bufferPool.Put(buf[:1500]) + return + default: + logger.Error("Error reading UDP packet: %v", err) + bufferPool.Put(buf[:1500]) + continue + } } s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n} } @@ -445,13 +514,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // First check for existing sessions to see if we know where to send this packet s.wgSessions.Range(func(k, v interface{}) bool { session := v.(*WireGuardSession) - if session.SenderIndex == receiverIndex { - // Found matching session - destAddr = session.DestAddr - - // Update last seen time - session.LastSeen = time.Now() - s.wgSessions.Store(k, session) + // Check if session matches (read lock for check) + if session.GetSenderIndex() == receiverIndex { + // Found matching session - get dest addr and update last seen + destAddr = session.GetDestAddr() + session.UpdateLastSeen() return false // stop iteration } return true // continue iteration @@ -591,49 +658,69 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd // Add a cleanup method to periodically remove idle connections func (s *UDPProxyServer) cleanupIdleConnections() { ticker := time.NewTicker(5 * time.Minute) - for range ticker.C { - now := time.Now() - s.connections.Range(func(key, value interface{}) bool { - destConn := value.(*DestinationConn) - if now.Sub(destConn.lastUsed) > 10*time.Minute { - destConn.conn.Close() - s.connections.Delete(key) - } - return true - }) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + s.connections.Range(func(key, value interface{}) bool { + destConn := value.(*DestinationConn) + if now.Sub(destConn.lastUsed) > 10*time.Minute { + destConn.conn.Close() + s.connections.Delete(key) + } + return true + }) + case <-s.ctx.Done(): + return + } } } // New method to periodically remove idle sessions func (s *UDPProxyServer) cleanupIdleSessions() { ticker := time.NewTicker(5 * time.Minute) - for range ticker.C { - now := time.Now() - s.wgSessions.Range(func(key, value interface{}) bool { - session := value.(*WireGuardSession) - if now.Sub(session.LastSeen) > 15*time.Minute { - s.wgSessions.Delete(key) - logger.Debug("Removed idle session: %s", key) - } - return true - }) + + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + s.wgSessions.Range(func(key, value interface{}) bool { + session := value.(*WireGuardSession) + // Use thread-safe method to read LastSeen + if now.Sub(session.GetLastSeen()) > 15*time.Minute { + s.wgSessions.Delete(key) + logger.Debug("Removed idle session: %s", key) + } + return true + }) + case <-s.ctx.Done(): + return + } } } // New method to periodically remove idle proxy mappings func (s *UDPProxyServer) cleanupIdleProxyMappings() { ticker := time.NewTicker(10 * time.Minute) - for range ticker.C { - now := time.Now() - s.proxyMappings.Range(func(key, value interface{}) bool { - mapping := value.(ProxyMapping) - // Remove mappings that haven't been used in 30 minutes - if now.Sub(mapping.LastUsed) > 30*time.Minute { - s.proxyMappings.Delete(key) - logger.Debug("Removed idle proxy mapping: %s", key) - } - return true - }) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + s.proxyMappings.Range(func(key, value interface{}) bool { + mapping := value.(ProxyMapping) + // Remove mappings that haven't been used in 30 minutes + if now.Sub(mapping.LastUsed) > 30*time.Minute { + s.proxyMappings.Delete(key) + logger.Debug("Removed idle proxy mapping: %s", key) + } + return true + }) + case <-s.ctx.Done(): + return + } } } @@ -738,8 +825,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) { keyStr := key.(string) session := value.(*WireGuardSession) - // Check if the session's destination address contains the WG IP - if session.DestAddr != nil && session.DestAddr.IP.String() == ip { + // Check if the session's destination address contains the WG IP (thread-safe) + destAddr := session.GetDestAddr() + if destAddr != nil && destAddr.IP.String() == ip { keysToDelete = append(keysToDelete, keyStr) logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr) } @@ -929,14 +1017,12 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { // Check if we already have this session if _, exists := s.wgSessions.Load(sessionKey); !exists { - session := &WireGuardSession{ + s.wgSessions.Store(sessionKey, &WireGuardSession{ ReceiverIndex: pattern.DestIndex, SenderIndex: pattern.ClientIndex, DestAddr: pattern.ToDestination, LastSeen: time.Now(), - } - - s.wgSessions.Store(sessionKey, session) + }) logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)", sessionKey, pattern.ToDestination.String(), pattern.PacketCount) } @@ -946,23 +1032,29 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { // cleanupIdleCommunicationPatterns periodically removes idle communication patterns func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { ticker := time.NewTicker(10 * time.Minute) - for range ticker.C { - now := time.Now() - s.commPatterns.Range(func(key, value interface{}) bool { - pattern := value.(*CommunicationPattern) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + s.commPatterns.Range(func(key, value interface{}) bool { + pattern := value.(*CommunicationPattern) - // Get the most recent activity - lastActivity := pattern.LastFromClient - if pattern.LastFromDest.After(lastActivity) { - lastActivity = pattern.LastFromDest - } + // Get the most recent activity + lastActivity := pattern.LastFromClient + if pattern.LastFromDest.After(lastActivity) { + lastActivity = pattern.LastFromDest + } - // Remove patterns that haven't had activity in 20 minutes - if now.Sub(lastActivity) > 20*time.Minute { - s.commPatterns.Delete(key) - logger.Debug("Removed idle communication pattern: %s", key) - } - return true - }) + // Remove patterns that haven't had activity in 20 minutes + if now.Sub(lastActivity) > 20*time.Minute { + s.commPatterns.Delete(key) + logger.Debug("Removed idle communication pattern: %s", key) + } + return true + }) + case <-s.ctx.Done(): + return + } } }