mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-08 05:56:40 +00:00
Compare commits
60 Commits
1.0.0-beta
...
1.2.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8a0f92c9b | ||
|
|
7040a9436e | ||
|
|
04361242fe | ||
|
|
554b1d55dc | ||
|
|
47589570c9 | ||
|
|
9f5b8dea26 | ||
|
|
f6a1e1e27c | ||
|
|
f983a8f141 | ||
|
|
efce3cb0b2 | ||
|
|
6eeebd81b2 | ||
|
|
c970fd5a18 | ||
|
|
09bd02456d | ||
|
|
c24537af36 | ||
|
|
9de3f14799 | ||
|
|
0908f75f5f | ||
|
|
10958f8c55 | ||
|
|
b1840fd5c3 | ||
|
|
1df5eb19ff | ||
|
|
f71f183886 | ||
|
|
8922ca9736 | ||
|
|
38483f4a26 | ||
|
|
78c768e497 | ||
|
|
fc7df8a530 | ||
|
|
50b42059ac | ||
|
|
825f7fcf60 | ||
|
|
8c8ec72b40 | ||
|
|
c61b7fc4fb | ||
|
|
96e3376147 | ||
|
|
e47a7c80d1 | ||
|
|
f1e373f2d8 | ||
|
|
ef4d0db475 | ||
|
|
b6b97f5ed3 | ||
|
|
dff267a42e | ||
|
|
bb98db7f5e | ||
|
|
f1016200b3 | ||
|
|
f1ab8094cf | ||
|
|
ad2bc0d397 | ||
|
|
a78d141ca3 | ||
|
|
10b1ad2a5a | ||
|
|
8a9f29043a | ||
|
|
05c9d851f4 | ||
|
|
c9a6b85e1d | ||
|
|
a16021cd86 | ||
|
|
9506b545f4 | ||
|
|
17b87e6707 | ||
|
|
cba4dc646d | ||
|
|
88be6d133d | ||
|
|
34a80c6411 | ||
|
|
6565fdbe62 | ||
|
|
993f5f86c5 | ||
|
|
093a4c21f2 | ||
|
|
f7c0bb9135 | ||
|
|
a145b77f79 | ||
|
|
7b3f7d2b12 | ||
|
|
9c5ddcdfb8 | ||
|
|
32176c74a0 | ||
|
|
aa4f4ebfab | ||
|
|
bab8630756 | ||
|
|
24e993ee41 | ||
|
|
5d4faaff65 |
@@ -6,4 +6,5 @@ README.md
|
||||
Makefile
|
||||
public/
|
||||
LICENSE
|
||||
CONTRIBUTING.md
|
||||
CONTRIBUTING.md
|
||||
.git
|
||||
|
||||
40
.github/dependabot.yml
vendored
Normal file
40
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "gomod"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
groups:
|
||||
dev-patch-updates:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "patch"
|
||||
dev-minor-updates:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
prod-patch-updates:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "patch"
|
||||
prod-minor-updates:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
|
||||
- package-ecosystem: "docker"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
groups:
|
||||
patch-updates:
|
||||
update-types:
|
||||
- "patch"
|
||||
minor-updates:
|
||||
update-types:
|
||||
- "minor"
|
||||
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
52
.github/workflows/cicd.yml
vendored
Normal file
52
.github/workflows/cicd.yml
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
name: CI/CD Pipeline
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
|
||||
- name: Extract tag name
|
||||
id: get-tag
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.25
|
||||
|
||||
- name: Build and push Docker images
|
||||
run: |
|
||||
TAG=${{ env.TAG }}
|
||||
make docker-build-release tag=$TAG
|
||||
|
||||
- name: Build binaries
|
||||
run: |
|
||||
make go-build-release
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
|
||||
28
.github/workflows/test.yml
vendored
Normal file
28
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Run Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
|
||||
- name: Build go
|
||||
run: go build
|
||||
|
||||
- name: Build Docker image
|
||||
run: make build
|
||||
|
||||
- name: Build binaries
|
||||
run: make go-build-release
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1 +1,3 @@
|
||||
gerbil
|
||||
gerbil
|
||||
.DS_Store
|
||||
bin/
|
||||
1
.go-version
Normal file
1
.go-version
Normal file
@@ -0,0 +1 @@
|
||||
1.25
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.23.1-alpine AS builder
|
||||
FROM golang:1.25-alpine AS builder
|
||||
|
||||
# Set the working directory inside the container
|
||||
WORKDIR /app
|
||||
@@ -16,7 +16,9 @@ COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /gerbil
|
||||
|
||||
# Start a new stage from scratch
|
||||
FROM ubuntu:22.04 AS runner
|
||||
FROM ubuntu:24.04 AS runner
|
||||
|
||||
RUN apt-get update && apt-get install -y iptables iproute2 && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy the pre-built binary file from the previous stage and the entrypoint script
|
||||
COPY --from=builder /gerbil /usr/local/bin/
|
||||
|
||||
14
Makefile
14
Makefile
@@ -1,6 +1,14 @@
|
||||
|
||||
all: build push
|
||||
|
||||
docker-build-release:
|
||||
@if [ -z "$(tag)" ]; then \
|
||||
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
|
||||
exit 1; \
|
||||
fi
|
||||
docker buildx build --platform linux/arm64,linux/amd64 -t fosrl/gerbil:latest -f Dockerfile --push .
|
||||
docker buildx build --platform linux/arm64,linux/amd64 -t fosrl/gerbil:$(tag) -f Dockerfile --push .
|
||||
|
||||
build:
|
||||
docker build -t fosrl/gerbil:latest .
|
||||
|
||||
@@ -13,5 +21,9 @@ test:
|
||||
local:
|
||||
CGO_ENABLED=0 GOOS=linux go build -o gerbil
|
||||
|
||||
go-build-release:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/gerbil_linux_arm64
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/gerbil_linux_amd64
|
||||
|
||||
clean:
|
||||
rm gerbil
|
||||
rm gerbil
|
||||
|
||||
67
README.md
67
README.md
@@ -4,17 +4,10 @@ Gerbil is a simple [WireGuard](https://www.wireguard.com/) interface management
|
||||
|
||||
### Installation and Documentation
|
||||
|
||||
Gerbil can be used standalone with your own API, a static JSON file, or with Pangolin and Newt as part of the larger system. See documentation below:
|
||||
Gerbil works with Pangolin, Newt, and Olm as part of the larger system. See documentation below:
|
||||
|
||||
- [Installation Instructions](https://docs.fossorial.io)
|
||||
- [Full Documentation](https://docs.fossorial.io)
|
||||
|
||||
## Preview
|
||||
|
||||
<img src="public/screenshots/preview.png" alt="Preview"/>
|
||||
|
||||
_Sample output of a Gerbil container connected to Pangolin and terminating various peers._
|
||||
|
||||
## Key Functions
|
||||
|
||||
### Setup WireGuard
|
||||
@@ -29,6 +22,24 @@ Gerbil will create the peers defined in the config on the WireGuard interface. T
|
||||
|
||||
Bytes transmitted in and out of each peer are collected every 10 seconds, and incremental usage is reported via the "reportBandwidthTo" endpoint. This can be used to track data usage of each peer on the remote server.
|
||||
|
||||
### Handle client relaying
|
||||
|
||||
Gerbil listens on port 21820 for incoming UDP hole punch packets to orchestrate NAT hole punching between olm and newt clients. Additionally, it handles relaying data through the gerbil server down to the newt. This is accomplished by scanning each packet for headers and handling them appropriately.
|
||||
|
||||
### SNI Proxy
|
||||
|
||||
Gerbil includes an SNI (Server Name Indication) proxy that enables intelligent routing of HTTPS traffic between Pangolin nodes. When a TLS connection comes in, the proxy extracts the hostname from the SNI extension and queries Pangolin to determine the correct routing destination. This allows seamless routing of web traffic through the WireGuard mesh network:
|
||||
|
||||
- If the hostname is configured for local handling (via local overrides or local SNIs), traffic is routed to the local proxy
|
||||
- Otherwise, the proxy queries Pangolin's routing API to determine which node should handle the traffic
|
||||
- Supports caching of routing decisions to improve performance
|
||||
- Handles connection pooling and graceful shutdown
|
||||
- Optional PROXY protocol v1 support to preserve original client IP addresses when forwarding to downstream proxies (HAProxy, Nginx, etc.)
|
||||
|
||||
The PROXY protocol allows downstream proxies to know the real client IP address instead of seeing the SNI proxy's IP. When enabled with `--proxy-protocol`, the SNI proxy will prepend a PROXY protocol header to each connection containing the original client's IP and port information.
|
||||
|
||||
In single node (self hosted) Pangolin deployments this can be bypassed by using port 443:443 to route to Traefik instead of the SNI proxy at 8443.
|
||||
|
||||
## CLI Args
|
||||
|
||||
- `reachableAt`: How should the remote server reach Gerbil's API?
|
||||
@@ -38,10 +49,36 @@ Bytes transmitted in and out of each peer are collected every 10 seconds, and in
|
||||
|
||||
Note: You must use either `config` or `remoteConfig` to configure WireGuard.
|
||||
|
||||
- `reportBandwidthTo` (optional): Remote HTTP endpoint to send peer bandwidth data
|
||||
- `reportBandwidthTo` (optional): **DEPRECATED** - Use `remoteConfig` instead. Remote HTTP endpoint to send peer bandwidth data
|
||||
- `interface` (optional): Name of the WireGuard interface created by Gerbil. Default: `wg0`
|
||||
- `listen` (optional): Port to listen on for HTTP server. Default: `3003`
|
||||
- `log-level` (optional): The log level to use. Default: INFO
|
||||
- `listen` (optional): Port to listen on for HTTP server. Default: `:3003`
|
||||
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: `INFO`
|
||||
- `mtu` (optional): MTU of the WireGuard interface. Default: `1280`
|
||||
- `notify` (optional): URL to notify on peer changes
|
||||
- `sni-port` (optional): Port for the SNI proxy to listen on. Default: `8443`
|
||||
- `local-proxy` (optional): Address for local proxy when routing local traffic. Default: `localhost`
|
||||
- `local-proxy-port` (optional): Port for local proxy when routing local traffic. Default: `443`
|
||||
- `local-overrides` (optional): Comma-separated list of domain names that should always be routed to the local proxy
|
||||
- `proxy-protocol` (optional): Enable PROXY protocol v1 for preserving client IP addresses when forwarding to downstream proxies. Default: `false`
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All CLI arguments can also be provided via environment variables:
|
||||
|
||||
- `INTERFACE`: Name of the WireGuard interface
|
||||
- `CONFIG`: Path to local configuration file
|
||||
- `REMOTE_CONFIG`: URL of the remote config server
|
||||
- `LISTEN`: Address to listen on for HTTP server
|
||||
- `GENERATE_AND_SAVE_KEY_TO`: Path to save generated private key
|
||||
- `REACHABLE_AT`: Endpoint of the HTTP server to tell remote config about
|
||||
- `LOG_LEVEL`: Log level (DEBUG, INFO, WARN, ERROR, FATAL)
|
||||
- `MTU`: MTU of the WireGuard interface
|
||||
- `NOTIFY_URL`: URL to notify on peer changes
|
||||
- `SNI_PORT`: Port for the SNI proxy to listen on
|
||||
- `LOCAL_PROXY`: Address for local proxy when routing local traffic
|
||||
- `LOCAL_PROXY_PORT`: Port for local proxy when routing local traffic
|
||||
- `LOCAL_OVERRIDES`: Comma-separated list of domain names that should always be routed to the local proxy
|
||||
- `PROXY_PROTOCOL`: Enable PROXY protocol v1 for preserving client IP addresses (true/false)
|
||||
|
||||
Example:
|
||||
|
||||
@@ -49,8 +86,7 @@ Example:
|
||||
./gerbil \
|
||||
--reachableAt=http://gerbil:3003 \
|
||||
--generateAndSaveKeyTo=/var/config/key \
|
||||
--remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config \
|
||||
--reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth
|
||||
--remoteConfig=http://pangolin:3001/api/v1/
|
||||
```
|
||||
|
||||
```yaml
|
||||
@@ -62,8 +98,7 @@ services:
|
||||
command:
|
||||
- --reachableAt=http://gerbil:3003
|
||||
- --generateAndSaveKeyTo=/var/config/key
|
||||
- --remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config
|
||||
- --reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth
|
||||
- --remoteConfig=http://pangolin:3001/api/v1/
|
||||
volumes:
|
||||
- ./config/:/var/config
|
||||
cap_add:
|
||||
@@ -71,6 +106,8 @@ services:
|
||||
- SYS_MODULE
|
||||
ports:
|
||||
- 51820:51820/udp
|
||||
- 21820:21820/udp
|
||||
- 443:8443/tcp # SNI proxy port
|
||||
```
|
||||
|
||||
## Build
|
||||
|
||||
14
go.mod
14
go.mod
@@ -1,10 +1,10 @@
|
||||
module github.com/fosrl/gerbil
|
||||
|
||||
go 1.23.1
|
||||
go 1.25
|
||||
|
||||
toolchain go1.23.2
|
||||
require (
|
||||
github.com/vishvananda/netlink v1.3.0
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
)
|
||||
|
||||
@@ -14,10 +14,10 @@ require (
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
golang.org/x/crypto v0.8.0 // indirect
|
||||
golang.org/x/net v0.9.0 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.10.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect
|
||||
)
|
||||
|
||||
25
go.sum
25
go.sum
@@ -8,21 +8,24 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
||||
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
||||
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
|
||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
|
||||
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
||||
|
||||
491
main.go
491
main.go
@@ -9,12 +9,17 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
"github.com/fosrl/gerbil/proxy"
|
||||
"github.com/fosrl/gerbil/relay"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -26,6 +31,10 @@ var (
|
||||
mtuInt int
|
||||
lastReadings = make(map[string]PeerReading)
|
||||
mu sync.Mutex
|
||||
wgMu sync.Mutex // Protects WireGuard operations
|
||||
notifyURL string
|
||||
proxyRelay *relay.UDPProxyServer
|
||||
proxySNI *proxy.SNIProxy
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
@@ -56,6 +65,31 @@ var (
|
||||
wgClient *wgctrl.Client
|
||||
)
|
||||
|
||||
// Add this new type at the top with other type definitions
|
||||
type ClientEndpoint struct {
|
||||
OlmID string `json:"olmId"`
|
||||
NewtID string `json:"newtId"`
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type HolePunchMessage struct {
|
||||
OlmID string `json:"olmId"`
|
||||
NewtID string `json:"newtId"`
|
||||
}
|
||||
|
||||
type ProxyMappingUpdate struct {
|
||||
OldDestination relay.PeerDestination `json:"oldDestination"`
|
||||
NewDestination relay.PeerDestination `json:"newDestination"`
|
||||
}
|
||||
|
||||
type UpdateDestinationsRequest struct {
|
||||
SourceIP string `json:"sourceIp"`
|
||||
SourcePort int `json:"sourcePort"`
|
||||
Destinations []relay.PeerDestination `json:"destinations"`
|
||||
}
|
||||
|
||||
func parseLogLevel(level string) logger.LogLevel {
|
||||
switch strings.ToUpper(level) {
|
||||
case "DEBUG":
|
||||
@@ -79,22 +113,32 @@ func main() {
|
||||
wgconfig WgConfig
|
||||
configFile string
|
||||
remoteConfigURL string
|
||||
reportBandwidthTo string
|
||||
generateAndSaveKeyTo string
|
||||
reachableAt string
|
||||
logLevel string
|
||||
mtu string
|
||||
sniProxyPort int
|
||||
localProxyAddr string
|
||||
localProxyPort int
|
||||
localOverridesStr string
|
||||
proxyProtocol bool
|
||||
)
|
||||
|
||||
interfaceName = os.Getenv("INTERFACE")
|
||||
configFile = os.Getenv("CONFIG")
|
||||
remoteConfigURL = os.Getenv("REMOTE_CONFIG")
|
||||
listenAddr = os.Getenv("LISTEN")
|
||||
reportBandwidthTo = os.Getenv("REPORT_BANDWIDTH_TO")
|
||||
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
||||
reachableAt = os.Getenv("REACHABLE_AT")
|
||||
logLevel = os.Getenv("LOG_LEVEL")
|
||||
mtu = os.Getenv("MTU")
|
||||
notifyURL = os.Getenv("NOTIFY_URL")
|
||||
|
||||
sniProxyPortStr := os.Getenv("SNI_PORT")
|
||||
localProxyAddr = os.Getenv("LOCAL_PROXY")
|
||||
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
|
||||
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
|
||||
proxyProtocolStr := os.Getenv("PROXY_PROTOCOL")
|
||||
|
||||
if interfaceName == "" {
|
||||
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
||||
@@ -103,14 +147,16 @@ func main() {
|
||||
flag.StringVar(&configFile, "config", "", "Path to local configuration file")
|
||||
}
|
||||
if remoteConfigURL == "" {
|
||||
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration")
|
||||
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL of the Pangolin server")
|
||||
}
|
||||
if listenAddr == "" {
|
||||
flag.StringVar(&listenAddr, "listen", ":3003", "Address to listen on")
|
||||
}
|
||||
if reportBandwidthTo == "" {
|
||||
flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "Address to listen on")
|
||||
}
|
||||
// DEPRECATED AND UNSED: reportBandwidthTo
|
||||
// allow reportBandwidthTo to be passed but dont do anything with it just thow it away
|
||||
reportBandwidthTo := ""
|
||||
flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "DEPRECATED: Use remoteConfig instead")
|
||||
|
||||
if generateAndSaveKeyTo == "" {
|
||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||
}
|
||||
@@ -123,6 +169,42 @@ func main() {
|
||||
if mtu == "" {
|
||||
flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface")
|
||||
}
|
||||
if notifyURL == "" {
|
||||
flag.StringVar(¬ifyURL, "notify", "", "URL to notify on peer changes")
|
||||
}
|
||||
|
||||
if sniProxyPortStr != "" {
|
||||
if port, err := strconv.Atoi(sniProxyPortStr); err == nil {
|
||||
sniProxyPort = port
|
||||
}
|
||||
}
|
||||
if sniProxyPortStr == "" {
|
||||
flag.IntVar(&sniProxyPort, "sni-port", 8443, "Port to listen on")
|
||||
}
|
||||
|
||||
if localProxyAddr == "" {
|
||||
flag.StringVar(&localProxyAddr, "local-proxy", "localhost", "Local proxy address")
|
||||
}
|
||||
|
||||
if localProxyPortStr != "" {
|
||||
if port, err := strconv.Atoi(localProxyPortStr); err == nil {
|
||||
localProxyPort = port
|
||||
}
|
||||
}
|
||||
if localProxyPortStr == "" {
|
||||
flag.IntVar(&localProxyPort, "local-proxy-port", 443, "Local proxy port")
|
||||
}
|
||||
if localOverridesStr != "" {
|
||||
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
|
||||
}
|
||||
|
||||
if proxyProtocolStr != "" {
|
||||
proxyProtocol = strings.ToLower(proxyProtocolStr) == "true"
|
||||
}
|
||||
if proxyProtocolStr == "" {
|
||||
flag.BoolVar(&proxyProtocol, "proxy-protocol", true, "Enable PROXY protocol v1 for preserving client IP")
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
logger.Init()
|
||||
@@ -143,6 +225,10 @@ func main() {
|
||||
logger.Fatal("You must provide either a config file or a remote config URL, not both")
|
||||
}
|
||||
|
||||
// clean up the reomte config URL for backwards compatibility
|
||||
remoteConfigURL = strings.TrimSuffix(remoteConfigURL, "/gerbil/get-config")
|
||||
remoteConfigURL = strings.TrimSuffix(remoteConfigURL, "/")
|
||||
|
||||
var key wgtypes.Key
|
||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
||||
if generateAndSaveKeyTo != "" {
|
||||
@@ -188,11 +274,17 @@ func main() {
|
||||
wgconfig.PrivateKey = key.String()
|
||||
}
|
||||
} else {
|
||||
wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to load configuration: %v", err)
|
||||
// loop until we get the config
|
||||
for wgconfig.PrivateKey == "" {
|
||||
logger.Info("Fetching remote config from %s", remoteConfigURL+"/gerbil/get-config")
|
||||
wgconfig, err = loadRemoteConfig(remoteConfigURL+"/gerbil/get-config", key, reachableAt)
|
||||
if err != nil {
|
||||
logger.Error("Failed to load configuration: %v", err)
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
wgconfig.PrivateKey = key.String()
|
||||
}
|
||||
wgconfig.PrivateKey = key.String()
|
||||
}
|
||||
|
||||
wgClient, err = wgctrl.New()
|
||||
@@ -209,13 +301,56 @@ func main() {
|
||||
// Ensure the WireGuard peers exist
|
||||
ensureWireguardPeers(wgconfig.Peers)
|
||||
|
||||
if reportBandwidthTo != "" {
|
||||
go periodicBandwidthCheck(reportBandwidthTo)
|
||||
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
|
||||
|
||||
// Start the UDP proxy server
|
||||
proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
|
||||
err = proxyRelay.Start()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to start UDP proxy server: %v", err)
|
||||
}
|
||||
defer proxyRelay.Stop()
|
||||
|
||||
// TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING
|
||||
// SO YOU DON'T NEED TO SET THIS SEPARATELY
|
||||
// Parse local overrides
|
||||
var localOverrides []string
|
||||
if localOverridesStr != "" {
|
||||
localOverrides = strings.Split(localOverridesStr, ",")
|
||||
for i, domain := range localOverrides {
|
||||
localOverrides[i] = strings.TrimSpace(domain)
|
||||
}
|
||||
logger.Info("Local overrides configured: %v", localOverrides)
|
||||
}
|
||||
|
||||
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create proxy: %v", err)
|
||||
}
|
||||
|
||||
if err := proxySNI.Start(); err != nil {
|
||||
logger.Fatal("Failed to start proxy: %v", err)
|
||||
}
|
||||
|
||||
// Set up HTTP server
|
||||
http.HandleFunc("/peer", handlePeer)
|
||||
logger.Info("Starting server on %s", listenAddr)
|
||||
logger.Fatal("Failed to start server: %v", http.ListenAndServe(listenAddr, nil))
|
||||
http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping)
|
||||
http.HandleFunc("/update-destinations", handleUpdateDestinations)
|
||||
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
|
||||
logger.Info("Starting HTTP server on %s", listenAddr)
|
||||
|
||||
// Run HTTP server in a goroutine
|
||||
go func() {
|
||||
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
logger.Error("HTTP server failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Keep the main goroutine running
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
logger.Info("Shutting down servers...")
|
||||
}
|
||||
|
||||
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
|
||||
@@ -338,6 +473,10 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||
}
|
||||
|
||||
if err := ensureMSSClamping(); err != nil {
|
||||
logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("WireGuard interface %s created and configured", interfaceName)
|
||||
|
||||
return nil
|
||||
@@ -366,6 +505,9 @@ func assignIPAddress(ipAddress string) error {
|
||||
}
|
||||
|
||||
func ensureWireguardPeers(peers []Peer) error {
|
||||
wgMu.Lock()
|
||||
defer wgMu.Unlock()
|
||||
|
||||
// get the current peers
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err != nil {
|
||||
@@ -388,8 +530,8 @@ func ensureWireguardPeers(peers []Peer) error {
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := removePeer(peer)
|
||||
if err != nil {
|
||||
// Note: We need to call the internal removal logic without re-acquiring the lock
|
||||
if err := removePeerInternal(peer); err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -405,8 +547,8 @@ func ensureWireguardPeers(peers []Peer) error {
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := addPeer(configPeer)
|
||||
if err != nil {
|
||||
// Note: We need to call the internal addition logic without re-acquiring the lock
|
||||
if err := addPeerInternal(configPeer); err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -415,6 +557,94 @@ func ensureWireguardPeers(peers []Peer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureMSSClamping() error {
|
||||
// Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
|
||||
mssValue := mtuInt - 40
|
||||
|
||||
// Rules to be managed - just the chains, we'll construct the full command separately
|
||||
chains := []string{"INPUT", "OUTPUT", "FORWARD"}
|
||||
|
||||
// First, try to delete any existing rules
|
||||
for _, chain := range chains {
|
||||
deleteCmd := exec.Command("/usr/sbin/iptables",
|
||||
"-t", "mangle",
|
||||
"-D", chain,
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
|
||||
|
||||
// Try deletion multiple times to handle multiple existing rules
|
||||
for i := 0; i < 3; i++ {
|
||||
out, err := deleteCmd.CombinedOutput()
|
||||
if err != nil {
|
||||
// Convert exit status 1 to string for better logging
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
|
||||
chain, exitErr.String(), string(out))
|
||||
}
|
||||
break // No more rules to delete
|
||||
}
|
||||
logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// Then add the new rules
|
||||
var errors []error
|
||||
for _, chain := range chains {
|
||||
addCmd := exec.Command("/usr/sbin/iptables",
|
||||
"-t", "mangle",
|
||||
"-A", chain,
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
logger.Info("Adding MSS clamping rule for chain %s", chain)
|
||||
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify the rule was added
|
||||
checkCmd := exec.Command("/usr/sbin/iptables",
|
||||
"-t", "mangle",
|
||||
"-C", chain,
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
|
||||
}
|
||||
|
||||
// If we encountered any errors, return them combined
|
||||
if len(errors) > 0 {
|
||||
var errMsgs []string
|
||||
for _, err := range errors {
|
||||
errMsgs = append(errMsgs, err.Error())
|
||||
}
|
||||
return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
|
||||
strings.Join(errMsgs, "\n"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
@@ -439,11 +669,20 @@ func handleAddPeer(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Notify if notifyURL is set
|
||||
go notifyPeerChange("add", peer.PublicKey)
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "Peer added successfully"})
|
||||
}
|
||||
|
||||
func addPeer(peer Peer) error {
|
||||
wgMu.Lock()
|
||||
defer wgMu.Unlock()
|
||||
return addPeerInternal(peer)
|
||||
}
|
||||
|
||||
func addPeerInternal(peer Peer) error {
|
||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
@@ -451,12 +690,15 @@ func addPeer(peer Peer) error {
|
||||
|
||||
// parse allowed IPs into array of net.IPNet
|
||||
var allowedIPs []net.IPNet
|
||||
var wgIPs []string
|
||||
for _, ipStr := range peer.AllowedIPs {
|
||||
_, ipNet, err := net.ParseCIDR(ipStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse allowed IP: %v", err)
|
||||
}
|
||||
allowedIPs = append(allowedIPs, *ipNet)
|
||||
// Extract the IP address from the CIDR for relay cleanup
|
||||
wgIPs = append(wgIPs, ipNet.IP.String())
|
||||
}
|
||||
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
@@ -472,6 +714,13 @@ func addPeer(peer Peer) error {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
|
||||
// Clear relay connections for the peer's WireGuard IPs
|
||||
if proxyRelay != nil {
|
||||
for _, wgIP := range wgIPs {
|
||||
proxyRelay.OnPeerAdded(wgIP)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Peer %s added successfully", peer.PublicKey)
|
||||
|
||||
return nil
|
||||
@@ -490,16 +739,42 @@ func handleRemovePeer(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Notify if notifyURL is set
|
||||
go notifyPeerChange("remove", publicKey)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "Peer removed successfully"})
|
||||
}
|
||||
|
||||
func removePeer(publicKey string) error {
|
||||
wgMu.Lock()
|
||||
defer wgMu.Unlock()
|
||||
return removePeerInternal(publicKey)
|
||||
}
|
||||
|
||||
func removePeerInternal(publicKey string) error {
|
||||
pubKey, err := wgtypes.ParseKey(publicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
// Get current peer info before removing to clear relay connections
|
||||
var wgIPs []string
|
||||
if proxyRelay != nil {
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err == nil {
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
// Extract WireGuard IPs from this peer's allowed IPs
|
||||
for _, allowedIP := range peer.AllowedIPs {
|
||||
wgIPs = append(wgIPs, allowedIP.IP.String())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
Remove: true,
|
||||
@@ -513,11 +788,163 @@ func removePeer(publicKey string) error {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
|
||||
// Clear relay connections for the peer's WireGuard IPs
|
||||
if proxyRelay != nil {
|
||||
for _, wgIP := range wgIPs {
|
||||
proxyRelay.OnPeerRemoved(wgIP)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Peer %s removed successfully", publicKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
logger.Error("Invalid method: %s", r.Method)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var update ProxyMappingUpdate
|
||||
if err := json.NewDecoder(r.Body).Decode(&update); err != nil {
|
||||
logger.Error("Failed to decode request body: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to decode request body: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the update request
|
||||
if update.OldDestination.DestinationIP == "" || update.NewDestination.DestinationIP == "" {
|
||||
logger.Error("Both old and new destination IP addresses are required")
|
||||
http.Error(w, "Both old and new destination IP addresses are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if update.OldDestination.DestinationPort <= 0 || update.NewDestination.DestinationPort <= 0 {
|
||||
logger.Error("Both old and new destination ports must be positive integers")
|
||||
http.Error(w, "Both old and new destination ports must be positive integers", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Update the proxy mappings in the relay server
|
||||
if proxyRelay == nil {
|
||||
logger.Error("Proxy server is not available")
|
||||
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
updatedCount := proxyRelay.UpdateDestinationInMappings(update.OldDestination, update.NewDestination)
|
||||
|
||||
logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d",
|
||||
updatedCount,
|
||||
update.OldDestination.DestinationIP, update.OldDestination.DestinationPort,
|
||||
update.NewDestination.DestinationIP, update.NewDestination.DestinationPort)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "Proxy mappings updated successfully",
|
||||
"updatedCount": updatedCount,
|
||||
"oldDestination": update.OldDestination,
|
||||
"newDestination": update.NewDestination,
|
||||
})
|
||||
}
|
||||
|
||||
func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
logger.Error("Invalid method: %s", r.Method)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var request UpdateDestinationsRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
logger.Error("Failed to decode request body: %v", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to decode request body: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the request
|
||||
if request.SourceIP == "" {
|
||||
logger.Error("Source IP address is required")
|
||||
http.Error(w, "Source IP address is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if request.SourcePort <= 0 {
|
||||
logger.Error("Source port must be a positive integer")
|
||||
http.Error(w, "Source port must be a positive integer", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(request.Destinations) == 0 {
|
||||
logger.Error("At least one destination is required")
|
||||
http.Error(w, "At least one destination is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate each destination
|
||||
for i, dest := range request.Destinations {
|
||||
if dest.DestinationIP == "" {
|
||||
logger.Error("Destination IP is required for destination %d", i)
|
||||
http.Error(w, fmt.Sprintf("Destination IP is required for destination %d", i), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if dest.DestinationPort <= 0 {
|
||||
logger.Error("Destination port must be a positive integer for destination %d", i)
|
||||
http.Error(w, fmt.Sprintf("Destination port must be a positive integer for destination %d", i), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update the proxy mappings in the relay server
|
||||
if proxyRelay == nil {
|
||||
logger.Error("Proxy server is not available")
|
||||
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
proxyRelay.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations)
|
||||
|
||||
logger.Info("Updated proxy mapping for %s:%d with %d destinations",
|
||||
request.SourceIP, request.SourcePort, len(request.Destinations))
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "Destinations updated successfully",
|
||||
"sourceIP": request.SourceIP,
|
||||
"sourcePort": request.SourcePort,
|
||||
"destinationCount": len(request.Destinations),
|
||||
"destinations": request.Destinations,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateLocalSNIsRequest represents the JSON payload for updating local SNIs
|
||||
type UpdateLocalSNIsRequest struct {
|
||||
FullDomains []string `json:"fullDomains"`
|
||||
}
|
||||
|
||||
func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
logger.Error("Invalid method: %s", r.Method)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateLocalSNIsRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
proxySNI.UpdateLocalSNIs(req.FullDomains)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "Local SNIs updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
func periodicBandwidthCheck(endpoint string) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -530,7 +957,10 @@ func periodicBandwidthCheck(endpoint string) {
|
||||
}
|
||||
|
||||
func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
wgMu.Lock()
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
wgMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
@@ -637,3 +1067,28 @@ func reportPeerBandwidth(apiURL string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// notifyPeerChange sends a POST request to notifyURL with the action and public key.
|
||||
func notifyPeerChange(action, publicKey string) {
|
||||
if notifyURL == "" {
|
||||
return
|
||||
}
|
||||
payload := map[string]string{
|
||||
"action": action,
|
||||
"publicKey": publicKey,
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to marshal notify payload: %v", err)
|
||||
return
|
||||
}
|
||||
resp, err := http.Post(notifyURL, "application/json", bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to notify peer change: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.Warn("Notify server returned non-OK: %s", resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
591
proxy/proxy.go
Normal file
591
proxy/proxy.go
Normal file
@@ -0,0 +1,591 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
// RouteRecord represents a routing configuration
|
||||
type RouteRecord struct {
|
||||
Hostname string
|
||||
TargetHost string
|
||||
TargetPort int
|
||||
}
|
||||
|
||||
// RouteAPIResponse represents the response from the route API
|
||||
type RouteAPIResponse struct {
|
||||
Endpoints []string `json:"endpoints"`
|
||||
}
|
||||
|
||||
// SNIProxy represents the main proxy server
|
||||
type SNIProxy struct {
|
||||
port int
|
||||
cache *cache.Cache
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
localProxyAddr string
|
||||
localProxyPort int
|
||||
remoteConfigURL string
|
||||
publicKey string
|
||||
proxyProtocol bool // Enable PROXY protocol v1
|
||||
|
||||
// New fields for fast local SNI lookup
|
||||
localSNIs map[string]struct{}
|
||||
localSNIsLock sync.RWMutex
|
||||
|
||||
// Local overrides for domains that should always use local proxy
|
||||
localOverrides map[string]struct{}
|
||||
|
||||
// Track active tunnels by SNI
|
||||
activeTunnels map[string]*activeTunnel
|
||||
activeTunnelsLock sync.Mutex
|
||||
}
|
||||
|
||||
type activeTunnel struct {
|
||||
conns []net.Conn
|
||||
}
|
||||
|
||||
// readOnlyConn is a wrapper for io.Reader that implements net.Conn
|
||||
type readOnlyConn struct {
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
|
||||
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
|
||||
func (conn readOnlyConn) Close() error { return nil }
|
||||
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
|
||||
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
|
||||
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
// buildProxyProtocolHeader creates a PROXY protocol v1 header
|
||||
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
|
||||
clientTCP, ok := clientAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
// Fallback for unknown address types
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
// Fallback for unknown address types
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
// Determine protocol family based on client IP and normalize target IP accordingly
|
||||
var protocol string
|
||||
var targetIP string
|
||||
|
||||
if clientTCP.IP.To4() != nil {
|
||||
// Client is IPv4, use TCP4 protocol
|
||||
protocol = "TCP4"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is also IPv4, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is IPv6, but we need IPv4 for consistent protocol family
|
||||
// Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1
|
||||
if targetTCP.IP.IsLoopback() {
|
||||
targetIP = "127.0.0.1"
|
||||
} else {
|
||||
// For non-loopback IPv6 targets, we could try to extract embedded IPv4
|
||||
// or fall back to a sensible IPv4 address based on the target
|
||||
targetIP = "127.0.0.1" // Safe fallback
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Client is IPv6, use TCP6 protocol
|
||||
protocol = "TCP6"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is IPv4, convert to IPv6 representation
|
||||
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is also IPv6, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||
protocol,
|
||||
clientTCP.IP.String(),
|
||||
targetIP,
|
||||
clientTCP.Port,
|
||||
targetTCP.Port)
|
||||
}
|
||||
|
||||
// NewSNIProxy creates a new SNI proxy instance
|
||||
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool) (*SNIProxy, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create local overrides map
|
||||
overridesMap := make(map[string]struct{})
|
||||
for _, domain := range localOverrides {
|
||||
if domain != "" {
|
||||
overridesMap[domain] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
proxy := &SNIProxy{
|
||||
port: port,
|
||||
cache: cache.New(3*time.Second, 10*time.Minute),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
localProxyAddr: localProxyAddr,
|
||||
localProxyPort: localProxyPort,
|
||||
remoteConfigURL: remoteConfigURL,
|
||||
publicKey: publicKey,
|
||||
proxyProtocol: proxyProtocol,
|
||||
localSNIs: make(map[string]struct{}),
|
||||
localOverrides: overridesMap,
|
||||
activeTunnels: make(map[string]*activeTunnel),
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// Start begins listening for connections
|
||||
func (p *SNIProxy) Start() error {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
|
||||
}
|
||||
|
||||
p.listener = listener
|
||||
logger.Debug("SNI Proxy listening on port %d", p.port)
|
||||
|
||||
// Accept connections in a goroutine
|
||||
go p.acceptConnections()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the proxy
|
||||
func (p *SNIProxy) Stop() error {
|
||||
log.Println("Stopping SNI Proxy...")
|
||||
|
||||
p.cancel()
|
||||
|
||||
if p.listener != nil {
|
||||
p.listener.Close()
|
||||
}
|
||||
|
||||
// Wait for all goroutines to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
p.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Println("All connections closed gracefully")
|
||||
case <-time.After(30 * time.Second):
|
||||
log.Println("Timeout waiting for connections to close")
|
||||
}
|
||||
|
||||
log.Println("SNI Proxy stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptConnections handles incoming connections
|
||||
func (p *SNIProxy) acceptConnections() {
|
||||
for {
|
||||
conn, err := p.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
logger.Debug("Accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
p.wg.Add(1)
|
||||
go p.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// readClientHello reads and parses the TLS ClientHello message
|
||||
func (p *SNIProxy) readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
|
||||
var hello *tls.ClientHelloInfo
|
||||
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
|
||||
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
hello = new(tls.ClientHelloInfo)
|
||||
*hello = *argHello
|
||||
return nil, nil
|
||||
},
|
||||
}).Handshake()
|
||||
if hello == nil {
|
||||
return nil, err
|
||||
}
|
||||
return hello, nil
|
||||
}
|
||||
|
||||
// peekClientHello reads the ClientHello while preserving the data for forwarding
|
||||
func (p *SNIProxy) peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
|
||||
peekedBytes := new(bytes.Buffer)
|
||||
hello, err := p.readClientHello(io.TeeReader(reader, peekedBytes))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return hello, io.MultiReader(peekedBytes, reader), nil
|
||||
}
|
||||
|
||||
// extractSNI extracts the SNI hostname from the TLS ClientHello
|
||||
func (p *SNIProxy) extractSNI(conn net.Conn) (string, io.Reader, error) {
|
||||
clientHello, clientReader, err := p.peekClientHello(conn)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to peek ClientHello: %w", err)
|
||||
}
|
||||
|
||||
if clientHello.ServerName == "" {
|
||||
return "", clientReader, fmt.Errorf("no SNI hostname found in ClientHello")
|
||||
}
|
||||
|
||||
return clientHello.ServerName, clientReader, nil
|
||||
}
|
||||
|
||||
// handleConnection processes a single client connection
|
||||
func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
defer p.wg.Done()
|
||||
defer clientConn.Close()
|
||||
|
||||
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
|
||||
|
||||
// Set read timeout for SNI extraction
|
||||
if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
logger.Debug("Failed to set read deadline: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract SNI hostname
|
||||
hostname, clientReader, err := p.extractSNI(clientConn)
|
||||
if err != nil {
|
||||
logger.Debug("SNI extraction failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if hostname == "" {
|
||||
log.Println("No SNI hostname found")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("SNI hostname detected: %s", hostname)
|
||||
|
||||
// Remove read timeout for normal operation
|
||||
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get routing information
|
||||
route, err := p.getRoute(hostname, clientConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
logger.Debug("Failed to get route for %s: %v", hostname, err)
|
||||
return
|
||||
}
|
||||
|
||||
if route == nil {
|
||||
logger.Debug("No route found for hostname: %s", hostname)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("Routing %s to %s:%d", hostname, route.TargetHost, route.TargetPort)
|
||||
|
||||
// Connect to target server
|
||||
targetConn, err := net.DialTimeout("tcp",
|
||||
fmt.Sprintf("%s:%d", route.TargetHost, route.TargetPort),
|
||||
10*time.Second)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to connect to target %s:%d: %v",
|
||||
route.TargetHost, route.TargetPort, err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
|
||||
|
||||
// Send PROXY protocol header if enabled
|
||||
if p.proxyProtocol {
|
||||
proxyHeader := buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
|
||||
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
|
||||
|
||||
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
|
||||
logger.Debug("Failed to send PROXY protocol header: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Track this tunnel by SNI
|
||||
p.activeTunnelsLock.Lock()
|
||||
tunnel, ok := p.activeTunnels[hostname]
|
||||
if !ok {
|
||||
tunnel = &activeTunnel{}
|
||||
p.activeTunnels[hostname] = tunnel
|
||||
}
|
||||
tunnel.conns = append(tunnel.conns, clientConn)
|
||||
p.activeTunnelsLock.Unlock()
|
||||
|
||||
defer func() {
|
||||
// Remove this conn from active tunnels
|
||||
p.activeTunnelsLock.Lock()
|
||||
if tunnel, ok := p.activeTunnels[hostname]; ok {
|
||||
newConns := make([]net.Conn, 0, len(tunnel.conns))
|
||||
for _, c := range tunnel.conns {
|
||||
if c != clientConn {
|
||||
newConns = append(newConns, c)
|
||||
}
|
||||
}
|
||||
if len(newConns) == 0 {
|
||||
delete(p.activeTunnels, hostname)
|
||||
} else {
|
||||
tunnel.conns = newConns
|
||||
}
|
||||
}
|
||||
p.activeTunnelsLock.Unlock()
|
||||
}()
|
||||
|
||||
// Start bidirectional data transfer
|
||||
p.pipe(clientConn, targetConn, clientReader)
|
||||
}
|
||||
|
||||
// getRoute retrieves routing information for a hostname
|
||||
func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
||||
// Check local overrides first
|
||||
if _, isOverride := p.localOverrides[hostname]; isOverride {
|
||||
logger.Debug("Local override matched for hostname: %s", hostname)
|
||||
return &RouteRecord{
|
||||
Hostname: hostname,
|
||||
TargetHost: p.localProxyAddr,
|
||||
TargetPort: p.localProxyPort,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Fast path: check if hostname is in localSNIs
|
||||
p.localSNIsLock.RLock()
|
||||
_, isLocal := p.localSNIs[hostname]
|
||||
p.localSNIsLock.RUnlock()
|
||||
if isLocal {
|
||||
return &RouteRecord{
|
||||
Hostname: hostname,
|
||||
TargetHost: p.localProxyAddr,
|
||||
TargetPort: p.localProxyPort,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
if cached, found := p.cache.Get(hostname); found {
|
||||
if cached == nil {
|
||||
return nil, nil // Cached negative result
|
||||
}
|
||||
logger.Debug("Cache hit for hostname: %s", hostname)
|
||||
return cached.(*RouteRecord), nil
|
||||
}
|
||||
|
||||
logger.Debug("Cache miss for hostname: %s, querying API", hostname)
|
||||
|
||||
// Query API with timeout
|
||||
ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Construct API URL (without hostname in path)
|
||||
apiURL := fmt.Sprintf("%s/gerbil/get-resolved-hostname", p.remoteConfigURL)
|
||||
|
||||
// Create request body with hostname and public key
|
||||
requestBody := map[string]string{
|
||||
"hostname": hostname,
|
||||
"publicKey": p.publicKey,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Make HTTP request
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("API request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// Cache negative result for shorter time (1 minute)
|
||||
p.cache.Set(hostname, nil, 1*time.Minute)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var apiResponse RouteAPIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode API response: %w", err)
|
||||
}
|
||||
|
||||
endpoints := apiResponse.Endpoints
|
||||
|
||||
// Default target configuration
|
||||
targetHost := p.localProxyAddr
|
||||
targetPort := p.localProxyPort
|
||||
|
||||
// If no endpoints returned, use local node
|
||||
if len(endpoints) == 0 {
|
||||
logger.Debug("No endpoints returned for hostname: %s, using local node", hostname)
|
||||
} else {
|
||||
// Select endpoint using consistent hashing for stickiness
|
||||
selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints)
|
||||
targetHost = selectedEndpoint
|
||||
targetPort = 443 // Default HTTPS port
|
||||
logger.Debug("Selected endpoint %s for hostname %s from client %s", selectedEndpoint, hostname, clientAddr)
|
||||
}
|
||||
|
||||
route := &RouteRecord{
|
||||
Hostname: hostname,
|
||||
TargetHost: targetHost,
|
||||
TargetPort: targetPort,
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
p.cache.Set(hostname, route, cache.DefaultExpiration)
|
||||
logger.Debug("Cached route for hostname: %s", hostname)
|
||||
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// selectStickyEndpoint selects an endpoint using consistent hashing to ensure
|
||||
// the same client always routes to the same endpoint for load balancing
|
||||
func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) string {
|
||||
if len(endpoints) == 0 {
|
||||
return p.localProxyAddr
|
||||
}
|
||||
if len(endpoints) == 1 {
|
||||
return endpoints[0]
|
||||
}
|
||||
|
||||
// Use FNV hash for consistent selection based on client address
|
||||
hash := fnv.New32a()
|
||||
hash.Write([]byte(clientAddr))
|
||||
index := hash.Sum32() % uint32(len(endpoints))
|
||||
|
||||
return endpoints[index]
|
||||
}
|
||||
|
||||
// pipe handles bidirectional data transfer between connections
|
||||
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Copy data from client to target (using the buffered reader)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
|
||||
tcpConn.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
// Use a large buffer for better performance
|
||||
buf := make([]byte, 32*1024)
|
||||
_, err := io.CopyBuffer(targetConn, clientReader, buf)
|
||||
if err != nil && err != io.EOF {
|
||||
logger.Debug("Copy client->target error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Copy data from target to client
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
|
||||
tcpConn.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
// Use a large buffer for better performance
|
||||
buf := make([]byte, 32*1024)
|
||||
_, err := io.CopyBuffer(clientConn, targetConn, buf)
|
||||
if err != nil && err != io.EOF {
|
||||
logger.Debug("Copy target->client error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// GetCacheStats returns cache statistics
|
||||
func (p *SNIProxy) GetCacheStats() (int, int) {
|
||||
return p.cache.ItemCount(), len(p.cache.Items())
|
||||
}
|
||||
|
||||
// ClearCache clears all cached entries
|
||||
func (p *SNIProxy) ClearCache() {
|
||||
p.cache.Flush()
|
||||
log.Println("Cache cleared")
|
||||
}
|
||||
|
||||
// UpdateLocalSNIs updates the local SNIs and invalidates cache for changed domains
|
||||
func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) {
|
||||
newSNIs := make(map[string]struct{})
|
||||
for _, domain := range fullDomains {
|
||||
newSNIs[domain] = struct{}{}
|
||||
// Invalidate any cached route for this domain
|
||||
p.cache.Delete(domain)
|
||||
}
|
||||
|
||||
// Update localSNIs
|
||||
p.localSNIsLock.Lock()
|
||||
removed := make([]string, 0)
|
||||
for sni := range p.localSNIs {
|
||||
if _, stillLocal := newSNIs[sni]; !stillLocal {
|
||||
removed = append(removed, sni)
|
||||
}
|
||||
}
|
||||
p.localSNIs = newSNIs
|
||||
p.localSNIsLock.Unlock()
|
||||
|
||||
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
|
||||
|
||||
// Terminate tunnels for removed SNIs
|
||||
if len(removed) > 0 {
|
||||
p.activeTunnelsLock.Lock()
|
||||
for _, sni := range removed {
|
||||
if tunnels, ok := p.activeTunnels[sni]; ok {
|
||||
for _, conn := range tunnels.conns {
|
||||
conn.Close()
|
||||
}
|
||||
delete(p.activeTunnels, sni)
|
||||
logger.Debug("Closed tunnels for SNI target change: %s", sni)
|
||||
}
|
||||
}
|
||||
p.activeTunnelsLock.Unlock()
|
||||
}
|
||||
}
|
||||
78
proxy/proxy_test.go
Normal file
78
proxy/proxy_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildProxyProtocolHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientAddr string
|
||||
targetAddr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 client and target",
|
||||
clientAddr: "192.168.1.100:12345",
|
||||
targetAddr: "10.0.0.1:443",
|
||||
expected: "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv6 client and target",
|
||||
clientAddr: "[2001:db8::1]:12345",
|
||||
targetAddr: "[2001:db8::2]:443",
|
||||
expected: "PROXY TCP6 2001:db8::1 2001:db8::2 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv4 client with IPv6 loopback target",
|
||||
clientAddr: "192.168.1.100:12345",
|
||||
targetAddr: "[::1]:443",
|
||||
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv4 client with IPv6 target",
|
||||
clientAddr: "192.168.1.100:12345",
|
||||
targetAddr: "[2001:db8::2]:443",
|
||||
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv6 client with IPv4 target",
|
||||
clientAddr: "[2001:db8::1]:12345",
|
||||
targetAddr: "10.0.0.1:443",
|
||||
expected: "PROXY TCP6 2001:db8::1 ::ffff:10.0.0.1 12345 443\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clientTCP, err := net.ResolveTCPAddr("tcp", tt.clientAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve client address: %v", err)
|
||||
}
|
||||
|
||||
targetTCP, err := net.ResolveTCPAddr("tcp", tt.targetAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve target address: %v", err)
|
||||
}
|
||||
|
||||
result := buildProxyProtocolHeader(clientTCP, targetTCP)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
|
||||
// Test with non-TCP address type
|
||||
clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}
|
||||
targetAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443}
|
||||
|
||||
result := buildProxyProtocolHeader(clientAddr, targetAddr)
|
||||
expected := "PROXY UNKNOWN\r\n"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected %q, got %q", expected, result)
|
||||
}
|
||||
}
|
||||
825
relay/relay.go
Normal file
825
relay/relay.go
Normal file
@@ -0,0 +1,825 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type EncryptedHolePunchMessage struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}
|
||||
|
||||
type HolePunchMessage struct {
|
||||
OlmID string `json:"olmId"`
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type ClientEndpoint struct {
|
||||
OlmID string `json:"olmId"`
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
ReachableAt string `json:"reachableAt"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
// Updated to support multiple destination peers
|
||||
type ProxyMapping struct {
|
||||
Destinations []PeerDestination `json:"destinations"`
|
||||
LastUsed time.Time `json:"-"` // Not serialized, used for cleanup
|
||||
}
|
||||
|
||||
type PeerDestination struct {
|
||||
DestinationIP string `json:"destinationIP"`
|
||||
DestinationPort int `json:"destinationPort"`
|
||||
}
|
||||
|
||||
type DestinationConn struct {
|
||||
conn *net.UDPConn
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
// Type for storing WireGuard handshake information
|
||||
type WireGuardSession struct {
|
||||
ReceiverIndex uint32
|
||||
SenderIndex uint32
|
||||
DestAddr *net.UDPAddr
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
type InitialMappings struct {
|
||||
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
|
||||
}
|
||||
|
||||
// Packet is a simple struct to hold the packet data and sender info.
|
||||
type Packet struct {
|
||||
data []byte
|
||||
remoteAddr *net.UDPAddr
|
||||
n int
|
||||
}
|
||||
|
||||
// WireGuard message types
|
||||
const (
|
||||
WireGuardMessageTypeHandshakeInitiation = 1
|
||||
WireGuardMessageTypeHandshakeResponse = 2
|
||||
WireGuardMessageTypeCookieReply = 3
|
||||
WireGuardMessageTypeTransportData = 4
|
||||
)
|
||||
|
||||
// --- End Types ---
|
||||
|
||||
// bufferPool allows reusing buffers to reduce allocations.
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 1500)
|
||||
},
|
||||
}
|
||||
|
||||
// UDPProxyServer has a channel for incoming packets.
|
||||
type UDPProxyServer struct {
|
||||
addr string
|
||||
serverURL string
|
||||
conn *net.UDPConn
|
||||
proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port"
|
||||
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
||||
privateKey wgtypes.Key
|
||||
packetChan chan Packet
|
||||
|
||||
// Session tracking for WireGuard peers
|
||||
// Key format: "senderIndex:receiverIndex"
|
||||
wgSessions sync.Map
|
||||
// ReachableAt is the URL where this server can be reached
|
||||
ReachableAt string
|
||||
}
|
||||
|
||||
// NewUDPProxyServer initializes the server with a buffered packet channel.
|
||||
func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
return &UDPProxyServer{
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 1000),
|
||||
ReachableAt: reachableAt,
|
||||
}
|
||||
}
|
||||
|
||||
// Start sets up the UDP listener, worker pool, and begins reading packets.
|
||||
func (s *UDPProxyServer) Start() error {
|
||||
// Fetch initial mappings.
|
||||
if err := s.fetchInitialMappings(); err != nil {
|
||||
return fmt.Errorf("failed to fetch initial mappings: %v", err)
|
||||
}
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", s.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.conn = conn
|
||||
logger.Info("UDP server listening on %s", s.addr)
|
||||
|
||||
// Start a fixed number of worker goroutines.
|
||||
workerCount := 10 // TODO: Make this configurable or pick it better!
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go s.packetWorker()
|
||||
}
|
||||
|
||||
// Start the goroutine that reads packets from the UDP socket.
|
||||
go s.readPackets()
|
||||
|
||||
// Start the idle connection cleanup routine.
|
||||
go s.cleanupIdleConnections()
|
||||
|
||||
// Start the session cleanup routine
|
||||
go s.cleanupIdleSessions()
|
||||
|
||||
// Start the proxy mapping cleanup routine
|
||||
go s.cleanupIdleProxyMappings()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) Stop() {
|
||||
s.conn.Close()
|
||||
}
|
||||
|
||||
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
|
||||
func (s *UDPProxyServer) readPackets() {
|
||||
for {
|
||||
buf := bufferPool.Get().([]byte)
|
||||
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
continue
|
||||
}
|
||||
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
|
||||
}
|
||||
}
|
||||
|
||||
// packetWorker processes incoming packets from the channel.
|
||||
func (s *UDPProxyServer) packetWorker() {
|
||||
for packet := range s.packetChan {
|
||||
// Determine packet type by inspecting the first byte.
|
||||
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
|
||||
// Process as a WireGuard packet.
|
||||
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
|
||||
} else {
|
||||
// Process as an encrypted hole punch message
|
||||
var encMsg EncryptedHolePunchMessage
|
||||
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
|
||||
logger.Error("Error unmarshaling encrypted message: %v", err)
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
}
|
||||
|
||||
if encMsg.EphemeralPublicKey == "" {
|
||||
logger.Error("Received malformed message without ephemeral key")
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
}
|
||||
|
||||
// This appears to be an encrypted message
|
||||
decryptedData, err := s.decryptMessage(encMsg)
|
||||
if err != nil {
|
||||
logger.Error("Failed to decrypt message: %v", err)
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
}
|
||||
|
||||
// Process the decrypted hole punch message
|
||||
var msg HolePunchMessage
|
||||
if err := json.Unmarshal(decryptedData, &msg); err != nil {
|
||||
logger.Error("Error unmarshaling decrypted message: %v", err)
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
}
|
||||
|
||||
endpoint := ClientEndpoint{
|
||||
NewtID: msg.NewtID,
|
||||
OlmID: msg.OlmID,
|
||||
Token: msg.Token,
|
||||
IP: packet.remoteAddr.IP.String(),
|
||||
Port: packet.remoteAddr.Port,
|
||||
Timestamp: time.Now().Unix(),
|
||||
ReachableAt: s.ReachableAt,
|
||||
PublicKey: s.privateKey.PublicKey().String(),
|
||||
}
|
||||
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port)
|
||||
s.notifyServer(endpoint)
|
||||
s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment
|
||||
}
|
||||
// Return the buffer to the pool for reuse.
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
}
|
||||
}
|
||||
|
||||
// decryptMessage decrypts the message using the server's private key
|
||||
func (s *UDPProxyServer) decryptMessage(encMsg EncryptedHolePunchMessage) ([]byte, error) {
|
||||
// Parse the ephemeral public key
|
||||
ephPubKey, err := wgtypes.ParseKey(encMsg.EphemeralPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ephemeral public key: %v", err)
|
||||
}
|
||||
|
||||
// Use X25519 for key exchange instead of ScalarMult
|
||||
sharedSecret, err := curve25519.X25519(s.privateKey[:], ephPubKey[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||
}
|
||||
|
||||
// Create the AEAD cipher using the shared secret
|
||||
aead, err := chacha20poly1305.New(sharedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||
}
|
||||
|
||||
// Verify nonce size
|
||||
if len(encMsg.Nonce) != aead.NonceSize() {
|
||||
return nil, fmt.Errorf("invalid nonce size")
|
||||
}
|
||||
|
||||
// Decrypt the ciphertext
|
||||
plaintext, err := aead.Open(nil, encMsg.Nonce, encMsg.Ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt message: %v", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) fetchInitialMappings() error {
|
||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.privateKey.PublicKey().String())))
|
||||
resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch mappings: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("server returned non-OK status: %d, body: %s",
|
||||
resp.StatusCode, string(body))
|
||||
}
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response body: %v", err)
|
||||
}
|
||||
logger.Info("Received initial mappings: %s", string(data))
|
||||
var initialMappings InitialMappings
|
||||
if err := json.Unmarshal(data, &initialMappings); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal initial mappings: %v", err)
|
||||
}
|
||||
// Store mappings in our sync.Map.
|
||||
for key, mapping := range initialMappings.Mappings {
|
||||
// Initialize LastUsed timestamp for initial mappings
|
||||
mapping.LastUsed = time.Now()
|
||||
s.proxyMappings.Store(key, mapping)
|
||||
}
|
||||
logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract WireGuard message indices
|
||||
func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) {
|
||||
if len(packet) < 12 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
messageType := packet[0]
|
||||
if messageType == WireGuardMessageTypeHandshakeInitiation {
|
||||
// Handshake initiation: extract sender index at offset 4
|
||||
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
|
||||
return 0, senderIndex, true
|
||||
} else if messageType == WireGuardMessageTypeHandshakeResponse {
|
||||
// Handshake response: extract sender index at offset 4 and receiver index at offset 8
|
||||
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
|
||||
receiverIndex := binary.LittleEndian.Uint32(packet[8:12])
|
||||
return receiverIndex, senderIndex, true
|
||||
} else if messageType == WireGuardMessageTypeTransportData {
|
||||
// Transport data: extract receiver index at offset 4
|
||||
receiverIndex := binary.LittleEndian.Uint32(packet[4:8])
|
||||
return receiverIndex, 0, true
|
||||
}
|
||||
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// Updated to handle multi-peer WireGuard communication
|
||||
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
|
||||
if len(packet) == 0 {
|
||||
logger.Error("Received empty packet")
|
||||
return
|
||||
}
|
||||
|
||||
messageType := packet[0]
|
||||
receiverIndex, senderIndex, ok := extractWireGuardIndices(packet)
|
||||
|
||||
if !ok {
|
||||
logger.Error("Failed to extract WireGuard indices")
|
||||
return
|
||||
}
|
||||
|
||||
key := remoteAddr.String()
|
||||
mappingObj, ok := s.proxyMappings.Load(key)
|
||||
if !ok {
|
||||
logger.Error("No proxy mapping found for %s", key)
|
||||
return
|
||||
}
|
||||
|
||||
proxyMapping := mappingObj.(ProxyMapping)
|
||||
// Update the last used timestamp and store it back
|
||||
proxyMapping.LastUsed = time.Now()
|
||||
s.proxyMappings.Store(key, proxyMapping)
|
||||
|
||||
// Handle different WireGuard message types
|
||||
switch messageType {
|
||||
case WireGuardMessageTypeHandshakeInitiation:
|
||||
// Initial handshake: forward to all peers
|
||||
logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations)
|
||||
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to get/create connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Error("Failed to forward handshake initiation: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
case WireGuardMessageTypeHandshakeResponse:
|
||||
// Received handshake response: establish session mapping
|
||||
logger.Debug("Received handshake response with receiver index %d and sender index %d from %s",
|
||||
receiverIndex, senderIndex, remoteAddr)
|
||||
|
||||
// Create a session key for the peer that sent the initial handshake
|
||||
sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex)
|
||||
|
||||
// Store the session information
|
||||
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
||||
ReceiverIndex: receiverIndex,
|
||||
SenderIndex: senderIndex,
|
||||
DestAddr: remoteAddr,
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
|
||||
// Forward the response to the original sender
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to get/create connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Error("Failed to forward handshake response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
case WireGuardMessageTypeTransportData:
|
||||
// Data packet: forward only to the established session peer
|
||||
// logger.Debug("Received transport data with receiver index %d from %s", receiverIndex, remoteAddr)
|
||||
|
||||
// Look up the session based on the receiver index
|
||||
var destAddr *net.UDPAddr
|
||||
|
||||
// First check for existing sessions to see if we know where to send this packet
|
||||
s.wgSessions.Range(func(k, v interface{}) bool {
|
||||
session := v.(*WireGuardSession)
|
||||
if session.SenderIndex == receiverIndex {
|
||||
// Found matching session
|
||||
destAddr = session.DestAddr
|
||||
|
||||
// Update last seen time
|
||||
session.LastSeen = time.Now()
|
||||
s.wgSessions.Store(k, session)
|
||||
return false // stop iteration
|
||||
}
|
||||
return true // continue iteration
|
||||
})
|
||||
|
||||
if destAddr != nil {
|
||||
// We found a specific peer to forward to
|
||||
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to get/create connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to forward transport data: %v", err)
|
||||
}
|
||||
} else {
|
||||
// No known session, fall back to forwarding to all peers
|
||||
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to get/create connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to forward transport data: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// Other packet types (like cookie reply)
|
||||
logger.Debug("Forwarding WireGuard packet type %d from %s", messageType, remoteAddr)
|
||||
|
||||
// Forward to all peers
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve destination address: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to get/create connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Error("Failed to forward WireGuard packet: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
key := destAddr.String() + "-" + remoteAddr.String()
|
||||
|
||||
// Check if we have an existing connection
|
||||
if conn, ok := s.connections.Load(key); ok {
|
||||
destConn := conn.(*DestinationConn)
|
||||
destConn.lastUsed = time.Now()
|
||||
return destConn.conn, nil
|
||||
}
|
||||
|
||||
// Create new connection
|
||||
newConn, err := net.DialUDP("udp", nil, destAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
// Store the new connection
|
||||
s.connections.Store(key, &DestinationConn{
|
||||
conn: newConn,
|
||||
lastUsed: time.Now(),
|
||||
})
|
||||
|
||||
// Start a goroutine to handle responses
|
||||
go s.handleResponses(newConn, destAddr, remoteAddr)
|
||||
|
||||
return newConn, nil
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) {
|
||||
buffer := make([]byte, 1500)
|
||||
for {
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
logger.Debug("Error reading response from %s: %v", destAddr.String(), err)
|
||||
return
|
||||
}
|
||||
|
||||
// Process the response to track sessions if it's a WireGuard packet
|
||||
if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 {
|
||||
receiverIndex, senderIndex, ok := extractWireGuardIndices(buffer[:n])
|
||||
if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse {
|
||||
// Store the session mapping for the handshake response
|
||||
sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex)
|
||||
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
||||
ReceiverIndex: receiverIndex,
|
||||
SenderIndex: senderIndex,
|
||||
DestAddr: destAddr,
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Forward the response back through the main listener
|
||||
_, err = s.conn.WriteToUDP(buffer[:n], remoteAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to forward response: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a cleanup method to periodically remove idle connections
|
||||
func (s *UDPProxyServer) cleanupIdleConnections() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
destConn := value.(*DestinationConn)
|
||||
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
destConn.conn.Close()
|
||||
s.connections.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle sessions
|
||||
func (s *UDPProxyServer) cleanupIdleSessions() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
session := value.(*WireGuardSession)
|
||||
if now.Sub(session.LastSeen) > 15*time.Minute {
|
||||
s.wgSessions.Delete(key)
|
||||
logger.Debug("Removed idle session: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle proxy mappings
|
||||
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
mapping := value.(ProxyMapping)
|
||||
// Remove mappings that haven't been used in 30 minutes
|
||||
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
s.proxyMappings.Delete(key)
|
||||
logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) {
|
||||
logger.Debug("notifyServer called with endpoint: IP=%s, Port=%d", endpoint.IP, endpoint.Port)
|
||||
|
||||
jsonData, err := json.Marshal(endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to marshal endpoint data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.Post(s.serverURL+"/gerbil/update-hole-punch", "application/json", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
logger.Error("Failed to notify server: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
logger.Error("Server returned non-OK status: %d, body: %s",
|
||||
resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the proxy mapping response
|
||||
var mapping ProxyMapping
|
||||
if err := json.NewDecoder(resp.Body).Decode(&mapping); err != nil {
|
||||
logger.Error("Failed to decode proxy mapping: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("Received proxy mapping from server: %v", mapping)
|
||||
|
||||
// Store the mapping with current timestamp
|
||||
key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port)
|
||||
logger.Debug("About to store proxy mapping with key: %s (from endpoint IP=%s, Port=%d)", key, endpoint.IP, endpoint.Port)
|
||||
mapping.LastUsed = time.Now()
|
||||
s.proxyMappings.Store(key, mapping)
|
||||
|
||||
logger.Debug("Stored proxy mapping for %s with %d destinations (timestamp: %v)", key, len(mapping.Destinations), mapping.LastUsed)
|
||||
}
|
||||
|
||||
// Updated to support multiple destinations
|
||||
func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, destinations []PeerDestination) {
|
||||
key := fmt.Sprintf("%s:%d", sourceIP, sourcePort)
|
||||
mapping := ProxyMapping{
|
||||
Destinations: destinations,
|
||||
LastUsed: time.Now(),
|
||||
}
|
||||
s.proxyMappings.Store(key, mapping)
|
||||
}
|
||||
|
||||
// OnPeerAdded clears connections and sessions for a specific WireGuard IP to allow re-establishment
|
||||
func (s *UDPProxyServer) OnPeerAdded(wgIP string) {
|
||||
logger.Info("Clearing connections for added peer with WG IP: %s", wgIP)
|
||||
s.clearConnectionsForWGIP(wgIP)
|
||||
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
|
||||
// s.clearProxyMappingsForWGIP(wgIP)
|
||||
}
|
||||
|
||||
// OnPeerRemoved clears connections and sessions for a specific WireGuard IP
|
||||
func (s *UDPProxyServer) OnPeerRemoved(wgIP string) {
|
||||
logger.Info("Clearing connections for removed peer with WG IP: %s", wgIP)
|
||||
s.clearConnectionsForWGIP(wgIP)
|
||||
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
|
||||
// s.clearProxyMappingsForWGIP(wgIP)
|
||||
}
|
||||
|
||||
// clearConnectionsForWGIP removes all connections associated with a specific WireGuard IP
|
||||
func (s *UDPProxyServer) clearConnectionsForWGIP(wgIP string) {
|
||||
var keysToDelete []string
|
||||
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
keyStr := key.(string)
|
||||
destConn := value.(*DestinationConn)
|
||||
|
||||
// Connection keys are in format "destAddr-remoteAddr"
|
||||
// Check if either destination or remote address contains the WG IP
|
||||
if containsIP(keyStr, wgIP) {
|
||||
keysToDelete = append(keysToDelete, keyStr)
|
||||
destConn.conn.Close()
|
||||
logger.Debug("Closing connection for WG IP %s: %s", wgIP, keyStr)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Delete the connections
|
||||
for _, key := range keysToDelete {
|
||||
s.connections.Delete(key)
|
||||
}
|
||||
|
||||
logger.Info("Cleared %d connections for WG IP: %s", len(keysToDelete), wgIP)
|
||||
}
|
||||
|
||||
// clearSessionsForWGIP removes all WireGuard sessions associated with a specific WireGuard IP
|
||||
func (s *UDPProxyServer) clearSessionsForIP(ip string) {
|
||||
var keysToDelete []string
|
||||
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
keyStr := key.(string)
|
||||
session := value.(*WireGuardSession)
|
||||
|
||||
// Check if the session's destination address contains the WG IP
|
||||
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
|
||||
keysToDelete = append(keysToDelete, keyStr)
|
||||
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Delete the sessions
|
||||
for _, key := range keysToDelete {
|
||||
s.wgSessions.Delete(key)
|
||||
}
|
||||
|
||||
logger.Info("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip)
|
||||
}
|
||||
|
||||
// // clearProxyMappingsForWGIP removes all proxy mappings that have destinations pointing to a specific WireGuard IP
|
||||
// func (s *UDPProxyServer) clearProxyMappingsForWGIP(wgIP string) {
|
||||
// var keysToDelete []string
|
||||
|
||||
// s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
// keyStr := key.(string)
|
||||
// mapping := value.(ProxyMapping)
|
||||
|
||||
// // Check if any destination in the mapping contains the WG IP
|
||||
// for _, dest := range mapping.Destinations {
|
||||
// if dest.DestinationIP == wgIP {
|
||||
// keysToDelete = append(keysToDelete, keyStr)
|
||||
// logger.Debug("Marking proxy mapping for deletion for WG IP %s: %s -> %s:%d", wgIP, keyStr, dest.DestinationIP, dest.DestinationPort)
|
||||
// break // Found one destination, no need to check others in this mapping
|
||||
// }
|
||||
// }
|
||||
// return true
|
||||
// })
|
||||
|
||||
// // Delete the proxy mappings
|
||||
// for _, key := range keysToDelete {
|
||||
// s.proxyMappings.Delete(key)
|
||||
// logger.Debug("Deleted proxy mapping: %s", key)
|
||||
// }
|
||||
|
||||
// logger.Info("Cleared %d proxy mappings for WG IP: %s", len(keysToDelete), wgIP)
|
||||
// }
|
||||
|
||||
// containsIP checks if a connection key string contains the specified IP address
|
||||
func containsIP(connectionKey, ip string) bool {
|
||||
// Connection keys are in format "destIP:destPort-remoteIP:remotePort"
|
||||
// Check if the IP appears at the beginning (destination) or after the dash (remote)
|
||||
ipWithColon := ip + ":"
|
||||
|
||||
// Check if connection key starts with the IP (destination address)
|
||||
if len(connectionKey) >= len(ipWithColon) && connectionKey[:len(ipWithColon)] == ipWithColon {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if connection key contains the IP after a dash (remote address)
|
||||
dashIndex := -1
|
||||
for i := 0; i < len(connectionKey); i++ {
|
||||
if connectionKey[i] == '-' {
|
||||
dashIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if dashIndex != -1 && dashIndex+1 < len(connectionKey) {
|
||||
remainingPart := connectionKey[dashIndex+1:]
|
||||
if len(remainingPart) >= len(ip)+1 && remainingPart[:len(ip)+1] == ipWithColon {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateDestinationInMappings updates all proxy mappings that contain the old destination with the new destination
|
||||
// Returns the number of mappings that were updated
|
||||
func (s *UDPProxyServer) UpdateDestinationInMappings(oldDest, newDest PeerDestination) int {
|
||||
updatedCount := 0
|
||||
|
||||
s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
keyStr := key.(string)
|
||||
mapping := value.(ProxyMapping)
|
||||
updated := false
|
||||
|
||||
// Check each destination in the mapping
|
||||
for i, dest := range mapping.Destinations {
|
||||
if dest.DestinationIP == oldDest.DestinationIP && dest.DestinationPort == oldDest.DestinationPort {
|
||||
// Update this destination
|
||||
mapping.Destinations[i] = newDest
|
||||
updated = true
|
||||
logger.Debug("Updated destination in mapping %s: %s:%d -> %s:%d",
|
||||
keyStr, oldDest.DestinationIP, oldDest.DestinationPort,
|
||||
newDest.DestinationIP, newDest.DestinationPort)
|
||||
}
|
||||
}
|
||||
|
||||
// If we updated any destinations, store the updated mapping back
|
||||
if updated {
|
||||
mapping.LastUsed = time.Now()
|
||||
s.proxyMappings.Store(keyStr, mapping)
|
||||
updatedCount++
|
||||
}
|
||||
|
||||
return true // continue iteration
|
||||
})
|
||||
|
||||
if updatedCount > 0 {
|
||||
logger.Info("Updated %d proxy mappings from %s:%d to %s:%d",
|
||||
updatedCount, oldDest.DestinationIP, oldDest.DestinationPort,
|
||||
newDest.DestinationIP, newDest.DestinationPort)
|
||||
}
|
||||
|
||||
return updatedCount
|
||||
}
|
||||
Reference in New Issue
Block a user