78 Commits

Author SHA1 Message Date
Owen
16aef10cca Merge branch 'main' of github.com:fosrl/gerbil 2025-10-16 13:41:42 -07:00
Owen
19031ebdfd Move to gen the port in the right place 2025-10-16 13:40:01 -07:00
Owen
0eebbc51d5 Deprecate --listen 2025-10-16 10:38:47 -07:00
Owen
d321a8ba7e Dont require proxy protocol from known hosts 2025-10-14 21:05:30 -07:00
Owen Schwartz
3ea86222ca Merge pull request #29 from fosrl/dependabot/go_modules/prod-minor-updates-ce64870c5e
Bump golang.org/x/crypto from 0.42.0 to 0.43.0 in the prod-minor-updates group
2025-10-11 09:41:08 -07:00
dependabot[bot]
c3ebe930d9 Bump golang.org/x/crypto in the prod-minor-updates group
Bumps the prod-minor-updates group with 1 update: [golang.org/x/crypto](https://github.com/golang/crypto).


Updates `golang.org/x/crypto` from 0.42.0 to 0.43.0
- [Commits](https://github.com/golang/crypto/compare/v0.42.0...v0.43.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.43.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: prod-minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-10-08 21:24:05 +00:00
Owen Schwartz
f2b96f2a38 Merge pull request #28 from SigmaSquadron/push-unypywyqkkrl
Change default port to 3004 to avoid a conflict with Pangolin's integration API.
2025-10-05 17:41:21 -07:00
Owen
9038239bbe Accept proxy protocol from other nodes 2025-09-29 21:56:15 -07:00
miloschwartz
3e64eb9c4f add templates 2025-09-29 16:41:29 -07:00
Owen
92992b8c14 Merge branch 'main' into dev 2025-09-28 16:28:07 -07:00
Owen
4ee9d77532 Rebuild sessions 2025-09-28 15:31:34 -07:00
Owen Schwartz
bd7a5bd4b0 Merge pull request #26 from fosrl/dependabot/github_actions/actions/setup-go-6
Bump actions/setup-go from 5 to 6
2025-09-15 14:43:53 -07:00
Owen Schwartz
1cd49f8ee3 Merge pull request #27 from fosrl/dependabot/go_modules/prod-minor-updates-237ba4726d
Bump golang.org/x/crypto from 0.41.0 to 0.42.0 in the prod-minor-updates group
2025-09-15 14:43:41 -07:00
Fernando Rodrigues
7a919d867b Change default port to 3004 to avoid a conflict with Pangolin's integration API.
Signed-off-by: Fernando Rodrigues <alpha@sigmasquadron.net>
2025-09-14 23:19:21 +10:00
dependabot[bot]
ce50c627a7 Bump golang.org/x/crypto in the prod-minor-updates group
Bumps the prod-minor-updates group with 1 update: [golang.org/x/crypto](https://github.com/golang/crypto).


Updates `golang.org/x/crypto` from 0.41.0 to 0.42.0
- [Commits](https://github.com/golang/crypto/compare/v0.41.0...v0.42.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.42.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: prod-minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-09-08 21:53:42 +00:00
dependabot[bot]
691d5f0271 Bump actions/setup-go from 5 to 6
Bumps [actions/setup-go](https://github.com/actions/setup-go) from 5 to 6.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-09-08 21:52:09 +00:00
Owen Schwartz
56151089e3 Merge pull request #24 from Lokowitz/dockerfile-update
changed docker image from ubuntu to alpine
2025-08-31 09:56:28 -07:00
Lokowitz
af7c1caf98 changed docker image from ubuntu to alpine 2025-08-31 11:19:16 +00:00
Owen Schwartz
dd208ab67c Merge pull request #22 from fosrl/dependabot/go_modules/prod-minor-updates-d1569a22cb
Bump golang.org/x/crypto from 0.36.0 to 0.41.0 in the prod-minor-updates group
2025-08-30 15:14:02 -07:00
Owen Schwartz
8189d41a45 Merge pull request #21 from fosrl/dependabot/github_actions/actions/checkout-5
Bump actions/checkout from 4 to 5
2025-08-30 15:13:54 -07:00
dependabot[bot]
ea3477c8ce Bump golang.org/x/crypto in the prod-minor-updates group
Bumps the prod-minor-updates group with 1 update: [golang.org/x/crypto](https://github.com/golang/crypto).


Updates `golang.org/x/crypto` from 0.36.0 to 0.41.0
- [Commits](https://github.com/golang/crypto/compare/v0.36.0...v0.41.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.41.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: prod-minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-08-27 21:35:30 +00:00
Owen Schwartz
a8a0f92c9b Merge pull request #23 from fosrl/dev
Add proxy protocol
2025-08-27 14:22:08 -07:00
Owen
7040a9436e Add proxy protocol 2025-08-26 22:26:01 -07:00
Owen
04361242fe Update readme 2025-08-23 12:29:26 -07:00
Owen
554b1d55dc Merge branch 'main' into dev 2025-08-23 12:24:21 -07:00
dependabot[bot]
b03f8911a5 Bump actions/checkout from 4 to 5
Bumps [actions/checkout](https://github.com/actions/checkout) from 4 to 5.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-08-23 04:37:59 +00:00
Owen Schwartz
47589570c9 Merge pull request #20 from Lokowitz/sync-go-versions
update versions and sync go version in all files
2025-08-22 21:37:20 -07:00
Owen
9f5b8dea26 Merge branch 'hybrid' into dev 2025-08-22 11:56:58 -07:00
Owen
f6a1e1e27c Merge branch 'main' into dev 2025-08-22 11:56:54 -07:00
Owen
f983a8f141 Local proxy port 443 2025-08-22 11:56:29 -07:00
Owen
efce3cb0b2 Sni has no errors now 2025-08-17 10:43:37 -07:00
Marvin
6eeebd81b2 sync go versions 2025-08-17 11:48:39 +00:00
Owen
c970fd5a18 Update to work with multipe endpoints 2025-08-16 22:59:45 -07:00
Owen
09bd02456d Move to post 2025-08-16 22:53:49 -07:00
Owen
c24537af36 Fix url 2025-08-16 22:36:03 -07:00
Owen
9de3f14799 Update default config 2025-08-16 22:35:51 -07:00
Owen Schwartz
0908f75f5f Merge pull request #19 from fosrl/dependabot/docker/minor-updates-80a311fbba
Bump golang from 1.24.3-alpine to 1.25.0-alpine in the minor-updates group
2025-08-15 09:40:54 -07:00
Owen
10958f8c55 Use propper logger 2025-08-14 22:25:38 -07:00
dependabot[bot]
b1840fd5c3 Bump golang in the minor-updates group
Bumps the minor-updates group with 1 update: golang.


Updates `golang` from 1.24.3-alpine to 1.25.0-alpine

---
updated-dependencies:
- dependency-name: golang
  dependency-version: 1.25.0-alpine
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-08-14 21:55:42 +00:00
Owen
1df5eb19ff Integrate sni proxy 2025-08-13 15:41:58 -07:00
Owen
f71f183886 Add basic proxy 2025-08-12 18:02:34 -07:00
Owen
8922ca9736 Fix some clients stuff for multi pop 2025-08-12 17:26:14 -07:00
Owen
38483f4a26 Allow for updating destinations 2025-07-28 22:41:11 -07:00
Owen
78c768e497 Add mutex 2025-07-28 21:35:57 -07:00
Owen
fc7df8a530 Update readme 2025-07-28 12:43:19 -07:00
Owen
50b42059ac Add new logic to handle changes in newt connection 2025-07-24 20:46:51 -07:00
Owen
825f7fcf60 Add notify 2025-06-21 12:06:58 -04:00
Owen
8c8ec72b40 Merge branch 'dev' into hp-multi-client 2025-06-10 09:39:39 -04:00
Owen
c61b7fc4fb Merge branch 'main' into dev 2025-06-10 09:39:29 -04:00
Owen Schwartz
96e3376147 Merge pull request #12 from Lokowitz/fix-dependabot
fix - dependabot
2025-06-10 09:37:52 -04:00
Owen Schwartz
e47a7c80d1 Merge pull request #11 from Lokowitz/add-test-action
Add test action
2025-06-10 09:37:34 -04:00
Marvin
f1e373f2d8 Update test.yml 2025-06-10 14:01:17 +02:00
Marvin
ef4d0db475 Update dependabot.yml 2025-06-10 13:40:41 +02:00
Marvin
b6b97f5ed3 Create test.yml 2025-06-10 13:36:07 +02:00
Marvin
dff267a42e Update Makefile 2025-06-10 13:34:49 +02:00
Owen Schwartz
bb98db7f5e Merge pull request #10 from Lokowitz/main
Update deps and add dependabot.yml
2025-06-02 09:04:41 -04:00
dependabot[bot]
f1016200b3 Bump golang.org/x/net from 0.21.0 to 0.38.0 in the go_modules group (#5)
Bumps the go_modules group with 1 update: [golang.org/x/net](https://github.com/golang/net).


Updates `golang.org/x/net` from 0.21.0 to 0.38.0
- [Commits](https://github.com/golang/net/compare/v0.21.0...v0.38.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-version: 0.38.0
  dependency-type: indirect
  dependency-group: go_modules
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-01 16:46:30 +02:00
dependabot[bot]
f1ab8094cf Bump the go_modules group with 2 updates (#4)
Bumps the go_modules group with 2 updates: [golang.org/x/crypto](https://github.com/golang/crypto) and [golang.org/x/net](https://github.com/golang/net).


Updates `golang.org/x/crypto` from 0.8.0 to 0.35.0
- [Commits](https://github.com/golang/crypto/compare/v0.8.0...v0.35.0)

Updates `golang.org/x/net` from 0.9.0 to 0.21.0
- [Commits](https://github.com/golang/net/compare/v0.9.0...v0.21.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.35.0
  dependency-type: indirect
  dependency-group: go_modules
- dependency-name: golang.org/x/net
  dependency-version: 0.21.0
  dependency-type: indirect
  dependency-group: go_modules
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-01 16:39:52 +02:00
dependabot[bot]
ad2bc0d397 Bump ubuntu from 22.04 to 24.04 in the major-updates group (#2)
Bumps the major-updates group with 1 update: ubuntu.


Updates `ubuntu` from 22.04 to 24.04

---
updated-dependencies:
- dependency-name: ubuntu
  dependency-version: '24.04'
  dependency-type: direct:production
  dependency-group: major-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-01 16:15:25 +02:00
dependabot[bot]
a78d141ca3 Bump github.com/vishvananda/netlink in the prod-patch-updates group (#3)
Bumps the prod-patch-updates group with 1 update: [github.com/vishvananda/netlink](https://github.com/vishvananda/netlink).


Updates `github.com/vishvananda/netlink` from 1.3.0 to 1.3.1
- [Release notes](https://github.com/vishvananda/netlink/releases)
- [Commits](https://github.com/vishvananda/netlink/compare/v1.3.0...v1.3.1)

---
updated-dependencies:
- dependency-name: github.com/vishvananda/netlink
  dependency-version: 1.3.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: prod-patch-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-01 16:11:22 +02:00
Marvin
10b1ad2a5a Merge pull request #1 from Lokowitz/dependabot/docker/minor-updates-c9322ea29a
Bump golang from 1.23.1-alpine to 1.24.3-alpine in the minor-updates group
2025-06-01 16:10:54 +02:00
dependabot[bot]
8a9f29043a Bump golang in the minor-updates group
Bumps the minor-updates group with 1 update: golang.


Updates `golang` from 1.23.1-alpine to 1.24.3-alpine

---
updated-dependencies:
- dependency-name: golang
  dependency-version: 1.24.3-alpine
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-06-01 14:08:27 +00:00
Marvin
05c9d851f4 Create dependabot.yml 2025-06-01 16:07:01 +02:00
Owen
c9a6b85e1d Attempt to add sender and receiver ids to relaying 2025-04-07 21:45:57 -04:00
Owen
a16021cd86 Put http server into routine 2025-03-25 20:49:28 -04:00
Owen
9506b545f4 Handle encrypted messages 2025-03-15 21:46:40 -04:00
Owen
17b87e6707 Merge branch 'dev' into holepunch 2025-03-04 00:02:04 -05:00
Owen
cba4dc646d Try to setup qemu 2025-03-03 23:58:55 -05:00
Owen
88be6d133d Update upload action to v4 2025-03-03 22:38:49 -05:00
Owen Schwartz
34a80c6411 Merge pull request #8 from fosrl/dev
Add CICD
2025-03-03 22:37:29 -05:00
Owen
6565fdbe62 Fix merge issue 2025-03-03 22:36:58 -05:00
Owen
993f5f86c5 Small adjustments 2025-02-23 20:17:16 -05:00
Owen
093a4c21f2 Big speed increase 2025-02-23 18:43:37 -05:00
Owen
f7c0bb9135 Basic relay working! 2025-02-23 16:49:49 -05:00
Owen
a145b77f79 Remove logging 2025-02-22 13:09:04 -05:00
Owen
7b3f7d2b12 Add holepunch udp server 2025-02-21 22:28:16 -05:00
Milo Schwartz
9c5ddcdfb8 Merge branch 'dev' of https://github.com/fosrl/gerbil into dev 2025-01-29 22:26:02 -05:00
Milo Schwartz
32176c74a0 add cicd 2025-01-29 22:25:33 -05:00
17 changed files with 2653 additions and 65 deletions

View File

@@ -6,4 +6,5 @@ README.md
Makefile
public/
LICENSE
CONTRIBUTING.md
CONTRIBUTING.md
.git

View File

@@ -0,0 +1,47 @@
body:
- type: textarea
attributes:
label: Summary
description: A clear and concise summary of the requested feature.
validations:
required: true
- type: textarea
attributes:
label: Motivation
description: |
Why is this feature important?
Explain the problem this feature would solve or what use case it would enable.
validations:
required: true
- type: textarea
attributes:
label: Proposed Solution
description: |
How would you like to see this feature implemented?
Provide as much detail as possible about the desired behavior, configuration, or changes.
validations:
required: true
- type: textarea
attributes:
label: Alternatives Considered
description: Describe any alternative solutions or workarounds you've thought about.
validations:
required: false
- type: textarea
attributes:
label: Additional Context
description: Add any other context, mockups, or screenshots about the feature request here.
validations:
required: false
- type: markdown
attributes:
value: |
Before submitting, please:
- Check if there is an existing issue for this feature.
- Clearly explain the benefit and use case.
- Be as specific as possible to help contributors evaluate and implement.

51
.github/ISSUE_TEMPLATE/1.bug_report.yml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: Bug Report
description: Create a bug report
labels: []
body:
- type: textarea
attributes:
label: Describe the Bug
description: A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
attributes:
label: Environment
description: Please fill out the relevant details below for your environment.
value: |
- OS Type & Version: (e.g., Ubuntu 22.04)
- Pangolin Version:
- Gerbil Version:
- Traefik Version:
- Newt Version:
- Olm Version: (if applicable)
validations:
required: true
- type: textarea
attributes:
label: To Reproduce
description: |
Steps to reproduce the behavior, please provide a clear description of how to reproduce the issue, based on the linked minimal reproduction. Screenshots can be provided in the issue body below.
If using code blocks, make sure syntax highlighting is correct and double-check that the rendered preview is not broken.
validations:
required: true
- type: textarea
attributes:
label: Expected Behavior
description: A clear and concise description of what you expected to happen.
validations:
required: true
- type: markdown
attributes:
value: |
Before posting the issue go through the steps you've written down to make sure the steps provided are detailed and clear.
- type: markdown
attributes:
value: |
Contributors should be able to follow the steps provided in order to reproduce the bug.

8
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,8 @@
blank_issues_enabled: false
contact_links:
- name: Need help or have questions?
url: https://github.com/orgs/fosrl/discussions
about: Ask questions, get help, and discuss with other community members
- name: Request a Feature
url: https://github.com/orgs/fosrl/discussions/new?category=feature-requests
about: Feature requests should be opened as discussions so others can upvote and comment

40
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,40 @@
version: 2
updates:
- package-ecosystem: "gomod"
directory: "/"
schedule:
interval: "daily"
groups:
dev-patch-updates:
dependency-type: "development"
update-types:
- "patch"
dev-minor-updates:
dependency-type: "development"
update-types:
- "minor"
prod-patch-updates:
dependency-type: "production"
update-types:
- "patch"
prod-minor-updates:
dependency-type: "production"
update-types:
- "minor"
- package-ecosystem: "docker"
directory: "/"
schedule:
interval: "daily"
groups:
patch-updates:
update-types:
- "patch"
minor-updates:
update-types:
- "minor"
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

52
.github/workflows/cicd.yml vendored Normal file
View File

@@ -0,0 +1,52 @@
name: CI/CD Pipeline
on:
push:
tags:
- "*"
jobs:
release:
name: Build and Release
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
- name: Extract tag name
id: get-tag
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
- name: Install Go
uses: actions/setup-go@v6
with:
go-version: 1.25
- name: Build and push Docker images
run: |
TAG=${{ env.TAG }}
make docker-build-release tag=$TAG
- name: Build binaries
run: |
make go-build-release
- name: Upload artifacts from /bin
uses: actions/upload-artifact@v4
with:
name: binaries
path: bin/

28
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
name: Run Tests
on:
pull_request:
branches:
- main
- dev
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: '1.25'
- name: Build go
run: go build
- name: Build Docker image
run: make build
- name: Build binaries
run: make go-build-release

1
.go-version Normal file
View File

@@ -0,0 +1 @@
1.25

View File

@@ -1,4 +1,4 @@
FROM golang:1.23.1-alpine AS builder
FROM golang:1.25-alpine AS builder
# Set the working directory inside the container
WORKDIR /app
@@ -16,18 +16,13 @@ COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -o /gerbil
# Start a new stage from scratch
FROM ubuntu:22.04 AS runner
FROM alpine:3.22 AS runner
RUN apt-get update && apt-get install -y iptables iproute2 && rm -rf /var/lib/apt/lists/*
RUN apk add --no-cache iptables iproute2
# Copy the pre-built binary file from the previous stage and the entrypoint script
COPY --from=builder /gerbil /usr/local/bin/
COPY entrypoint.sh /
RUN chmod +x /entrypoint.sh
# Copy the entrypoint script
ENTRYPOINT ["/entrypoint.sh"]
# Command to run the executable
CMD ["gerbil"]

View File

@@ -1,6 +1,14 @@
all: build push
docker-build-release:
@if [ -z "$(tag)" ]; then \
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
exit 1; \
fi
docker buildx build --platform linux/arm64,linux/amd64 -t fosrl/gerbil:latest -f Dockerfile --push .
docker buildx build --platform linux/arm64,linux/amd64 -t fosrl/gerbil:$(tag) -f Dockerfile --push .
build:
docker build -t fosrl/gerbil:latest .
@@ -13,9 +21,9 @@ test:
local:
CGO_ENABLED=0 GOOS=linux go build -o gerbil
release:
go-build-release:
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/gerbil_linux_arm64
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/gerbil_linux_amd64
clean:
rm gerbil
rm gerbil

View File

@@ -4,17 +4,10 @@ Gerbil is a simple [WireGuard](https://www.wireguard.com/) interface management
### Installation and Documentation
Gerbil can be used standalone with your own API, a static JSON file, or with Pangolin and Newt as part of the larger system. See documentation below:
Gerbil works with Pangolin, Newt, and Olm as part of the larger system. See documentation below:
- [Installation Instructions](https://docs.fossorial.io)
- [Full Documentation](https://docs.fossorial.io)
## Preview
<img src="public/screenshots/preview.png" alt="Preview"/>
_Sample output of a Gerbil container connected to Pangolin and terminating various peers._
## Key Functions
### Setup WireGuard
@@ -27,7 +20,25 @@ Gerbil will create the peers defined in the config on the WireGuard interface. T
### Report Bandwidth
Bytes transmitted in and out of each peer are collected every 10 seconds, and incremental usage is reported via the "reportBandwidthTo" endpoint. This can be used to track data usage of each peer on the remote server.
Bytes transmitted in and out of each peer are collected every 10 seconds, and incremental usage is reported via the api endpoint. This can be used to track data usage of each peer on the remote server.
### Handle client relaying
Gerbil listens on port 21820 for incoming UDP hole punch packets to orchestrate NAT hole punching between olm and newt clients. Additionally, it handles relaying data through the gerbil server down to the newt. This is accomplished by scanning each packet for headers and handling them appropriately.
### SNI Proxy
Gerbil includes an SNI (Server Name Indication) proxy that enables intelligent routing of HTTPS traffic between Pangolin nodes. When a TLS connection comes in, the proxy extracts the hostname from the SNI extension and queries Pangolin to determine the correct routing destination. This allows seamless routing of web traffic through the WireGuard mesh network:
- If the hostname is configured for local handling (via local overrides or local SNIs), traffic is routed to the local proxy
- Otherwise, the proxy queries Pangolin's routing API to determine which node should handle the traffic
- Supports caching of routing decisions to improve performance
- Handles connection pooling and graceful shutdown
- Optional PROXY protocol v1 support to preserve original client IP addresses when forwarding to downstream proxies (HAProxy, Nginx, etc.)
The PROXY protocol allows downstream proxies to know the real client IP address instead of seeing the SNI proxy's IP. When enabled with `--proxy-protocol`, the SNI proxy will prepend a PROXY protocol header to each connection containing the original client's IP and port information.
In single node (self hosted) Pangolin deployments this can be bypassed by using port 443:443 to route to Traefik instead of the SNI proxy at 8443.
## CLI Args
@@ -38,19 +49,44 @@ Bytes transmitted in and out of each peer are collected every 10 seconds, and in
Note: You must use either `config` or `remoteConfig` to configure WireGuard.
- `reportBandwidthTo` (optional): Remote HTTP endpoint to send peer bandwidth data
- `reportBandwidthTo` (optional): **DEPRECATED** - Use `remoteConfig` instead. Remote HTTP endpoint to send peer bandwidth data
- `interface` (optional): Name of the WireGuard interface created by Gerbil. Default: `wg0`
- `listen` (optional): Port to listen on for HTTP server. Default: `3003`
- `log-level` (optional): The log level to use. Default: INFO
- `listen` (optional): Port to listen on for HTTP server. Default: `:3004`
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: `INFO`
- `mtu` (optional): MTU of the WireGuard interface. Default: `1280`
- `notify` (optional): URL to notify on peer changes
- `sni-port` (optional): Port for the SNI proxy to listen on. Default: `8443`
- `local-proxy` (optional): Address for local proxy when routing local traffic. Default: `localhost`
- `local-proxy-port` (optional): Port for local proxy when routing local traffic. Default: `443`
- `local-overrides` (optional): Comma-separated list of domain names that should always be routed to the local proxy
- `proxy-protocol` (optional): Enable PROXY protocol v1 for preserving client IP addresses when forwarding to downstream proxies. Default: `false`
## Environment Variables
All CLI arguments can also be provided via environment variables:
- `INTERFACE`: Name of the WireGuard interface
- `CONFIG`: Path to local configuration file
- `REMOTE_CONFIG`: URL of the remote config server
- `LISTEN`: Address to listen on for HTTP server
- `GENERATE_AND_SAVE_KEY_TO`: Path to save generated private key
- `REACHABLE_AT`: Endpoint of the HTTP server to tell remote config about
- `LOG_LEVEL`: Log level (DEBUG, INFO, WARN, ERROR, FATAL)
- `MTU`: MTU of the WireGuard interface
- `NOTIFY_URL`: URL to notify on peer changes
- `SNI_PORT`: Port for the SNI proxy to listen on
- `LOCAL_PROXY`: Address for local proxy when routing local traffic
- `LOCAL_PROXY_PORT`: Port for local proxy when routing local traffic
- `LOCAL_OVERRIDES`: Comma-separated list of domain names that should always be routed to the local proxy
- `PROXY_PROTOCOL`: Enable PROXY protocol v1 for preserving client IP addresses (true/false)
Example:
```bash
./gerbil \
--reachableAt=http://gerbil:3003 \
--reachableAt=http://gerbil:3004 \
--generateAndSaveKeyTo=/var/config/key \
--remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config \
--reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth
--remoteConfig=http://pangolin:3001/api/v1/
```
```yaml
@@ -62,8 +98,7 @@ services:
command:
- --reachableAt=http://gerbil:3003
- --generateAndSaveKeyTo=/var/config/key
- --remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config
- --reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth
- --remoteConfig=http://pangolin:3001/api/v1/
volumes:
- ./config/:/var/config
cap_add:
@@ -71,6 +106,8 @@ services:
- SYS_MODULE
ports:
- 51820:51820/udp
- 21820:21820/udp
- 443:8443/tcp # SNI proxy port
```
## Build

14
go.mod
View File

@@ -1,10 +1,11 @@
module github.com/fosrl/gerbil
go 1.23.1
go 1.25
toolchain go1.23.2
require (
github.com/vishvananda/netlink v1.3.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.43.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
)
@@ -14,10 +15,9 @@ require (
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/crypto v0.8.0 // indirect
golang.org/x/net v0.9.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/net v0.45.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.10.0 // indirect
golang.org/x/sys v0.37.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect
)

25
go.sum
View File

@@ -8,21 +8,24 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=

426
main.go
View File

@@ -10,12 +10,16 @@ import (
"net/http"
"os"
"os/exec"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/fosrl/gerbil/logger"
"github.com/fosrl/gerbil/proxy"
"github.com/fosrl/gerbil/relay"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -27,6 +31,10 @@ var (
mtuInt int
lastReadings = make(map[string]PeerReading)
mu sync.Mutex
wgMu sync.Mutex // Protects WireGuard operations
notifyURL string
proxyRelay *relay.UDPProxyServer
proxySNI *proxy.SNIProxy
)
type WgConfig struct {
@@ -57,6 +65,31 @@ var (
wgClient *wgctrl.Client
)
// Add this new type at the top with other type definitions
type ClientEndpoint struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
IP string `json:"ip"`
Port int `json:"port"`
Timestamp int64 `json:"timestamp"`
}
type HolePunchMessage struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
}
type ProxyMappingUpdate struct {
OldDestination relay.PeerDestination `json:"oldDestination"`
NewDestination relay.PeerDestination `json:"newDestination"`
}
type UpdateDestinationsRequest struct {
SourceIP string `json:"sourceIp"`
SourcePort int `json:"sourcePort"`
Destinations []relay.PeerDestination `json:"destinations"`
}
func parseLogLevel(level string) logger.LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":
@@ -80,22 +113,34 @@ func main() {
wgconfig WgConfig
configFile string
remoteConfigURL string
reportBandwidthTo string
generateAndSaveKeyTo string
reachableAt string
logLevel string
mtu string
sniProxyPort int
localProxyAddr string
localProxyPort int
localOverridesStr string
trustedUpstreamsStr string
proxyProtocol bool
)
interfaceName = os.Getenv("INTERFACE")
configFile = os.Getenv("CONFIG")
remoteConfigURL = os.Getenv("REMOTE_CONFIG")
listenAddr = os.Getenv("LISTEN")
reportBandwidthTo = os.Getenv("REPORT_BANDWIDTH_TO")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
reachableAt = os.Getenv("REACHABLE_AT")
logLevel = os.Getenv("LOG_LEVEL")
mtu = os.Getenv("MTU")
notifyURL = os.Getenv("NOTIFY_URL")
sniProxyPortStr := os.Getenv("SNI_PORT")
localProxyAddr = os.Getenv("LOCAL_PROXY")
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS")
proxyProtocolStr := os.Getenv("PROXY_PROTOCOL")
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
@@ -104,31 +149,90 @@ func main() {
flag.StringVar(&configFile, "config", "", "Path to local configuration file")
}
if remoteConfigURL == "" {
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration")
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL of the Pangolin server")
}
if listenAddr == "" {
flag.StringVar(&listenAddr, "listen", ":3003", "Address to listen on")
}
if reportBandwidthTo == "" {
flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "Address to listen on")
flag.StringVar(&listenAddr, "listen", "", "DEPRECATED (overridden by reachableAt): Address to listen on")
}
// DEPRECATED AND UNSED: reportBandwidthTo
// allow reportBandwidthTo to be passed but dont do anything with it just thow it away
reportBandwidthTo := ""
flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "DEPRECATED: Use remoteConfig instead")
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
if reachableAt == "" {
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
}
if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if mtu == "" {
flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface")
}
if notifyURL == "" {
flag.StringVar(&notifyURL, "notify", "", "URL to notify on peer changes")
}
if sniProxyPortStr != "" {
if port, err := strconv.Atoi(sniProxyPortStr); err == nil {
sniProxyPort = port
}
}
if sniProxyPortStr == "" {
flag.IntVar(&sniProxyPort, "sni-port", 8443, "Port to listen on")
}
if localProxyAddr == "" {
flag.StringVar(&localProxyAddr, "local-proxy", "localhost", "Local proxy address")
}
if localProxyPortStr != "" {
if port, err := strconv.Atoi(localProxyPortStr); err == nil {
localProxyPort = port
}
}
if localProxyPortStr == "" {
flag.IntVar(&localProxyPort, "local-proxy-port", 443, "Local proxy port")
}
if localOverridesStr != "" {
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
}
if trustedUpstreamsStr == "" {
flag.StringVar(&trustedUpstreamsStr, "trusted-upstreams", "", "Comma-separated list of trusted upstream proxy domain names/IPs that can send PROXY protocol")
}
if proxyProtocolStr != "" {
proxyProtocol = strings.ToLower(proxyProtocolStr) == "true"
}
if proxyProtocolStr == "" {
flag.BoolVar(&proxyProtocol, "proxy-protocol", true, "Enable PROXY protocol v1 for preserving client IP")
}
flag.Parse()
logger.Init()
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
// try to parse as http://host:port and set the listenAddr to the :port from this reachableAt.
if reachableAt != "" && listenAddr == "" {
if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") {
parts := strings.Split(reachableAt, ":")
if len(parts) == 3 {
port := parts[2]
if strings.Contains(port, "/") {
port = strings.Split(port, "/")[0]
}
listenAddr = ":" + port
}
}
} else if listenAddr == "" {
listenAddr = ":3003"
}
mtuInt, err = strconv.Atoi(mtu)
if err != nil {
logger.Fatal("Failed to parse MTU: %v", err)
@@ -144,6 +248,10 @@ func main() {
logger.Fatal("You must provide either a config file or a remote config URL, not both")
}
// clean up the reomte config URL for backwards compatibility
remoteConfigURL = strings.TrimSuffix(remoteConfigURL, "/gerbil/get-config")
remoteConfigURL = strings.TrimSuffix(remoteConfigURL, "/")
var key wgtypes.Key
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
if generateAndSaveKeyTo != "" {
@@ -191,8 +299,8 @@ func main() {
} else {
// loop until we get the config
for wgconfig.PrivateKey == "" {
logger.Info("Fetching remote config from %s", remoteConfigURL)
wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt)
logger.Info("Fetching remote config from %s", remoteConfigURL+"/gerbil/get-config")
wgconfig, err = loadRemoteConfig(remoteConfigURL+"/gerbil/get-config", key, reachableAt)
if err != nil {
logger.Error("Failed to load configuration: %v", err)
time.Sleep(5 * time.Second)
@@ -216,13 +324,65 @@ func main() {
// Ensure the WireGuard peers exist
ensureWireguardPeers(wgconfig.Peers)
if reportBandwidthTo != "" {
go periodicBandwidthCheck(reportBandwidthTo)
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
// Start the UDP proxy server
proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
err = proxyRelay.Start()
if err != nil {
logger.Fatal("Failed to start UDP proxy server: %v", err)
}
defer proxyRelay.Stop()
// TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING
// SO YOU DON'T NEED TO SET THIS SEPARATELY
// Parse local overrides
var localOverrides []string
if localOverridesStr != "" {
localOverrides = strings.Split(localOverridesStr, ",")
for i, domain := range localOverrides {
localOverrides[i] = strings.TrimSpace(domain)
}
logger.Info("Local overrides configured: %v", localOverrides)
}
var trustedUpstreams []string
if trustedUpstreamsStr != "" {
trustedUpstreams = strings.Split(trustedUpstreamsStr, ",")
for i, upstream := range trustedUpstreams {
trustedUpstreams[i] = strings.TrimSpace(upstream)
}
logger.Info("Trusted upstreams configured: %v", trustedUpstreams)
}
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol, trustedUpstreams)
if err != nil {
logger.Fatal("Failed to create proxy: %v", err)
}
if err := proxySNI.Start(); err != nil {
logger.Fatal("Failed to start proxy: %v", err)
}
// Set up HTTP server
http.HandleFunc("/peer", handlePeer)
logger.Info("Starting server on %s", listenAddr)
logger.Fatal("Failed to start server: %v", http.ListenAndServe(listenAddr, nil))
http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping)
http.HandleFunc("/update-destinations", handleUpdateDestinations)
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
logger.Info("Starting HTTP server on %s", listenAddr)
// Run HTTP server in a goroutine
go func() {
if err := http.ListenAndServe(listenAddr, nil); err != nil {
logger.Error("HTTP server failed: %v", err)
}
}()
// Keep the main goroutine running
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
logger.Info("Shutting down servers...")
}
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
@@ -377,6 +537,9 @@ func assignIPAddress(ipAddress string) error {
}
func ensureWireguardPeers(peers []Peer) error {
wgMu.Lock()
defer wgMu.Unlock()
// get the current peers
device, err := wgClient.Device(interfaceName)
if err != nil {
@@ -399,8 +562,8 @@ func ensureWireguardPeers(peers []Peer) error {
}
}
if !found {
err := removePeer(peer)
if err != nil {
// Note: We need to call the internal removal logic without re-acquiring the lock
if err := removePeerInternal(peer); err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
}
@@ -416,8 +579,8 @@ func ensureWireguardPeers(peers []Peer) error {
}
}
if !found {
err := addPeer(configPeer)
if err != nil {
// Note: We need to call the internal addition logic without re-acquiring the lock
if err := addPeerInternal(configPeer); err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
}
@@ -477,7 +640,7 @@ func ensureMSSClamping() error {
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
chain, err, string(out))
logger.Error(errMsg)
errors = append(errors, fmt.Errorf(errMsg))
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
@@ -494,7 +657,7 @@ func ensureMSSClamping() error {
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
chain, err, string(out))
logger.Error(errMsg)
errors = append(errors, fmt.Errorf(errMsg))
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
@@ -538,11 +701,20 @@ func handleAddPeer(w http.ResponseWriter, r *http.Request) {
return
}
// Notify if notifyURL is set
go notifyPeerChange("add", peer.PublicKey)
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{"status": "Peer added successfully"})
}
func addPeer(peer Peer) error {
wgMu.Lock()
defer wgMu.Unlock()
return addPeerInternal(peer)
}
func addPeerInternal(peer Peer) error {
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
@@ -550,12 +722,15 @@ func addPeer(peer Peer) error {
// parse allowed IPs into array of net.IPNet
var allowedIPs []net.IPNet
var wgIPs []string
for _, ipStr := range peer.AllowedIPs {
_, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return fmt.Errorf("failed to parse allowed IP: %v", err)
}
allowedIPs = append(allowedIPs, *ipNet)
// Extract the IP address from the CIDR for relay cleanup
wgIPs = append(wgIPs, ipNet.IP.String())
}
peerConfig := wgtypes.PeerConfig{
@@ -571,6 +746,13 @@ func addPeer(peer Peer) error {
return fmt.Errorf("failed to add peer: %v", err)
}
// Clear relay connections for the peer's WireGuard IPs
if proxyRelay != nil {
for _, wgIP := range wgIPs {
proxyRelay.OnPeerAdded(wgIP)
}
}
logger.Info("Peer %s added successfully", peer.PublicKey)
return nil
@@ -589,16 +771,42 @@ func handleRemovePeer(w http.ResponseWriter, r *http.Request) {
return
}
// Notify if notifyURL is set
go notifyPeerChange("remove", publicKey)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "Peer removed successfully"})
}
func removePeer(publicKey string) error {
wgMu.Lock()
defer wgMu.Unlock()
return removePeerInternal(publicKey)
}
func removePeerInternal(publicKey string) error {
pubKey, err := wgtypes.ParseKey(publicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
}
// Get current peer info before removing to clear relay connections
var wgIPs []string
if proxyRelay != nil {
device, err := wgClient.Device(interfaceName)
if err == nil {
for _, peer := range device.Peers {
if peer.PublicKey.String() == publicKey {
// Extract WireGuard IPs from this peer's allowed IPs
for _, allowedIP := range peer.AllowedIPs {
wgIPs = append(wgIPs, allowedIP.IP.String())
}
break
}
}
}
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Remove: true,
@@ -612,11 +820,163 @@ func removePeer(publicKey string) error {
return fmt.Errorf("failed to remove peer: %v", err)
}
// Clear relay connections for the peer's WireGuard IPs
if proxyRelay != nil {
for _, wgIP := range wgIPs {
proxyRelay.OnPeerRemoved(wgIP)
}
}
logger.Info("Peer %s removed successfully", publicKey)
return nil
}
func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
logger.Error("Invalid method: %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var update ProxyMappingUpdate
if err := json.NewDecoder(r.Body).Decode(&update); err != nil {
logger.Error("Failed to decode request body: %v", err)
http.Error(w, fmt.Sprintf("Failed to decode request body: %v", err), http.StatusBadRequest)
return
}
// Validate the update request
if update.OldDestination.DestinationIP == "" || update.NewDestination.DestinationIP == "" {
logger.Error("Both old and new destination IP addresses are required")
http.Error(w, "Both old and new destination IP addresses are required", http.StatusBadRequest)
return
}
if update.OldDestination.DestinationPort <= 0 || update.NewDestination.DestinationPort <= 0 {
logger.Error("Both old and new destination ports must be positive integers")
http.Error(w, "Both old and new destination ports must be positive integers", http.StatusBadRequest)
return
}
// Update the proxy mappings in the relay server
if proxyRelay == nil {
logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
return
}
updatedCount := proxyRelay.UpdateDestinationInMappings(update.OldDestination, update.NewDestination)
logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d",
updatedCount,
update.OldDestination.DestinationIP, update.OldDestination.DestinationPort,
update.NewDestination.DestinationIP, update.NewDestination.DestinationPort)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "Proxy mappings updated successfully",
"updatedCount": updatedCount,
"oldDestination": update.OldDestination,
"newDestination": update.NewDestination,
})
}
func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
logger.Error("Invalid method: %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var request UpdateDestinationsRequest
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
logger.Error("Failed to decode request body: %v", err)
http.Error(w, fmt.Sprintf("Failed to decode request body: %v", err), http.StatusBadRequest)
return
}
// Validate the request
if request.SourceIP == "" {
logger.Error("Source IP address is required")
http.Error(w, "Source IP address is required", http.StatusBadRequest)
return
}
if request.SourcePort <= 0 {
logger.Error("Source port must be a positive integer")
http.Error(w, "Source port must be a positive integer", http.StatusBadRequest)
return
}
if len(request.Destinations) == 0 {
logger.Error("At least one destination is required")
http.Error(w, "At least one destination is required", http.StatusBadRequest)
return
}
// Validate each destination
for i, dest := range request.Destinations {
if dest.DestinationIP == "" {
logger.Error("Destination IP is required for destination %d", i)
http.Error(w, fmt.Sprintf("Destination IP is required for destination %d", i), http.StatusBadRequest)
return
}
if dest.DestinationPort <= 0 {
logger.Error("Destination port must be a positive integer for destination %d", i)
http.Error(w, fmt.Sprintf("Destination port must be a positive integer for destination %d", i), http.StatusBadRequest)
return
}
}
// Update the proxy mappings in the relay server
if proxyRelay == nil {
logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
return
}
proxyRelay.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations)
logger.Info("Updated proxy mapping for %s:%d with %d destinations",
request.SourceIP, request.SourcePort, len(request.Destinations))
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "Destinations updated successfully",
"sourceIP": request.SourceIP,
"sourcePort": request.SourcePort,
"destinationCount": len(request.Destinations),
"destinations": request.Destinations,
})
}
// UpdateLocalSNIsRequest represents the JSON payload for updating local SNIs
type UpdateLocalSNIsRequest struct {
FullDomains []string `json:"fullDomains"`
}
func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
logger.Error("Invalid method: %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req UpdateLocalSNIsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
return
}
proxySNI.UpdateLocalSNIs(req.FullDomains)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "Local SNIs updated successfully",
})
}
func periodicBandwidthCheck(endpoint string) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
@@ -629,7 +989,10 @@ func periodicBandwidthCheck(endpoint string) {
}
func calculatePeerBandwidth() ([]PeerBandwidth, error) {
wgMu.Lock()
device, err := wgClient.Device(interfaceName)
wgMu.Unlock()
if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err)
}
@@ -736,3 +1099,28 @@ func reportPeerBandwidth(apiURL string) error {
return nil
}
// notifyPeerChange sends a POST request to notifyURL with the action and public key.
func notifyPeerChange(action, publicKey string) {
if notifyURL == "" {
return
}
payload := map[string]string{
"action": action,
"publicKey": publicKey,
}
data, err := json.Marshal(payload)
if err != nil {
logger.Warn("Failed to marshal notify payload: %v", err)
return
}
resp, err := http.Post(notifyURL, "application/json", bytes.NewBuffer(data))
if err != nil {
logger.Warn("Failed to notify peer change: %v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.Warn("Notify server returned non-OK: %s", resp.Status)
}
}

845
proxy/proxy.go Normal file
View File

@@ -0,0 +1,845 @@
package proxy
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"hash/fnv"
"io"
"log"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/fosrl/gerbil/logger"
"github.com/patrickmn/go-cache"
)
// RouteRecord represents a routing configuration
type RouteRecord struct {
Hostname string
TargetHost string
TargetPort int
}
// RouteAPIResponse represents the response from the route API
type RouteAPIResponse struct {
Endpoints []string `json:"endpoints"`
}
// ProxyProtocolInfo holds information parsed from incoming PROXY protocol header
type ProxyProtocolInfo struct {
Protocol string // TCP4 or TCP6
SrcIP string
DestIP string
SrcPort int
DestPort int
OriginalConn net.Conn // The original connection after PROXY protocol parsing
}
// SNIProxy represents the main proxy server
type SNIProxy struct {
port int
cache *cache.Cache
listener net.Listener
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
localProxyAddr string
localProxyPort int
remoteConfigURL string
publicKey string
proxyProtocol bool // Enable PROXY protocol v1
// New fields for fast local SNI lookup
localSNIs map[string]struct{}
localSNIsLock sync.RWMutex
// Local overrides for domains that should always use local proxy
localOverrides map[string]struct{}
// Track active tunnels by SNI
activeTunnels map[string]*activeTunnel
activeTunnelsLock sync.Mutex
// Trusted upstream proxies that can send PROXY protocol
trustedUpstreams map[string]struct{}
}
type activeTunnel struct {
conns []net.Conn
}
// readOnlyConn is a wrapper for io.Reader that implements net.Conn
type readOnlyConn struct {
reader io.Reader
}
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
func (conn readOnlyConn) Close() error { return nil }
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
// parseProxyProtocolHeader parses a PROXY protocol v1 header from the connection
func (p *SNIProxy) parseProxyProtocolHeader(conn net.Conn) (*ProxyProtocolInfo, net.Conn, error) {
// Check if the connection comes from a trusted upstream
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
return nil, conn, fmt.Errorf("failed to parse remote address: %w", err)
}
// Resolve the remote IP to hostname to check if it's trusted
// For simplicity, we'll check the IP directly in trusted upstreams
// In production, you might want to do reverse DNS lookup
if _, isTrusted := p.trustedUpstreams[remoteHost]; !isTrusted {
// Not from trusted upstream, return original connection
return nil, conn, nil
}
// Set read timeout for PROXY protocol parsing
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, conn, fmt.Errorf("failed to set read deadline: %w", err)
}
// Read the first line (PROXY protocol header)
buffer := make([]byte, 512) // PROXY protocol header should be much smaller
n, err := conn.Read(buffer)
if err != nil {
// If we can't read from trusted upstream, treat as regular connection
logger.Debug("Could not read from trusted upstream %s, treating as regular connection: %v", remoteHost, err)
// Clear read timeout before returning
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
logger.Debug("Failed to clear read deadline: %v", clearErr)
}
return nil, conn, nil
}
// Find the end of the first line (CRLF)
headerEnd := bytes.Index(buffer[:n], []byte("\r\n"))
if headerEnd == -1 {
// No PROXY protocol header found, treat as regular TLS connection
// Return the connection with the buffered data prepended
logger.Debug("No PROXY protocol header from trusted upstream %s, treating as regular TLS connection", remoteHost)
// Clear read timeout
if err := conn.SetReadDeadline(time.Time{}); err != nil {
logger.Debug("Failed to clear read deadline: %v", err)
}
// Create a reader that includes the buffered data + original connection
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
wrappedConn := &proxyProtocolConn{
Conn: conn,
reader: newReader,
}
return nil, wrappedConn, nil
}
headerLine := string(buffer[:headerEnd])
remainingData := buffer[headerEnd+2 : n]
// Parse PROXY protocol line: "PROXY TCP4/TCP6 srcIP destIP srcPort destPort"
parts := strings.Fields(headerLine)
if len(parts) != 6 || parts[0] != "PROXY" {
// Check for PROXY UNKNOWN
if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" {
// PROXY UNKNOWN - use original connection info
return nil, conn, nil
}
// Invalid PROXY protocol, but might be regular TLS - treat as such
logger.Debug("Invalid PROXY protocol from trusted upstream %s, treating as regular TLS connection: %s", remoteHost, headerLine)
// Clear read timeout
if err := conn.SetReadDeadline(time.Time{}); err != nil {
logger.Debug("Failed to clear read deadline: %v", err)
}
// Return the connection with all buffered data prepended
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
wrappedConn := &proxyProtocolConn{
Conn: conn,
reader: newReader,
}
return nil, wrappedConn, nil
}
protocol := parts[1]
srcIP := parts[2]
destIP := parts[3]
srcPort, err := strconv.Atoi(parts[4])
if err != nil {
return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4])
}
destPort, err := strconv.Atoi(parts[5])
if err != nil {
return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5])
}
// Create a new reader that includes remaining data + original connection
var newReader io.Reader
if len(remainingData) > 0 {
newReader = io.MultiReader(bytes.NewReader(remainingData), conn)
} else {
newReader = conn
}
// Create a wrapper connection that reads from the combined reader
wrappedConn := &proxyProtocolConn{
Conn: conn,
reader: newReader,
}
proxyInfo := &ProxyProtocolInfo{
Protocol: protocol,
SrcIP: srcIP,
DestIP: destIP,
SrcPort: srcPort,
DestPort: destPort,
OriginalConn: wrappedConn,
}
// Clear read timeout
if err := conn.SetReadDeadline(time.Time{}); err != nil {
return nil, conn, fmt.Errorf("failed to clear read deadline: %w", err)
}
return proxyInfo, wrappedConn, nil
}
// proxyProtocolConn wraps a connection to read from a custom reader
type proxyProtocolConn struct {
net.Conn
reader io.Reader
}
func (c *proxyProtocolConn) Read(b []byte) (int, error) {
return c.reader.Read(b)
}
// buildProxyProtocolHeaderFromInfo creates a PROXY protocol v1 header using ProxyProtocolInfo
func (p *SNIProxy) buildProxyProtocolHeaderFromInfo(proxyInfo *ProxyProtocolInfo, targetAddr net.Addr) string {
targetTCP, ok := targetAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
// Use the original client information from the PROXY protocol
var targetIP string
var protocol string
// Parse source IP to determine protocol family
srcIP := net.ParseIP(proxyInfo.SrcIP)
if srcIP == nil {
return "PROXY UNKNOWN\r\n"
}
if srcIP.To4() != nil {
// Source is IPv4, use TCP4 protocol
protocol = "TCP4"
if targetTCP.IP.To4() != nil {
// Target is also IPv4, use as-is
targetIP = targetTCP.IP.String()
} else {
// Target is IPv6, but we need IPv4 for consistent protocol family
if targetTCP.IP.IsLoopback() {
targetIP = "127.0.0.1"
} else {
targetIP = "127.0.0.1" // Safe fallback
}
}
} else {
// Source is IPv6, use TCP6 protocol
protocol = "TCP6"
if targetTCP.IP.To4() != nil {
// Target is IPv4, convert to IPv6 representation
targetIP = "::ffff:" + targetTCP.IP.String()
} else {
// Target is also IPv6, use as-is
targetIP = targetTCP.IP.String()
}
}
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
protocol,
proxyInfo.SrcIP,
targetIP,
proxyInfo.SrcPort,
targetTCP.Port)
}
// buildProxyProtocolHeader creates a PROXY protocol v1 header
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
clientTCP, ok := clientAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
targetTCP, ok := targetAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
// Determine protocol family based on client IP and normalize target IP accordingly
var protocol string
var targetIP string
if clientTCP.IP.To4() != nil {
// Client is IPv4, use TCP4 protocol
protocol = "TCP4"
if targetTCP.IP.To4() != nil {
// Target is also IPv4, use as-is
targetIP = targetTCP.IP.String()
} else {
// Target is IPv6, but we need IPv4 for consistent protocol family
// Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1
if targetTCP.IP.IsLoopback() {
targetIP = "127.0.0.1"
} else {
// For non-loopback IPv6 targets, we could try to extract embedded IPv4
// or fall back to a sensible IPv4 address based on the target
targetIP = "127.0.0.1" // Safe fallback
}
}
} else {
// Client is IPv6, use TCP6 protocol
protocol = "TCP6"
if targetTCP.IP.To4() != nil {
// Target is IPv4, convert to IPv6 representation
targetIP = "::ffff:" + targetTCP.IP.String()
} else {
// Target is also IPv6, use as-is
targetIP = targetTCP.IP.String()
}
}
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
protocol,
clientTCP.IP.String(),
targetIP,
clientTCP.Port,
targetTCP.Port)
}
// NewSNIProxy creates a new SNI proxy instance
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool, trustedUpstreams []string) (*SNIProxy, error) {
ctx, cancel := context.WithCancel(context.Background())
// Create local overrides map
overridesMap := make(map[string]struct{})
for _, domain := range localOverrides {
if domain != "" {
overridesMap[domain] = struct{}{}
}
}
// Create trusted upstreams map
trustedMap := make(map[string]struct{})
for _, upstream := range trustedUpstreams {
if upstream != "" {
// Add both the domain and potentially resolved IPs
trustedMap[upstream] = struct{}{}
// Try to resolve the domain to IPs and add them too
if ips, err := net.LookupIP(upstream); err == nil {
for _, ip := range ips {
trustedMap[ip.String()] = struct{}{}
}
}
}
}
proxy := &SNIProxy{
port: port,
cache: cache.New(3*time.Second, 10*time.Minute),
ctx: ctx,
cancel: cancel,
localProxyAddr: localProxyAddr,
localProxyPort: localProxyPort,
remoteConfigURL: remoteConfigURL,
publicKey: publicKey,
proxyProtocol: proxyProtocol,
localSNIs: make(map[string]struct{}),
localOverrides: overridesMap,
activeTunnels: make(map[string]*activeTunnel),
trustedUpstreams: trustedMap,
}
return proxy, nil
}
// Start begins listening for connections
func (p *SNIProxy) Start() error {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.port))
if err != nil {
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
}
p.listener = listener
logger.Debug("SNI Proxy listening on port %d", p.port)
// Accept connections in a goroutine
go p.acceptConnections()
return nil
}
// Stop gracefully shuts down the proxy
func (p *SNIProxy) Stop() error {
log.Println("Stopping SNI Proxy...")
p.cancel()
if p.listener != nil {
p.listener.Close()
}
// Wait for all goroutines to finish with timeout
done := make(chan struct{})
go func() {
p.wg.Wait()
close(done)
}()
select {
case <-done:
log.Println("All connections closed gracefully")
case <-time.After(30 * time.Second):
log.Println("Timeout waiting for connections to close")
}
log.Println("SNI Proxy stopped")
return nil
}
// acceptConnections handles incoming connections
func (p *SNIProxy) acceptConnections() {
for {
conn, err := p.listener.Accept()
if err != nil {
select {
case <-p.ctx.Done():
return
default:
logger.Debug("Accept error: %v", err)
continue
}
}
p.wg.Add(1)
go p.handleConnection(conn)
}
}
// readClientHello reads and parses the TLS ClientHello message
func (p *SNIProxy) readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
var hello *tls.ClientHelloInfo
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
hello = new(tls.ClientHelloInfo)
*hello = *argHello
return nil, nil
},
}).Handshake()
if hello == nil {
return nil, err
}
return hello, nil
}
// peekClientHello reads the ClientHello while preserving the data for forwarding
func (p *SNIProxy) peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
peekedBytes := new(bytes.Buffer)
hello, err := p.readClientHello(io.TeeReader(reader, peekedBytes))
if err != nil {
return nil, nil, err
}
return hello, io.MultiReader(peekedBytes, reader), nil
}
// extractSNI extracts the SNI hostname from the TLS ClientHello
func (p *SNIProxy) extractSNI(conn net.Conn) (string, io.Reader, error) {
clientHello, clientReader, err := p.peekClientHello(conn)
if err != nil {
return "", nil, fmt.Errorf("failed to peek ClientHello: %w", err)
}
if clientHello.ServerName == "" {
return "", clientReader, fmt.Errorf("no SNI hostname found in ClientHello")
}
return clientHello.ServerName, clientReader, nil
}
// handleConnection processes a single client connection
func (p *SNIProxy) handleConnection(clientConn net.Conn) {
defer p.wg.Done()
defer clientConn.Close()
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
// Check for PROXY protocol from trusted upstream
var proxyInfo *ProxyProtocolInfo
var actualClientConn net.Conn = clientConn
if len(p.trustedUpstreams) > 0 {
var err error
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
if err != nil {
logger.Debug("Failed to parse PROXY protocol: %v", err)
return
}
if proxyInfo != nil {
logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d",
proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort)
} else {
// No PROXY protocol detected, but connection is from trusted upstream
// This is fine - treat as regular connection
logger.Debug("No PROXY protocol detected from trusted upstream, treating as regular connection")
}
}
// Set read timeout for SNI extraction
if err := actualClientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
logger.Debug("Failed to set read deadline: %v", err)
return
}
// Extract SNI hostname
hostname, clientReader, err := p.extractSNI(actualClientConn)
if err != nil {
logger.Debug("SNI extraction failed: %v", err)
return
}
if hostname == "" {
log.Println("No SNI hostname found")
return
}
logger.Debug("SNI hostname detected: %s", hostname)
// Remove read timeout for normal operation
if err := actualClientConn.SetReadDeadline(time.Time{}); err != nil {
logger.Debug("Failed to clear read deadline: %v", err)
return
}
// Get routing information - use original client address if available from PROXY protocol
var clientAddrStr string
if proxyInfo != nil {
clientAddrStr = fmt.Sprintf("%s:%d", proxyInfo.SrcIP, proxyInfo.SrcPort)
} else {
clientAddrStr = clientConn.RemoteAddr().String()
}
route, err := p.getRoute(hostname, clientAddrStr)
if err != nil {
logger.Debug("Failed to get route for %s: %v", hostname, err)
return
}
if route == nil {
logger.Debug("No route found for hostname: %s", hostname)
return
}
logger.Debug("Routing %s to %s:%d", hostname, route.TargetHost, route.TargetPort)
// Connect to target server
targetConn, err := net.DialTimeout("tcp",
fmt.Sprintf("%s:%d", route.TargetHost, route.TargetPort),
10*time.Second)
if err != nil {
logger.Debug("Failed to connect to target %s:%d: %v",
route.TargetHost, route.TargetPort, err)
return
}
defer targetConn.Close()
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
// Send PROXY protocol header if enabled
if p.proxyProtocol {
var proxyHeader string
if proxyInfo != nil {
// Use original client info from PROXY protocol
proxyHeader = p.buildProxyProtocolHeaderFromInfo(proxyInfo, targetConn.LocalAddr())
} else {
// Use direct client connection info
proxyHeader = buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
}
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
logger.Debug("Failed to send PROXY protocol header: %v", err)
return
}
}
// Track this tunnel by SNI
p.activeTunnelsLock.Lock()
tunnel, ok := p.activeTunnels[hostname]
if !ok {
tunnel = &activeTunnel{}
p.activeTunnels[hostname] = tunnel
}
tunnel.conns = append(tunnel.conns, actualClientConn)
p.activeTunnelsLock.Unlock()
defer func() {
// Remove this conn from active tunnels
p.activeTunnelsLock.Lock()
if tunnel, ok := p.activeTunnels[hostname]; ok {
newConns := make([]net.Conn, 0, len(tunnel.conns))
for _, c := range tunnel.conns {
if c != actualClientConn {
newConns = append(newConns, c)
}
}
if len(newConns) == 0 {
delete(p.activeTunnels, hostname)
} else {
tunnel.conns = newConns
}
}
p.activeTunnelsLock.Unlock()
}()
// Start bidirectional data transfer
p.pipe(actualClientConn, targetConn, clientReader)
}
// getRoute retrieves routing information for a hostname
func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
// Check local overrides first
if _, isOverride := p.localOverrides[hostname]; isOverride {
logger.Debug("Local override matched for hostname: %s", hostname)
return &RouteRecord{
Hostname: hostname,
TargetHost: p.localProxyAddr,
TargetPort: p.localProxyPort,
}, nil
}
// Fast path: check if hostname is in localSNIs
p.localSNIsLock.RLock()
_, isLocal := p.localSNIs[hostname]
p.localSNIsLock.RUnlock()
if isLocal {
return &RouteRecord{
Hostname: hostname,
TargetHost: p.localProxyAddr,
TargetPort: p.localProxyPort,
}, nil
}
// Check cache first
if cached, found := p.cache.Get(hostname); found {
if cached == nil {
return nil, nil // Cached negative result
}
logger.Debug("Cache hit for hostname: %s", hostname)
return cached.(*RouteRecord), nil
}
logger.Debug("Cache miss for hostname: %s, querying API", hostname)
// Query API with timeout
ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second)
defer cancel()
// Construct API URL (without hostname in path)
apiURL := fmt.Sprintf("%s/gerbil/get-resolved-hostname", p.remoteConfigURL)
// Create request body with hostname and public key
requestBody := map[string]string{
"hostname": hostname,
"publicKey": p.publicKey,
}
jsonBody, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
// Make HTTP request
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("API request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// Cache negative result for shorter time (1 minute)
p.cache.Set(hostname, nil, 1*time.Minute)
return nil, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned status %d", resp.StatusCode)
}
// Parse response
var apiResponse RouteAPIResponse
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
return nil, fmt.Errorf("failed to decode API response: %w", err)
}
endpoints := apiResponse.Endpoints
// Default target configuration
targetHost := p.localProxyAddr
targetPort := p.localProxyPort
// If no endpoints returned, use local node
if len(endpoints) == 0 {
logger.Debug("No endpoints returned for hostname: %s, using local node", hostname)
} else {
// Select endpoint using consistent hashing for stickiness
selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints)
targetHost = selectedEndpoint
targetPort = 443 // Default HTTPS port
logger.Debug("Selected endpoint %s for hostname %s from client %s", selectedEndpoint, hostname, clientAddr)
}
route := &RouteRecord{
Hostname: hostname,
TargetHost: targetHost,
TargetPort: targetPort,
}
// Cache the result
p.cache.Set(hostname, route, cache.DefaultExpiration)
logger.Debug("Cached route for hostname: %s", hostname)
return route, nil
}
// selectStickyEndpoint selects an endpoint using consistent hashing to ensure
// the same client always routes to the same endpoint for load balancing
func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) string {
if len(endpoints) == 0 {
return p.localProxyAddr
}
if len(endpoints) == 1 {
return endpoints[0]
}
// Use FNV hash for consistent selection based on client address
hash := fnv.New32a()
hash.Write([]byte(clientAddr))
index := hash.Sum32() % uint32(len(endpoints))
return endpoints[index]
}
// pipe handles bidirectional data transfer between connections
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) {
var wg sync.WaitGroup
wg.Add(2)
// Copy data from client to target (using the buffered reader)
go func() {
defer wg.Done()
defer func() {
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}()
// Use a large buffer for better performance
buf := make([]byte, 32*1024)
_, err := io.CopyBuffer(targetConn, clientReader, buf)
if err != nil && err != io.EOF {
logger.Debug("Copy client->target error: %v", err)
}
}()
// Copy data from target to client
go func() {
defer wg.Done()
defer func() {
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}()
// Use a large buffer for better performance
buf := make([]byte, 32*1024)
_, err := io.CopyBuffer(clientConn, targetConn, buf)
if err != nil && err != io.EOF {
logger.Debug("Copy target->client error: %v", err)
}
}()
wg.Wait()
}
// GetCacheStats returns cache statistics
func (p *SNIProxy) GetCacheStats() (int, int) {
return p.cache.ItemCount(), len(p.cache.Items())
}
// ClearCache clears all cached entries
func (p *SNIProxy) ClearCache() {
p.cache.Flush()
log.Println("Cache cleared")
}
// UpdateLocalSNIs updates the local SNIs and invalidates cache for changed domains
func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) {
newSNIs := make(map[string]struct{})
for _, domain := range fullDomains {
newSNIs[domain] = struct{}{}
// Invalidate any cached route for this domain
p.cache.Delete(domain)
}
// Update localSNIs
p.localSNIsLock.Lock()
removed := make([]string, 0)
for sni := range p.localSNIs {
if _, stillLocal := newSNIs[sni]; !stillLocal {
removed = append(removed, sni)
}
}
p.localSNIs = newSNIs
p.localSNIsLock.Unlock()
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
// Terminate tunnels for removed SNIs
if len(removed) > 0 {
p.activeTunnelsLock.Lock()
for _, sni := range removed {
if tunnels, ok := p.activeTunnels[sni]; ok {
for _, conn := range tunnels.conns {
conn.Close()
}
delete(p.activeTunnels, sni)
logger.Debug("Closed tunnels for SNI target change: %s", sni)
}
}
p.activeTunnelsLock.Unlock()
}
}

119
proxy/proxy_test.go Normal file
View File

@@ -0,0 +1,119 @@
package proxy
import (
"net"
"testing"
)
func TestBuildProxyProtocolHeader(t *testing.T) {
tests := []struct {
name string
clientAddr string
targetAddr string
expected string
}{
{
name: "IPv4 client and target",
clientAddr: "192.168.1.100:12345",
targetAddr: "10.0.0.1:443",
expected: "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n",
},
{
name: "IPv6 client and target",
clientAddr: "[2001:db8::1]:12345",
targetAddr: "[2001:db8::2]:443",
expected: "PROXY TCP6 2001:db8::1 2001:db8::2 12345 443\r\n",
},
{
name: "IPv4 client with IPv6 loopback target",
clientAddr: "192.168.1.100:12345",
targetAddr: "[::1]:443",
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
},
{
name: "IPv4 client with IPv6 target",
clientAddr: "192.168.1.100:12345",
targetAddr: "[2001:db8::2]:443",
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
},
{
name: "IPv6 client with IPv4 target",
clientAddr: "[2001:db8::1]:12345",
targetAddr: "10.0.0.1:443",
expected: "PROXY TCP6 2001:db8::1 ::ffff:10.0.0.1 12345 443\r\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientTCP, err := net.ResolveTCPAddr("tcp", tt.clientAddr)
if err != nil {
t.Fatalf("Failed to resolve client address: %v", err)
}
targetTCP, err := net.ResolveTCPAddr("tcp", tt.targetAddr)
if err != nil {
t.Fatalf("Failed to resolve target address: %v", err)
}
result := buildProxyProtocolHeader(clientTCP, targetTCP)
if result != tt.expected {
t.Errorf("Expected %q, got %q", tt.expected, result)
}
})
}
}
func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
// Test with non-TCP address type
clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}
targetAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443}
result := buildProxyProtocolHeader(clientAddr, targetAddr)
expected := "PROXY UNKNOWN\r\n"
if result != expected {
t.Errorf("Expected %q, got %q", expected, result)
}
}
func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
proxy, err := NewSNIProxy(8443, "", "", "127.0.0.1", 443, nil, true, nil)
if err != nil {
t.Fatalf("Failed to create SNI proxy: %v", err)
}
// Test IPv4 case
proxyInfo := &ProxyProtocolInfo{
Protocol: "TCP4",
SrcIP: "10.0.0.1",
DestIP: "192.168.1.100",
SrcPort: 12345,
DestPort: 443,
}
targetAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
header := proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
expected := "PROXY TCP4 10.0.0.1 127.0.0.1 12345 8080\r\n"
if header != expected {
t.Errorf("Expected header '%s', got '%s'", expected, header)
}
// Test IPv6 case
proxyInfo = &ProxyProtocolInfo{
Protocol: "TCP6",
SrcIP: "2001:db8::1",
DestIP: "2001:db8::2",
SrcPort: 12345,
DestPort: 443,
}
targetAddr, _ = net.ResolveTCPAddr("tcp6", "[::1]:8080")
header = proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
expected = "PROXY TCP6 2001:db8::1 ::1 12345 8080\r\n"
if header != expected {
t.Errorf("Expected header '%s', got '%s'", expected, header)
}
}

965
relay/relay.go Normal file
View File

@@ -0,0 +1,965 @@
package relay
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
"github.com/fosrl/gerbil/logger"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type EncryptedHolePunchMessage struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}
type HolePunchMessage struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
Token string `json:"token"`
}
type ClientEndpoint struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
Token string `json:"token"`
IP string `json:"ip"`
Port int `json:"port"`
Timestamp int64 `json:"timestamp"`
ReachableAt string `json:"reachableAt"`
PublicKey string `json:"publicKey"`
}
// Updated to support multiple destination peers
type ProxyMapping struct {
Destinations []PeerDestination `json:"destinations"`
LastUsed time.Time `json:"-"` // Not serialized, used for cleanup
}
type PeerDestination struct {
DestinationIP string `json:"destinationIP"`
DestinationPort int `json:"destinationPort"`
}
type DestinationConn struct {
conn *net.UDPConn
lastUsed time.Time
}
// Type for storing WireGuard handshake information
type WireGuardSession struct {
ReceiverIndex uint32
SenderIndex uint32
DestAddr *net.UDPAddr
LastSeen time.Time
}
// Type for tracking bidirectional communication patterns to rebuild sessions
type CommunicationPattern struct {
FromClient *net.UDPAddr // The client address
ToDestination *net.UDPAddr // The destination address
ClientIndex uint32 // The receiver index seen from client
DestIndex uint32 // The receiver index seen from destination
LastFromClient time.Time // Last packet from client to destination
LastFromDest time.Time // Last packet from destination to client
PacketCount int // Number of packets observed
}
type InitialMappings struct {
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
}
// Packet is a simple struct to hold the packet data and sender info.
type Packet struct {
data []byte
remoteAddr *net.UDPAddr
n int
}
// WireGuard message types
const (
WireGuardMessageTypeHandshakeInitiation = 1
WireGuardMessageTypeHandshakeResponse = 2
WireGuardMessageTypeCookieReply = 3
WireGuardMessageTypeTransportData = 4
)
// --- End Types ---
// bufferPool allows reusing buffers to reduce allocations.
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 1500)
},
}
// UDPProxyServer has a channel for incoming packets.
type UDPProxyServer struct {
addr string
serverURL string
conn *net.UDPConn
proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port"
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
privateKey wgtypes.Key
packetChan chan Packet
// Session tracking for WireGuard peers
// Key format: "senderIndex:receiverIndex"
wgSessions sync.Map
// Communication pattern tracking for rebuilding sessions
// Key format: "clientIP:clientPort-destIP:destPort"
commPatterns sync.Map
// ReachableAt is the URL where this server can be reached
ReachableAt string
}
// NewUDPProxyServer initializes the server with a buffered packet channel.
func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
return &UDPProxyServer{
addr: addr,
serverURL: serverURL,
privateKey: privateKey,
packetChan: make(chan Packet, 1000),
ReachableAt: reachableAt,
}
}
// Start sets up the UDP listener, worker pool, and begins reading packets.
func (s *UDPProxyServer) Start() error {
// Fetch initial mappings.
if err := s.fetchInitialMappings(); err != nil {
return fmt.Errorf("failed to fetch initial mappings: %v", err)
}
udpAddr, err := net.ResolveUDPAddr("udp", s.addr)
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
s.conn = conn
logger.Info("UDP server listening on %s", s.addr)
// Start a fixed number of worker goroutines.
workerCount := 10 // TODO: Make this configurable or pick it better!
for i := 0; i < workerCount; i++ {
go s.packetWorker()
}
// Start the goroutine that reads packets from the UDP socket.
go s.readPackets()
// Start the idle connection cleanup routine.
go s.cleanupIdleConnections()
// Start the session cleanup routine
go s.cleanupIdleSessions()
// Start the proxy mapping cleanup routine
go s.cleanupIdleProxyMappings()
// Start the communication pattern cleanup routine
go s.cleanupIdleCommunicationPatterns()
return nil
}
func (s *UDPProxyServer) Stop() {
s.conn.Close()
}
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
func (s *UDPProxyServer) readPackets() {
for {
buf := bufferPool.Get().([]byte)
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
if err != nil {
logger.Error("Error reading UDP packet: %v", err)
continue
}
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
}
}
// packetWorker processes incoming packets from the channel.
func (s *UDPProxyServer) packetWorker() {
for packet := range s.packetChan {
// Determine packet type by inspecting the first byte.
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
// Process as a WireGuard packet.
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
} else {
// Process as an encrypted hole punch message
var encMsg EncryptedHolePunchMessage
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
logger.Error("Error unmarshaling encrypted message: %v", err)
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
if encMsg.EphemeralPublicKey == "" {
logger.Error("Received malformed message without ephemeral key")
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
// This appears to be an encrypted message
decryptedData, err := s.decryptMessage(encMsg)
if err != nil {
logger.Error("Failed to decrypt message: %v", err)
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
// Process the decrypted hole punch message
var msg HolePunchMessage
if err := json.Unmarshal(decryptedData, &msg); err != nil {
logger.Error("Error unmarshaling decrypted message: %v", err)
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
endpoint := ClientEndpoint{
NewtID: msg.NewtID,
OlmID: msg.OlmID,
Token: msg.Token,
IP: packet.remoteAddr.IP.String(),
Port: packet.remoteAddr.Port,
Timestamp: time.Now().Unix(),
ReachableAt: s.ReachableAt,
PublicKey: s.privateKey.PublicKey().String(),
}
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port)
s.notifyServer(endpoint)
s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment
}
// Return the buffer to the pool for reuse.
bufferPool.Put(packet.data[:1500])
}
}
// decryptMessage decrypts the message using the server's private key
func (s *UDPProxyServer) decryptMessage(encMsg EncryptedHolePunchMessage) ([]byte, error) {
// Parse the ephemeral public key
ephPubKey, err := wgtypes.ParseKey(encMsg.EphemeralPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse ephemeral public key: %v", err)
}
// Use X25519 for key exchange instead of ScalarMult
sharedSecret, err := curve25519.X25519(s.privateKey[:], ephPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create the AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Verify nonce size
if len(encMsg.Nonce) != aead.NonceSize() {
return nil, fmt.Errorf("invalid nonce size")
}
// Decrypt the ciphertext
plaintext, err := aead.Open(nil, encMsg.Nonce, encMsg.Ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt message: %v", err)
}
return plaintext, nil
}
func (s *UDPProxyServer) fetchInitialMappings() error {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.privateKey.PublicKey().String())))
resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body)
if err != nil {
return fmt.Errorf("failed to fetch mappings: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("server returned non-OK status: %d, body: %s",
resp.StatusCode, string(body))
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %v", err)
}
logger.Info("Received initial mappings: %s", string(data))
var initialMappings InitialMappings
if err := json.Unmarshal(data, &initialMappings); err != nil {
return fmt.Errorf("failed to unmarshal initial mappings: %v", err)
}
// Store mappings in our sync.Map.
for key, mapping := range initialMappings.Mappings {
// Initialize LastUsed timestamp for initial mappings
mapping.LastUsed = time.Now()
s.proxyMappings.Store(key, mapping)
}
logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings))
return nil
}
// Extract WireGuard message indices
func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) {
if len(packet) < 12 {
return 0, 0, false
}
messageType := packet[0]
if messageType == WireGuardMessageTypeHandshakeInitiation {
// Handshake initiation: extract sender index at offset 4
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
return 0, senderIndex, true
} else if messageType == WireGuardMessageTypeHandshakeResponse {
// Handshake response: extract sender index at offset 4 and receiver index at offset 8
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
receiverIndex := binary.LittleEndian.Uint32(packet[8:12])
return receiverIndex, senderIndex, true
} else if messageType == WireGuardMessageTypeTransportData {
// Transport data: extract receiver index at offset 4
receiverIndex := binary.LittleEndian.Uint32(packet[4:8])
return receiverIndex, 0, true
}
return 0, 0, false
}
// Updated to handle multi-peer WireGuard communication
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
if len(packet) == 0 {
logger.Error("Received empty packet")
return
}
messageType := packet[0]
receiverIndex, senderIndex, ok := extractWireGuardIndices(packet)
if !ok {
logger.Error("Failed to extract WireGuard indices")
return
}
key := remoteAddr.String()
mappingObj, ok := s.proxyMappings.Load(key)
if !ok {
logger.Error("No proxy mapping found for %s", key)
return
}
proxyMapping := mappingObj.(ProxyMapping)
// Update the last used timestamp and store it back
proxyMapping.LastUsed = time.Now()
s.proxyMappings.Store(key, proxyMapping)
// Handle different WireGuard message types
switch messageType {
case WireGuardMessageTypeHandshakeInitiation:
// Initial handshake: forward to all peers
logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations)
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward handshake initiation: %v", err)
}
}
case WireGuardMessageTypeHandshakeResponse:
// Received handshake response: establish session mapping
logger.Debug("Received handshake response with receiver index %d and sender index %d from %s",
receiverIndex, senderIndex, remoteAddr)
// Create a session key for the peer that sent the initial handshake
sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex)
// Store the session information
s.wgSessions.Store(sessionKey, &WireGuardSession{
ReceiverIndex: receiverIndex,
SenderIndex: senderIndex,
DestAddr: remoteAddr,
LastSeen: time.Now(),
})
// Forward the response to the original sender
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward handshake response: %v", err)
}
}
case WireGuardMessageTypeTransportData:
// Data packet: forward only to the established session peer
// logger.Debug("Received transport data with receiver index %d from %s", receiverIndex, remoteAddr)
// Look up the session based on the receiver index
var destAddr *net.UDPAddr
// First check for existing sessions to see if we know where to send this packet
s.wgSessions.Range(func(k, v interface{}) bool {
session := v.(*WireGuardSession)
if session.SenderIndex == receiverIndex {
// Found matching session
destAddr = session.DestAddr
// Update last seen time
session.LastSeen = time.Now()
s.wgSessions.Store(k, session)
return false // stop iteration
}
return true // continue iteration
})
if destAddr != nil {
// We found a specific peer to forward to
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
return
}
// Track communication pattern for session rebuilding
s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true)
_, err = conn.Write(packet)
if err != nil {
logger.Debug("Failed to forward transport data: %v", err)
}
} else {
// No known session, fall back to forwarding to all peers
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
// Track communication pattern for session rebuilding
s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true)
_, err = conn.Write(packet)
if err != nil {
logger.Debug("Failed to forward transport data: %v", err)
}
}
}
default:
// Other packet types (like cookie reply)
logger.Debug("Forwarding WireGuard packet type %d from %s", messageType, remoteAddr)
// Forward to all peers
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward WireGuard packet: %v", err)
}
}
}
}
func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) (*net.UDPConn, error) {
key := destAddr.String() + "-" + remoteAddr.String()
// Check if we have an existing connection
if conn, ok := s.connections.Load(key); ok {
destConn := conn.(*DestinationConn)
destConn.lastUsed = time.Now()
return destConn.conn, nil
}
// Create new connection
newConn, err := net.DialUDP("udp", nil, destAddr)
if err != nil {
return nil, fmt.Errorf("failed to create UDP connection: %v", err)
}
// Store the new connection
s.connections.Store(key, &DestinationConn{
conn: newConn,
lastUsed: time.Now(),
})
// Start a goroutine to handle responses
go s.handleResponses(newConn, destAddr, remoteAddr)
return newConn, nil
}
func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) {
buffer := make([]byte, 1500)
for {
n, err := conn.Read(buffer)
if err != nil {
logger.Debug("Error reading response from %s: %v", destAddr.String(), err)
return
}
// Process the response to track sessions if it's a WireGuard packet
if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 {
receiverIndex, senderIndex, ok := extractWireGuardIndices(buffer[:n])
if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse {
// Store the session mapping for the handshake response
sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex)
s.wgSessions.Store(sessionKey, &WireGuardSession{
ReceiverIndex: receiverIndex,
SenderIndex: senderIndex,
DestAddr: destAddr,
LastSeen: time.Now(),
})
logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String())
} else if ok && buffer[0] == WireGuardMessageTypeTransportData {
// Track communication pattern for session rebuilding (reverse direction)
s.trackCommunicationPattern(destAddr, remoteAddr, receiverIndex, false)
}
}
// Forward the response back through the main listener
_, err = s.conn.WriteToUDP(buffer[:n], remoteAddr)
if err != nil {
logger.Error("Failed to forward response: %v", err)
}
}
}
// Add a cleanup method to periodically remove idle connections
func (s *UDPProxyServer) cleanupIdleConnections() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
now := time.Now()
s.connections.Range(func(key, value interface{}) bool {
destConn := value.(*DestinationConn)
if now.Sub(destConn.lastUsed) > 10*time.Minute {
destConn.conn.Close()
s.connections.Delete(key)
}
return true
})
}
}
// New method to periodically remove idle sessions
func (s *UDPProxyServer) cleanupIdleSessions() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
now := time.Now()
s.wgSessions.Range(func(key, value interface{}) bool {
session := value.(*WireGuardSession)
if now.Sub(session.LastSeen) > 15*time.Minute {
s.wgSessions.Delete(key)
logger.Debug("Removed idle session: %s", key)
}
return true
})
}
}
// New method to periodically remove idle proxy mappings
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
ticker := time.NewTicker(10 * time.Minute)
for range ticker.C {
now := time.Now()
s.proxyMappings.Range(func(key, value interface{}) bool {
mapping := value.(ProxyMapping)
// Remove mappings that haven't been used in 30 minutes
if now.Sub(mapping.LastUsed) > 30*time.Minute {
s.proxyMappings.Delete(key)
logger.Debug("Removed idle proxy mapping: %s", key)
}
return true
})
}
}
func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) {
logger.Debug("notifyServer called with endpoint: IP=%s, Port=%d", endpoint.IP, endpoint.Port)
jsonData, err := json.Marshal(endpoint)
if err != nil {
logger.Error("Failed to marshal endpoint data: %v", err)
return
}
resp, err := http.Post(s.serverURL+"/gerbil/update-hole-punch", "application/json", bytes.NewBuffer(jsonData))
if err != nil {
logger.Error("Failed to notify server: %v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("Server returned non-OK status: %d, body: %s",
resp.StatusCode, string(body))
return
}
// Parse the proxy mapping response
var mapping ProxyMapping
if err := json.NewDecoder(resp.Body).Decode(&mapping); err != nil {
logger.Error("Failed to decode proxy mapping: %v", err)
return
}
logger.Debug("Received proxy mapping from server: %v", mapping)
// Store the mapping with current timestamp
key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port)
logger.Debug("About to store proxy mapping with key: %s (from endpoint IP=%s, Port=%d)", key, endpoint.IP, endpoint.Port)
mapping.LastUsed = time.Now()
s.proxyMappings.Store(key, mapping)
logger.Debug("Stored proxy mapping for %s with %d destinations (timestamp: %v)", key, len(mapping.Destinations), mapping.LastUsed)
}
// Updated to support multiple destinations
func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, destinations []PeerDestination) {
key := fmt.Sprintf("%s:%d", sourceIP, sourcePort)
mapping := ProxyMapping{
Destinations: destinations,
LastUsed: time.Now(),
}
s.proxyMappings.Store(key, mapping)
}
// OnPeerAdded clears connections and sessions for a specific WireGuard IP to allow re-establishment
func (s *UDPProxyServer) OnPeerAdded(wgIP string) {
logger.Info("Clearing connections for added peer with WG IP: %s", wgIP)
s.clearConnectionsForWGIP(wgIP)
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
// s.clearProxyMappingsForWGIP(wgIP)
}
// OnPeerRemoved clears connections and sessions for a specific WireGuard IP
func (s *UDPProxyServer) OnPeerRemoved(wgIP string) {
logger.Info("Clearing connections for removed peer with WG IP: %s", wgIP)
s.clearConnectionsForWGIP(wgIP)
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
// s.clearProxyMappingsForWGIP(wgIP)
}
// clearConnectionsForWGIP removes all connections associated with a specific WireGuard IP
func (s *UDPProxyServer) clearConnectionsForWGIP(wgIP string) {
var keysToDelete []string
s.connections.Range(func(key, value interface{}) bool {
keyStr := key.(string)
destConn := value.(*DestinationConn)
// Connection keys are in format "destAddr-remoteAddr"
// Check if either destination or remote address contains the WG IP
if containsIP(keyStr, wgIP) {
keysToDelete = append(keysToDelete, keyStr)
destConn.conn.Close()
logger.Debug("Closing connection for WG IP %s: %s", wgIP, keyStr)
}
return true
})
// Delete the connections
for _, key := range keysToDelete {
s.connections.Delete(key)
}
logger.Info("Cleared %d connections for WG IP: %s", len(keysToDelete), wgIP)
}
// clearSessionsForWGIP removes all WireGuard sessions associated with a specific WireGuard IP
func (s *UDPProxyServer) clearSessionsForIP(ip string) {
var keysToDelete []string
s.wgSessions.Range(func(key, value interface{}) bool {
keyStr := key.(string)
session := value.(*WireGuardSession)
// Check if the session's destination address contains the WG IP
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
keysToDelete = append(keysToDelete, keyStr)
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
}
return true
})
// Delete the sessions
for _, key := range keysToDelete {
s.wgSessions.Delete(key)
}
logger.Info("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip)
}
// // clearProxyMappingsForWGIP removes all proxy mappings that have destinations pointing to a specific WireGuard IP
// func (s *UDPProxyServer) clearProxyMappingsForWGIP(wgIP string) {
// var keysToDelete []string
// s.proxyMappings.Range(func(key, value interface{}) bool {
// keyStr := key.(string)
// mapping := value.(ProxyMapping)
// // Check if any destination in the mapping contains the WG IP
// for _, dest := range mapping.Destinations {
// if dest.DestinationIP == wgIP {
// keysToDelete = append(keysToDelete, keyStr)
// logger.Debug("Marking proxy mapping for deletion for WG IP %s: %s -> %s:%d", wgIP, keyStr, dest.DestinationIP, dest.DestinationPort)
// break // Found one destination, no need to check others in this mapping
// }
// }
// return true
// })
// // Delete the proxy mappings
// for _, key := range keysToDelete {
// s.proxyMappings.Delete(key)
// logger.Debug("Deleted proxy mapping: %s", key)
// }
// logger.Info("Cleared %d proxy mappings for WG IP: %s", len(keysToDelete), wgIP)
// }
// containsIP checks if a connection key string contains the specified IP address
func containsIP(connectionKey, ip string) bool {
// Connection keys are in format "destIP:destPort-remoteIP:remotePort"
// Check if the IP appears at the beginning (destination) or after the dash (remote)
ipWithColon := ip + ":"
// Check if connection key starts with the IP (destination address)
if len(connectionKey) >= len(ipWithColon) && connectionKey[:len(ipWithColon)] == ipWithColon {
return true
}
// Check if connection key contains the IP after a dash (remote address)
dashIndex := -1
for i := 0; i < len(connectionKey); i++ {
if connectionKey[i] == '-' {
dashIndex = i
break
}
}
if dashIndex != -1 && dashIndex+1 < len(connectionKey) {
remainingPart := connectionKey[dashIndex+1:]
if len(remainingPart) >= len(ip)+1 && remainingPart[:len(ip)+1] == ipWithColon {
return true
}
}
return false
}
// UpdateDestinationInMappings updates all proxy mappings that contain the old destination with the new destination
// Returns the number of mappings that were updated
func (s *UDPProxyServer) UpdateDestinationInMappings(oldDest, newDest PeerDestination) int {
updatedCount := 0
s.proxyMappings.Range(func(key, value interface{}) bool {
keyStr := key.(string)
mapping := value.(ProxyMapping)
updated := false
// Check each destination in the mapping
for i, dest := range mapping.Destinations {
if dest.DestinationIP == oldDest.DestinationIP && dest.DestinationPort == oldDest.DestinationPort {
// Update this destination
mapping.Destinations[i] = newDest
updated = true
logger.Debug("Updated destination in mapping %s: %s:%d -> %s:%d",
keyStr, oldDest.DestinationIP, oldDest.DestinationPort,
newDest.DestinationIP, newDest.DestinationPort)
}
}
// If we updated any destinations, store the updated mapping back
if updated {
mapping.LastUsed = time.Now()
s.proxyMappings.Store(keyStr, mapping)
updatedCount++
}
return true // continue iteration
})
if updatedCount > 0 {
logger.Info("Updated %d proxy mappings from %s:%d to %s:%d",
updatedCount, oldDest.DestinationIP, oldDest.DestinationPort,
newDest.DestinationIP, newDest.DestinationPort)
}
return updatedCount
}
// trackCommunicationPattern tracks bidirectional communication patterns to rebuild sessions
func (s *UDPProxyServer) trackCommunicationPattern(fromAddr, toAddr *net.UDPAddr, receiverIndex uint32, fromClient bool) {
var clientAddr, destAddr *net.UDPAddr
var clientIndex, destIndex uint32
if fromClient {
clientAddr = fromAddr
destAddr = toAddr
clientIndex = receiverIndex
destIndex = 0 // We don't know the destination index yet
} else {
clientAddr = toAddr
destAddr = fromAddr
clientIndex = 0 // We don't know the client index yet
destIndex = receiverIndex
}
patternKey := fmt.Sprintf("%s-%s", clientAddr.String(), destAddr.String())
now := time.Now()
if existingPattern, ok := s.commPatterns.Load(patternKey); ok {
pattern := existingPattern.(*CommunicationPattern)
// Update the pattern
if fromClient {
pattern.LastFromClient = now
if pattern.ClientIndex == 0 {
pattern.ClientIndex = clientIndex
}
} else {
pattern.LastFromDest = now
if pattern.DestIndex == 0 {
pattern.DestIndex = destIndex
}
}
pattern.PacketCount++
s.commPatterns.Store(patternKey, pattern)
// Check if we have bidirectional communication and can rebuild a session
s.tryRebuildSession(pattern)
} else {
// Create new pattern
pattern := &CommunicationPattern{
FromClient: clientAddr,
ToDestination: destAddr,
ClientIndex: clientIndex,
DestIndex: destIndex,
PacketCount: 1,
}
if fromClient {
pattern.LastFromClient = now
} else {
pattern.LastFromDest = now
}
s.commPatterns.Store(patternKey, pattern)
}
}
// tryRebuildSession attempts to rebuild a WireGuard session from communication patterns
func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
// Check if we have bidirectional communication within a reasonable time window
timeDiff := pattern.LastFromClient.Sub(pattern.LastFromDest)
if timeDiff < 0 {
timeDiff = -timeDiff
}
// Only rebuild if we have recent bidirectional communication and both indices
if timeDiff < 30*time.Second && pattern.ClientIndex != 0 && pattern.DestIndex != 0 && pattern.PacketCount >= 4 {
// Create session mapping: client's index maps to destination
sessionKey := fmt.Sprintf("%d:%d", pattern.DestIndex, pattern.ClientIndex)
// Check if we already have this session
if _, exists := s.wgSessions.Load(sessionKey); !exists {
session := &WireGuardSession{
ReceiverIndex: pattern.DestIndex,
SenderIndex: pattern.ClientIndex,
DestAddr: pattern.ToDestination,
LastSeen: time.Now(),
}
s.wgSessions.Store(sessionKey, session)
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
}
}
}
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
ticker := time.NewTicker(10 * time.Minute)
for range ticker.C {
now := time.Now()
s.commPatterns.Range(func(key, value interface{}) bool {
pattern := value.(*CommunicationPattern)
// Get the most recent activity
lastActivity := pattern.LastFromClient
if pattern.LastFromDest.After(lastActivity) {
lastActivity = pattern.LastFromDest
}
// Remove patterns that haven't had activity in 20 minutes
if now.Sub(lastActivity) > 20*time.Minute {
s.commPatterns.Delete(key)
logger.Debug("Removed idle communication pattern: %s", key)
}
return true
})
}
}