mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Compare commits
113 Commits
package
...
1.2.0-rc.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
29aa68ecf7 | ||
|
|
50a97b19d1 | ||
|
|
229ce7504f | ||
|
|
b4f3619aff | ||
|
|
e77a4fbd66 | ||
|
|
f8f368a981 | ||
|
|
153b986100 | ||
|
|
1c47c0981c | ||
|
|
defd85e118 | ||
|
|
ec1085f5f7 | ||
|
|
d13cc179e8 | ||
|
|
a39e6d4f2b | ||
|
|
4875835024 | ||
|
|
f5a74c36f8 | ||
|
|
c71828f5a1 | ||
|
|
dc83af6c2e | ||
|
|
35544e1081 | ||
|
|
2ddb4a5645 | ||
|
|
c25fb02f1e | ||
|
|
28583c9507 | ||
|
|
ba41602e4b | ||
|
|
3e24a77625 | ||
|
|
4b8b281d5b | ||
|
|
a07a714d93 | ||
|
|
58ce93f6c3 | ||
|
|
293e507000 | ||
|
|
c948208493 | ||
|
|
2106734aa4 | ||
|
|
51162d6be6 | ||
|
|
45ef6e5279 | ||
|
|
3b2ffe006a | ||
|
|
a497f0873f | ||
|
|
6e4ec246ef | ||
|
|
7270b840cf | ||
|
|
fb007e09a9 | ||
|
|
9ce6450351 | ||
|
|
672fff0ad9 | ||
|
|
22474d92ef | ||
|
|
0e4a657700 | ||
|
|
e24ee0e68b | ||
|
|
cea9ab0932 | ||
|
|
229dc6afce | ||
|
|
e2fe7d53f8 | ||
|
|
e8f1fb507c | ||
|
|
7e410cde28 | ||
|
|
afe0d338be | ||
|
|
a18b367e60 | ||
|
|
91e44e112e | ||
|
|
a38d1ef8a8 | ||
|
|
a32e91de24 | ||
|
|
92b551fa4b | ||
|
|
53c1fa117a | ||
|
|
50525aaf8d | ||
|
|
d8ced86d19 | ||
|
|
2718d15825 | ||
|
|
fff234bdd5 | ||
|
|
0802673048 | ||
|
|
d54b7e3f14 | ||
|
|
650084132b | ||
|
|
534631fb27 | ||
|
|
9d34c818d7 | ||
|
|
2436a5be15 | ||
|
|
430f2bf7fa | ||
|
|
16362f285d | ||
|
|
34c7f89804 | ||
|
|
ead8fab70a | ||
|
|
50008f3c12 | ||
|
|
24b5122cc1 | ||
|
|
9099b246dc | ||
|
|
30ff3c06eb | ||
|
|
d02ca20c06 | ||
|
|
7afe842a95 | ||
|
|
47d628af73 | ||
|
|
0f1e51f391 | ||
|
|
5d6024ac59 | ||
|
|
6c7ee31330 | ||
|
|
b38357875e | ||
|
|
c230c7be28 | ||
|
|
d7cd746cc9 | ||
|
|
7941479994 | ||
|
|
e3623fd756 | ||
|
|
68c2744ebe | ||
|
|
a9d8d0e5c6 | ||
|
|
7f94fbc1e4 | ||
|
|
542d7e5d61 | ||
|
|
f93f73f541 | ||
|
|
930bf7e0f2 | ||
|
|
d7345c7dbd | ||
|
|
3e2cb70d58 | ||
|
|
d4c5292e8f | ||
|
|
45047343c4 | ||
|
|
c09fb312e8 | ||
|
|
8dfb4b2b20 | ||
|
|
6b17cb08c0 | ||
|
|
aa866493aa | ||
|
|
7b28137cf6 | ||
|
|
a8383f5612 | ||
|
|
2fc385155e | ||
|
|
b7271b77b6 | ||
|
|
1ef6b7ada6 | ||
|
|
ea454d0528 | ||
|
|
a6670ccab3 | ||
|
|
f226e8f7f3 | ||
|
|
75890ca5a6 | ||
|
|
e6cf631dbc | ||
|
|
b87f90c211 | ||
|
|
3e0cefa3dc | ||
|
|
2fe3359ae8 | ||
|
|
0aa8f07be3 | ||
|
|
36d47a7331 | ||
|
|
10fa5acb0b | ||
|
|
e3a679609f | ||
|
|
b7a04dc511 |
@@ -1,9 +1,9 @@
|
||||
.gitignore
|
||||
.dockerignore
|
||||
olm
|
||||
*.json
|
||||
README.md
|
||||
Makefile
|
||||
public/
|
||||
LICENSE
|
||||
CONTRIBUTING.md
|
||||
CONTRIBUTING.md
|
||||
bin/
|
||||
4
.github/workflows/cicd.yml
vendored
4
.github/workflows/cicd.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
@@ -54,7 +54,7 @@ jobs:
|
||||
make go-build-release
|
||||
|
||||
- name: Upload artifacts from /bin
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: binaries
|
||||
path: bin/
|
||||
|
||||
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@@ -11,7 +11,13 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Clone fosrl/newt
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
repository: fosrl/newt
|
||||
path: ../newt
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
|
||||
@@ -16,7 +16,7 @@ COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /olm
|
||||
|
||||
# Start a new stage from scratch
|
||||
FROM alpine:3.22 AS runner
|
||||
FROM alpine:3.23 AS runner
|
||||
|
||||
RUN apk --no-cache add ca-certificates
|
||||
|
||||
|
||||
280
README.md
280
README.md
@@ -23,15 +23,21 @@ When Olm receives WireGuard control messages, it will use the information encode
|
||||
- `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket.
|
||||
- `id`: Olm ID generated by Pangolin to identify the olm.
|
||||
- `secret`: A unique secret (not shared and kept private) used to authenticate the olm ID with the websocket in order to receive commands.
|
||||
- `org` (optional): Organization ID to connect to.
|
||||
- `user-token` (optional): User authentication token.
|
||||
- `mtu` (optional): MTU for the internal WG interface. Default: 1280
|
||||
- `dns` (optional): DNS server to use to resolve the endpoint. Default: 8.8.8.8
|
||||
- `upstream-dns` (optional): Upstream DNS server(s), comma-separated. Default: 8.8.8.8:53
|
||||
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: INFO
|
||||
- `ping-interval` (optional): Interval for pinging the server. Default: 3s
|
||||
- `ping-timeout` (optional): Timeout for each ping. Default: 5s
|
||||
- `interface` (optional): Name of the WireGuard interface. Default: olm
|
||||
- `enable-http` (optional): Enable HTTP server for receiving connection requests. Default: false
|
||||
- `enable-api` (optional): Enable API server for receiving connection requests. Default: false
|
||||
- `http-addr` (optional): HTTP server address (e.g., ':9452'). Default: :9452
|
||||
- `holepunch` (optional): Enable hole punching. Default: false
|
||||
- `socket-path` (optional): Unix socket path (or named pipe on Windows). Default: /var/run/olm.sock (Linux/macOS) or olm (Windows)
|
||||
- `disable-holepunch` (optional): Disable hole punching. Default: false
|
||||
- `override-dns` (optional): Override system DNS settings. Default: false
|
||||
- `disable-relay` (optional): Disable relay connections. Default: false
|
||||
|
||||
## Environment Variables
|
||||
|
||||
@@ -40,14 +46,21 @@ All CLI arguments can also be set via environment variables:
|
||||
- `PANGOLIN_ENDPOINT`: Equivalent to `--endpoint`
|
||||
- `OLM_ID`: Equivalent to `--id`
|
||||
- `OLM_SECRET`: Equivalent to `--secret`
|
||||
- `ORG`: Equivalent to `--org`
|
||||
- `USER_TOKEN`: Equivalent to `--user-token`
|
||||
- `MTU`: Equivalent to `--mtu`
|
||||
- `DNS`: Equivalent to `--dns`
|
||||
- `UPSTREAM_DNS`: Equivalent to `--upstream-dns`
|
||||
- `LOG_LEVEL`: Equivalent to `--log-level`
|
||||
- `INTERFACE`: Equivalent to `--interface`
|
||||
- `ENABLE_API`: Set to "true" to enable API server (equivalent to `--enable-api`)
|
||||
- `HTTP_ADDR`: Equivalent to `--http-addr`
|
||||
- `SOCKET_PATH`: Equivalent to `--socket-path`
|
||||
- `PING_INTERVAL`: Equivalent to `--ping-interval`
|
||||
- `PING_TIMEOUT`: Equivalent to `--ping-timeout`
|
||||
- `HOLEPUNCH`: Set to "true" to enable hole punching (equivalent to `--holepunch`)
|
||||
- `DISABLE_HOLEPUNCH`: Set to "true" to disable hole punching (equivalent to `--disable-holepunch`)
|
||||
- `OVERRIDE_DNS`: Set to "true" to override system DNS settings (equivalent to `--override-dns`)
|
||||
- `DISABLE_RELAY`: Set to "true" to disable relay connections (equivalent to `--disable-relay`)
|
||||
- `CONFIG_FILE`: Set to the location of a JSON file to load secret values
|
||||
|
||||
Examples:
|
||||
@@ -108,11 +121,26 @@ $ cat ~/.config/olm-client/config.json
|
||||
"id": "spmzu8rbpzj1qq6",
|
||||
"secret": "f6v61mjutwme2kkydbw3fjo227zl60a2tsf5psw9r25hgae3",
|
||||
"endpoint": "https://app.pangolin.net",
|
||||
"org": "",
|
||||
"userToken": "",
|
||||
"mtu": 1280,
|
||||
"dns": "8.8.8.8",
|
||||
"upstreamDNS": ["8.8.8.8:53"],
|
||||
"interface": "olm",
|
||||
"logLevel": "INFO",
|
||||
"enableApi": false,
|
||||
"httpAddr": "",
|
||||
"socketPath": "/var/run/olm.sock",
|
||||
"pingInterval": "3s",
|
||||
"pingTimeout": "5s",
|
||||
"disableHolepunch": false,
|
||||
"overrideDNS": false,
|
||||
"disableRelay": false,
|
||||
"tlsClientCert": ""
|
||||
}
|
||||
```
|
||||
|
||||
This file is also written to when newt first starts up. So you do not need to run every time with --id and secret if you have run it once!
|
||||
This file is also written to when olm first starts up. So you do not need to run every time with --id and secret if you have run it once!
|
||||
|
||||
Default locations:
|
||||
|
||||
@@ -122,7 +150,7 @@ Default locations:
|
||||
|
||||
## Hole Punching
|
||||
|
||||
In the default mode, olm "relays" traffic through Gerbil in the cloud to get down to newt. This is a little more reliable. Support for NAT hole punching is also EXPERIMENTAL right now using the `--holepunch` flag. This will attempt to orchestrate a NAT hole punch between the two sites so that traffic flows directly. This will save data costs and speed. If it fails it should fall back to relaying.
|
||||
In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to newt. If you want to disable hole punching, use the `--disable-holepunch` flag. Hole punching attempts to orchestrate a NAT hole punch between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil.
|
||||
|
||||
Right now, basic NAT hole punching is supported. We plan to add:
|
||||
|
||||
@@ -182,26 +210,75 @@ You can view the Windows Event Log using Event Viewer or PowerShell:
|
||||
Get-EventLog -LogName Application -Source "OlmWireguardService" -Newest 10
|
||||
```
|
||||
|
||||
## HTTP Endpoints
|
||||
## HTTP API
|
||||
|
||||
Olm can be controlled with an embedded http server when using `--enable-http`. This allows you to start it as a daemon and trigger it with the following endpoints:
|
||||
Olm can be controlled with an embedded HTTP server when using `--enable-http`. This allows you to start it as a daemon and trigger it with the following endpoints. The API can listen on either a TCP address or a Unix socket/Windows named pipe.
|
||||
|
||||
### Socket vs TCP
|
||||
|
||||
By default, when `--enable-http` is used, Olm listens on a TCP address (configured via `--http-addr`, default `:9452`). Alternatively, Olm can listen on a Unix socket (Linux/macOS) or Windows named pipe for local-only communication with better security.
|
||||
|
||||
**Unix Socket (Linux/macOS):**
|
||||
- Socket path example: `/var/run/olm/olm.sock`
|
||||
- The directory is created automatically if it doesn't exist
|
||||
- Socket permissions are set to `0666` to allow access
|
||||
- Existing socket files are automatically removed on startup
|
||||
- Socket file is cleaned up when Olm stops
|
||||
|
||||
**Windows Named Pipe:**
|
||||
- Pipe path example: `\\.\pipe\olm`
|
||||
- If the path doesn't start with `\`, it's automatically prefixed with `\\.\pipe\`
|
||||
- Security descriptor grants full access to Everyone and the current owner
|
||||
- Named pipes are automatically cleaned up by Windows
|
||||
|
||||
**Connecting to the Socket:**
|
||||
|
||||
```bash
|
||||
# Linux/macOS - using curl with Unix socket
|
||||
curl --unix-socket /var/run/olm/olm.sock http://localhost/status
|
||||
|
||||
---
|
||||
|
||||
### POST /connect
|
||||
Initiates a new connection request.
|
||||
Initiates a new connection request to a Pangolin server.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"id": "string",
|
||||
"secret": "string",
|
||||
"endpoint": "string"
|
||||
"secret": "string",
|
||||
"endpoint": "string",
|
||||
"userToken": "string",
|
||||
"mtu": 1280,
|
||||
"dns": "8.8.8.8",
|
||||
"dnsProxyIP": "string",
|
||||
"upstreamDNS": ["8.8.8.8:53", "1.1.1.1:53"],
|
||||
"interfaceName": "olm",
|
||||
"holepunch": false,
|
||||
"tlsClientCert": "string",
|
||||
"pingInterval": "3s",
|
||||
"pingTimeout": "5s",
|
||||
"orgId": "string"
|
||||
}
|
||||
```
|
||||
|
||||
**Required Fields:**
|
||||
- `id`: Connection identifier
|
||||
- `secret`: Authentication secret
|
||||
- `endpoint`: Target endpoint URL
|
||||
- `id`: Olm ID generated by Pangolin
|
||||
- `secret`: Authentication secret for the Olm ID
|
||||
- `endpoint`: Target Pangolin endpoint URL
|
||||
|
||||
**Optional Fields:**
|
||||
- `userToken`: User authentication token
|
||||
- `mtu`: MTU for the internal WireGuard interface (default: 1280)
|
||||
- `dns`: DNS server to use for resolving the endpoint
|
||||
- `dnsProxyIP`: DNS proxy IP address
|
||||
- `upstreamDNS`: Array of upstream DNS servers
|
||||
- `interfaceName`: Name of the WireGuard interface (default: olm)
|
||||
- `holepunch`: Enable NAT hole punching (default: false)
|
||||
- `tlsClientCert`: TLS client certificate
|
||||
- `pingInterval`: Interval for pinging the server (default: 3s)
|
||||
- `pingTimeout`: Timeout for each ping (default: 5s)
|
||||
- `orgId`: Organization ID to connect to
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `202 Accepted`
|
||||
@@ -216,9 +293,12 @@ Initiates a new connection request.
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `400 Bad Request` - Invalid JSON or missing required fields
|
||||
- `409 Conflict` - Already connected to a server (disconnect first)
|
||||
|
||||
---
|
||||
|
||||
### GET /status
|
||||
Returns the current connection status and peer information.
|
||||
Returns the current connection status, registration state, and peer information.
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
@@ -226,52 +306,162 @@ Returns the current connection status and peer information.
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "connected",
|
||||
"connected": true,
|
||||
"tunnelIP": "100.89.128.3/20",
|
||||
"version": "version_replaceme",
|
||||
"registered": true,
|
||||
"terminated": false,
|
||||
"version": "1.0.0",
|
||||
"agent": "olm",
|
||||
"orgId": "org_123",
|
||||
"peers": {
|
||||
"10": {
|
||||
"siteId": 10,
|
||||
"name": "Site A",
|
||||
"connected": true,
|
||||
"rtt": 145338339,
|
||||
"lastSeen": "2025-08-13T14:39:17.208334428-07:00",
|
||||
"endpoint": "p.fosrl.io:21820",
|
||||
"isRelay": true
|
||||
"isRelay": true,
|
||||
"peerAddress": "100.89.128.5",
|
||||
"holepunchConnected": false
|
||||
},
|
||||
"8": {
|
||||
"siteId": 8,
|
||||
"name": "Site B",
|
||||
"connected": false,
|
||||
"rtt": 0,
|
||||
"lastSeen": "2025-08-13T14:39:19.663823645-07:00",
|
||||
"endpoint": "p.fosrl.io:21820",
|
||||
"isRelay": true
|
||||
"isRelay": true,
|
||||
"peerAddress": "100.89.128.10",
|
||||
"holepunchConnected": false
|
||||
}
|
||||
},
|
||||
"networkSettings": {
|
||||
"tunnelIP": "100.89.128.3/20"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Fields:**
|
||||
- `status`: Overall connection status ("connected" or "disconnected")
|
||||
- `connected`: Boolean connection state
|
||||
- `tunnelIP`: IP address and subnet of the tunnel (when connected)
|
||||
- `connected`: Boolean indicating if connected to Pangolin
|
||||
- `registered`: Boolean indicating if registered with the server
|
||||
- `terminated`: Boolean indicating if the connection was terminated
|
||||
- `version`: Olm version string
|
||||
- `agent`: Agent identifier
|
||||
- `orgId`: Current organization ID
|
||||
- `peers`: Map of peer statuses by site ID
|
||||
- `siteId`: Peer site identifier
|
||||
- `name`: Site name
|
||||
- `connected`: Boolean peer connection state
|
||||
- `rtt`: Peer round-trip time (integer, nanoseconds)
|
||||
- `lastSeen`: Last time peer was seen (RFC3339 timestamp)
|
||||
- `endpoint`: Peer endpoint address
|
||||
- `isRelay`: Whether the peer is relayed (true) or direct (false)
|
||||
- `peerAddress`: Peer's IP address in the tunnel
|
||||
- `holepunchConnected`: Whether holepunch connection is established
|
||||
- `networkSettings`: Current network configuration including tunnel IP
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-GET requests
|
||||
|
||||
---
|
||||
|
||||
### POST /disconnect
|
||||
Disconnects from the current Pangolin server and tears down the WireGuard tunnel.
|
||||
|
||||
**Request Body:** None required
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "disconnect initiated"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `409 Conflict` - Not currently connected to a server
|
||||
|
||||
---
|
||||
|
||||
### POST /switch-org
|
||||
Switches to a different organization while maintaining the connection.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"orgId": "string"
|
||||
}
|
||||
```
|
||||
|
||||
**Required Fields:**
|
||||
- `orgId`: The organization ID to switch to
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "org switch request accepted"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
- `400 Bad Request` - Invalid JSON or missing orgId field
|
||||
- `500 Internal Server Error` - Org switch failed
|
||||
|
||||
---
|
||||
|
||||
### POST /exit
|
||||
Initiates a graceful shutdown of the Olm process.
|
||||
|
||||
**Request Body:** None required
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "shutdown initiated"
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** The response is sent before shutdown begins. There is a 100ms delay before the actual shutdown to ensure the response is delivered.
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-POST requests
|
||||
|
||||
---
|
||||
|
||||
### GET /health
|
||||
Simple health check endpoint to verify the API server is running.
|
||||
|
||||
**Response:**
|
||||
- **Status Code:** `200 OK`
|
||||
- **Content-Type:** `application/json`
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "ok"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Responses:**
|
||||
- `405 Method Not Allowed` - Non-GET requests
|
||||
|
||||
---
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Connect to a peer
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/connect \
|
||||
curl -X POST http://localhost:9452/connect \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"id": "31frd0uzbjvp721",
|
||||
@@ -280,9 +470,51 @@ curl -X POST http://localhost:8080/connect \
|
||||
}'
|
||||
```
|
||||
|
||||
### Connect with additional options
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/connect \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"id": "31frd0uzbjvp721",
|
||||
"secret": "h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6",
|
||||
"endpoint": "https://example.com",
|
||||
"mtu": 1400,
|
||||
"holepunch": true,
|
||||
"pingInterval": "5s"
|
||||
}'
|
||||
```
|
||||
|
||||
### Check connection status
|
||||
```bash
|
||||
curl http://localhost:8080/status
|
||||
curl http://localhost:9452/status
|
||||
```
|
||||
|
||||
### Switch organization
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/switch-org \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"orgId": "org_456"}'
|
||||
```
|
||||
|
||||
### Disconnect from server
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/disconnect
|
||||
```
|
||||
|
||||
### Health check
|
||||
```bash
|
||||
curl http://localhost:9452/health
|
||||
```
|
||||
|
||||
### Shutdown Olm
|
||||
```bash
|
||||
curl -X POST http://localhost:9452/exit
|
||||
```
|
||||
|
||||
### Using Unix socket (Linux/macOS)
|
||||
```bash
|
||||
curl --unix-socket /var/run/olm/olm.sock http://localhost/status
|
||||
curl --unix-socket /var/run/olm/olm.sock -X POST http://localhost/disconnect
|
||||
```
|
||||
|
||||
## Build
|
||||
|
||||
282
api/api.go
282
api/api.go
@@ -9,14 +9,25 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
)
|
||||
|
||||
// ConnectionRequest defines the structure for an incoming connection request
|
||||
type ConnectionRequest struct {
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
UserToken string `json:"userToken,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
UserToken string `json:"userToken,omitempty"`
|
||||
MTU int `json:"mtu,omitempty"`
|
||||
DNS string `json:"dns,omitempty"`
|
||||
DNSProxyIP string `json:"dnsProxyIP,omitempty"`
|
||||
UpstreamDNS []string `json:"upstreamDNS,omitempty"`
|
||||
InterfaceName string `json:"interfaceName,omitempty"`
|
||||
Holepunch bool `json:"holepunch,omitempty"`
|
||||
TlsClientCert string `json:"tlsClientCert,omitempty"`
|
||||
PingInterval string `json:"pingInterval,omitempty"`
|
||||
PingTimeout string `json:"pingTimeout,omitempty"`
|
||||
OrgID string `json:"orgId,omitempty"`
|
||||
}
|
||||
|
||||
// SwitchOrgRequest defines the structure for switching organizations
|
||||
@@ -26,54 +37,55 @@ type SwitchOrgRequest struct {
|
||||
|
||||
// PeerStatus represents the status of a peer connection
|
||||
type PeerStatus struct {
|
||||
SiteID int `json:"siteId"`
|
||||
Connected bool `json:"connected"`
|
||||
RTT time.Duration `json:"rtt"`
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
IsRelay bool `json:"isRelay"`
|
||||
PeerIP string `json:"peerAddress,omitempty"`
|
||||
SiteID int `json:"siteId"`
|
||||
Name string `json:"name"`
|
||||
Connected bool `json:"connected"`
|
||||
RTT time.Duration `json:"rtt"`
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
IsRelay bool `json:"isRelay"`
|
||||
PeerIP string `json:"peerAddress,omitempty"`
|
||||
HolepunchConnected bool `json:"holepunchConnected"`
|
||||
}
|
||||
|
||||
// StatusResponse is returned by the status endpoint
|
||||
type StatusResponse struct {
|
||||
Connected bool `json:"connected"`
|
||||
Registered bool `json:"registered"`
|
||||
TunnelIP string `json:"tunnelIP,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
OrgID string `json:"orgId,omitempty"`
|
||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||
Connected bool `json:"connected"`
|
||||
Registered bool `json:"registered"`
|
||||
Terminated bool `json:"terminated"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Agent string `json:"agent,omitempty"`
|
||||
OrgID string `json:"orgId,omitempty"`
|
||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
|
||||
}
|
||||
|
||||
// API represents the HTTP server and its state
|
||||
type API struct {
|
||||
addr string
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
connectionChan chan ConnectionRequest
|
||||
switchOrgChan chan SwitchOrgRequest
|
||||
shutdownChan chan struct{}
|
||||
disconnectChan chan struct{}
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
connectedAt time.Time
|
||||
isConnected bool
|
||||
isRegistered bool
|
||||
tunnelIP string
|
||||
version string
|
||||
orgID string
|
||||
addr string
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
onConnect func(ConnectionRequest) error
|
||||
onSwitchOrg func(SwitchOrgRequest) error
|
||||
onDisconnect func() error
|
||||
onExit func() error
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
connectedAt time.Time
|
||||
isConnected bool
|
||||
isRegistered bool
|
||||
isTerminated bool
|
||||
version string
|
||||
agent string
|
||||
orgID string
|
||||
}
|
||||
|
||||
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||
func NewAPI(addr string) *API {
|
||||
s := &API{
|
||||
addr: addr,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||
shutdownChan: make(chan struct{}, 1),
|
||||
disconnectChan: make(chan struct{}, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
addr: addr,
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -82,17 +94,26 @@ func NewAPI(addr string) *API {
|
||||
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
|
||||
func NewAPISocket(socketPath string) *API {
|
||||
s := &API{
|
||||
socketPath: socketPath,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||
shutdownChan: make(chan struct{}, 1),
|
||||
disconnectChan: make(chan struct{}, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
socketPath: socketPath,
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SetHandlers sets the callback functions for handling API requests
|
||||
func (s *API) SetHandlers(
|
||||
onConnect func(ConnectionRequest) error,
|
||||
onSwitchOrg func(SwitchOrgRequest) error,
|
||||
onDisconnect func() error,
|
||||
onExit func() error,
|
||||
) {
|
||||
s.onConnect = onConnect
|
||||
s.onSwitchOrg = onSwitchOrg
|
||||
s.onDisconnect = onDisconnect
|
||||
s.onExit = onExit
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *API) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
@@ -101,6 +122,7 @@ func (s *API) Start() error {
|
||||
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||
mux.HandleFunc("/disconnect", s.handleDisconnect)
|
||||
mux.HandleFunc("/exit", s.handleExit)
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
s.server = &http.Server{
|
||||
Handler: mux,
|
||||
@@ -149,24 +171,24 @@ func (s *API) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConnectionChannel returns the channel for receiving connection requests
|
||||
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
|
||||
return s.connectionChan
|
||||
}
|
||||
func (s *API) AddPeerStatus(siteID int, siteName string, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
||||
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
||||
return s.switchOrgChan
|
||||
}
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
// GetShutdownChannel returns the channel for receiving shutdown requests
|
||||
func (s *API) GetShutdownChannel() <-chan struct{} {
|
||||
return s.shutdownChan
|
||||
}
|
||||
|
||||
// GetDisconnectChannel returns the channel for receiving disconnect requests
|
||||
func (s *API) GetDisconnectChannel() <-chan struct{} {
|
||||
return s.disconnectChan
|
||||
status.Name = siteName
|
||||
status.Connected = connected
|
||||
status.RTT = rtt
|
||||
status.LastSeen = time.Now()
|
||||
status.Endpoint = endpoint
|
||||
status.IsRelay = isRelay
|
||||
}
|
||||
|
||||
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||
@@ -189,6 +211,12 @@ func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, en
|
||||
status.IsRelay = isRelay
|
||||
}
|
||||
|
||||
func (s *API) RemovePeerStatus(siteID int) { // remove the peer from the status map
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
delete(s.peerStatuses, siteID)
|
||||
}
|
||||
|
||||
// SetConnectionStatus sets the overall connection status
|
||||
func (s *API) SetConnectionStatus(isConnected bool) {
|
||||
s.statusMu.Lock()
|
||||
@@ -210,11 +238,17 @@ func (s *API) SetRegistered(registered bool) {
|
||||
s.isRegistered = registered
|
||||
}
|
||||
|
||||
// SetTunnelIP sets the tunnel IP address
|
||||
func (s *API) SetTunnelIP(tunnelIP string) {
|
||||
func (s *API) SetTerminated(terminated bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.tunnelIP = tunnelIP
|
||||
s.isTerminated = terminated
|
||||
}
|
||||
|
||||
// ClearPeerStatuses clears all peer statuses
|
||||
func (s *API) ClearPeerStatuses() {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.peerStatuses = make(map[int]*PeerStatus)
|
||||
}
|
||||
|
||||
// SetVersion sets the olm version
|
||||
@@ -224,6 +258,13 @@ func (s *API) SetVersion(version string) {
|
||||
s.version = version
|
||||
}
|
||||
|
||||
// SetAgent sets the olm agent
|
||||
func (s *API) SetAgent(agent string) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.agent = agent
|
||||
}
|
||||
|
||||
// SetOrgID sets the organization ID
|
||||
func (s *API) SetOrgID(orgID string) {
|
||||
s.statusMu.Lock()
|
||||
@@ -248,6 +289,22 @@ func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||
status.IsRelay = isRelay
|
||||
}
|
||||
|
||||
// UpdatePeerHolepunchStatus updates the holepunch connection status of a peer
|
||||
func (s *API) UpdatePeerHolepunchStatus(siteID int, holepunchConnected bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.HolepunchConnected = holepunchConnected
|
||||
}
|
||||
|
||||
// handleConnect handles the /connect endpoint
|
||||
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -277,8 +334,13 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Send the request to the main goroutine
|
||||
s.connectionChan <- req
|
||||
// Call the connect handler if set
|
||||
if s.onConnect != nil {
|
||||
if err := s.onConnect(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Connection failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -299,18 +361,34 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
defer s.statusMu.RUnlock()
|
||||
|
||||
resp := StatusResponse{
|
||||
Connected: s.isConnected,
|
||||
Registered: s.isRegistered,
|
||||
TunnelIP: s.tunnelIP,
|
||||
Version: s.version,
|
||||
OrgID: s.orgID,
|
||||
PeerStatuses: s.peerStatuses,
|
||||
Connected: s.isConnected,
|
||||
Registered: s.isRegistered,
|
||||
Terminated: s.isTerminated,
|
||||
Version: s.version,
|
||||
Agent: s.agent,
|
||||
OrgID: s.orgID,
|
||||
PeerStatuses: s.peerStatuses,
|
||||
NetworkSettings: network.GetSettings(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// handleHealth handles the /health endpoint
|
||||
func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "ok",
|
||||
})
|
||||
}
|
||||
|
||||
// handleExit handles the /exit endpoint
|
||||
func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -320,20 +398,23 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
logger.Info("Received exit request via API")
|
||||
|
||||
// Send shutdown signal
|
||||
select {
|
||||
case s.shutdownChan <- struct{}{}:
|
||||
// Signal sent successfully
|
||||
default:
|
||||
// Channel already has a signal, don't block
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
// Return a success response first
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "shutdown initiated",
|
||||
})
|
||||
|
||||
// Call the exit handler after responding, in a goroutine with a small delay
|
||||
// to ensure the response is fully sent before shutdown begins
|
||||
if s.onExit != nil {
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if err := s.onExit(); err != nil {
|
||||
logger.Error("Exit handler failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// handleSwitchOrg handles the /switch-org endpoint
|
||||
@@ -358,14 +439,12 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||
|
||||
// Send the request to the main goroutine
|
||||
select {
|
||||
case s.switchOrgChan <- req:
|
||||
// Signal sent successfully
|
||||
default:
|
||||
// Channel already has a pending request
|
||||
http.Error(w, "Org switch already in progress", http.StatusConflict)
|
||||
return
|
||||
// Call the switch org handler if set
|
||||
if s.onSwitchOrg != nil {
|
||||
if err := s.onSwitchOrg(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Org switch failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
@@ -394,12 +473,12 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
logger.Info("Received disconnect request via API")
|
||||
|
||||
// Send disconnect signal
|
||||
select {
|
||||
case s.disconnectChan <- struct{}{}:
|
||||
// Signal sent successfully
|
||||
default:
|
||||
// Channel already has a signal, don't block
|
||||
// Call the disconnect handler if set
|
||||
if s.onDisconnect != nil {
|
||||
if err := s.onDisconnect(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Disconnect failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
@@ -409,3 +488,16 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
||||
"status": "disconnect initiated",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *API) GetStatus() StatusResponse {
|
||||
return StatusResponse{
|
||||
Connected: s.isConnected,
|
||||
Registered: s.isRegistered,
|
||||
Terminated: s.isTerminated,
|
||||
Version: s.version,
|
||||
Agent: s.agent,
|
||||
OrgID: s.orgID,
|
||||
PeerStatuses: s.peerStatuses,
|
||||
NetworkSettings: network.GetSettings(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,378 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
// Endpoint represents a network endpoint for the SharedBind
|
||||
type Endpoint struct {
|
||||
AddrPort netip.AddrPort
|
||||
}
|
||||
|
||||
// ClearSrc implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) ClearSrc() {}
|
||||
|
||||
// DstIP implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) DstIP() netip.Addr {
|
||||
return e.AddrPort.Addr()
|
||||
}
|
||||
|
||||
// SrcIP implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
// DstToBytes implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) DstToBytes() []byte {
|
||||
b, _ := e.AddrPort.MarshalBinary()
|
||||
return b
|
||||
}
|
||||
|
||||
// DstToString implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) DstToString() string {
|
||||
return e.AddrPort.String()
|
||||
}
|
||||
|
||||
// SrcToString implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
|
||||
// and hole punch senders. It wraps a single UDP connection and implements
|
||||
// reference counting to prevent premature closure.
|
||||
type SharedBind struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// The underlying UDP connection
|
||||
udpConn *net.UDPConn
|
||||
|
||||
// IPv4 and IPv6 packet connections for advanced features
|
||||
ipv4PC *ipv4.PacketConn
|
||||
ipv6PC *ipv6.PacketConn
|
||||
|
||||
// Reference counting to prevent closing while in use
|
||||
refCount atomic.Int32
|
||||
closed atomic.Bool
|
||||
|
||||
// Channels for receiving data
|
||||
recvFuncs []wgConn.ReceiveFunc
|
||||
|
||||
// Port binding information
|
||||
port uint16
|
||||
}
|
||||
|
||||
// New creates a new SharedBind from an existing UDP connection.
|
||||
// The SharedBind takes ownership of the connection and will close it
|
||||
// when all references are released.
|
||||
func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
||||
if udpConn == nil {
|
||||
return nil, fmt.Errorf("udpConn cannot be nil")
|
||||
}
|
||||
|
||||
bind := &SharedBind{
|
||||
udpConn: udpConn,
|
||||
}
|
||||
|
||||
// Initialize reference count to 1 (the creator holds the first reference)
|
||||
bind.refCount.Store(1)
|
||||
|
||||
// Get the local port
|
||||
if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok {
|
||||
bind.port = uint16(addr.Port)
|
||||
}
|
||||
|
||||
return bind, nil
|
||||
}
|
||||
|
||||
// AddRef increments the reference count. Call this when sharing
|
||||
// the bind with another component.
|
||||
func (b *SharedBind) AddRef() {
|
||||
newCount := b.refCount.Add(1)
|
||||
// Optional: Add logging for debugging
|
||||
_ = newCount // Placeholder for potential logging
|
||||
}
|
||||
|
||||
// Release decrements the reference count. When it reaches zero,
|
||||
// the underlying UDP connection is closed.
|
||||
func (b *SharedBind) Release() error {
|
||||
newCount := b.refCount.Add(-1)
|
||||
// Optional: Add logging for debugging
|
||||
_ = newCount // Placeholder for potential logging
|
||||
|
||||
if newCount < 0 {
|
||||
// This should never happen with proper usage
|
||||
b.refCount.Store(0)
|
||||
return fmt.Errorf("SharedBind reference count went negative")
|
||||
}
|
||||
|
||||
if newCount == 0 {
|
||||
return b.closeConnection()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeConnection actually closes the UDP connection
|
||||
func (b *SharedBind) closeConnection() error {
|
||||
if !b.closed.CompareAndSwap(false, true) {
|
||||
// Already closed
|
||||
return nil
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
var err error
|
||||
if b.udpConn != nil {
|
||||
err = b.udpConn.Close()
|
||||
b.udpConn = nil
|
||||
}
|
||||
|
||||
b.ipv4PC = nil
|
||||
b.ipv6PC = nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetUDPConn returns the underlying UDP connection.
|
||||
// The caller must not close this connection directly.
|
||||
func (b *SharedBind) GetUDPConn() *net.UDPConn {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return b.udpConn
|
||||
}
|
||||
|
||||
// GetRefCount returns the current reference count (for debugging)
|
||||
func (b *SharedBind) GetRefCount() int32 {
|
||||
return b.refCount.Load()
|
||||
}
|
||||
|
||||
// IsClosed returns whether the bind is closed
|
||||
func (b *SharedBind) IsClosed() bool {
|
||||
return b.closed.Load()
|
||||
}
|
||||
|
||||
// WriteToUDP writes data to a specific UDP address.
|
||||
// This is thread-safe and can be used by hole punch senders.
|
||||
func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) {
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
return conn.WriteToUDP(data, addr)
|
||||
}
|
||||
|
||||
// Close implements the WireGuard Bind interface.
|
||||
// It decrements the reference count and closes the connection if no references remain.
|
||||
func (b *SharedBind) Close() error {
|
||||
return b.Release()
|
||||
}
|
||||
|
||||
// Open implements the WireGuard Bind interface.
|
||||
// Since the connection is already open, this just sets up the receive functions.
|
||||
func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||
if b.closed.Load() {
|
||||
return nil, 0, net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.udpConn == nil {
|
||||
return nil, 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Set up IPv4 and IPv6 packet connections for advanced features
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
|
||||
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
|
||||
}
|
||||
|
||||
// Create receive functions
|
||||
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
|
||||
|
||||
// Add IPv4 receive function
|
||||
if b.ipv4PC != nil || runtime.GOOS != "linux" {
|
||||
recvFuncs = append(recvFuncs, b.makeReceiveIPv4())
|
||||
}
|
||||
|
||||
// Add IPv6 receive function if needed
|
||||
// For now, we focus on IPv4 for hole punching use case
|
||||
|
||||
b.recvFuncs = recvFuncs
|
||||
return recvFuncs, b.port, nil
|
||||
}
|
||||
|
||||
// makeReceiveIPv4 creates a receive function for IPv4 packets
|
||||
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
pc := b.ipv4PC
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Use batch reading on Linux for performance
|
||||
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
||||
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
||||
}
|
||||
|
||||
// Fallback to simple read for other platforms
|
||||
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
// receiveIPv4Batch uses batch reading for better performance on Linux
|
||||
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||
// Create messages for batch reading
|
||||
msgs := make([]ipv4.Message, len(bufs))
|
||||
for i := range bufs {
|
||||
msgs[i].Buffers = [][]byte{bufs[i]}
|
||||
msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use
|
||||
}
|
||||
|
||||
numMsgs, err := pc.ReadBatch(msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
sizes[i] = msgs[i].N
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if msgs[i].Addr != nil {
|
||||
if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok {
|
||||
addrPort := udpAddr.AddrPort()
|
||||
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return numMsgs, nil
|
||||
}
|
||||
|
||||
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
|
||||
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||
n, addr, err := conn.ReadFromUDP(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
sizes[0] = n
|
||||
if addr != nil {
|
||||
addrPort := addr.AddrPort()
|
||||
eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
}
|
||||
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// Send implements the WireGuard Bind interface.
|
||||
// It sends packets to the specified endpoint.
|
||||
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
if b.closed.Load() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
// Extract the destination address from the endpoint
|
||||
var destAddr *net.UDPAddr
|
||||
|
||||
// Try to cast to StdNetEndpoint first
|
||||
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
|
||||
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
|
||||
} else {
|
||||
// Fallback: construct from DstIP and DstToBytes
|
||||
dstBytes := ep.DstToBytes()
|
||||
if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes)
|
||||
var addr netip.Addr
|
||||
var port uint16
|
||||
|
||||
if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes)
|
||||
addr, _ = netip.AddrFromSlice(dstBytes[:16])
|
||||
port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8
|
||||
} else { // IPv4
|
||||
addr, _ = netip.AddrFromSlice(dstBytes[:4])
|
||||
port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if destAddr == nil {
|
||||
return fmt.Errorf("could not extract destination address from endpoint")
|
||||
}
|
||||
|
||||
// Send all buffers to the destination
|
||||
for _, buf := range bufs {
|
||||
_, err := conn.WriteToUDP(buf, destAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMark implements the WireGuard Bind interface.
|
||||
// It's a no-op for this implementation.
|
||||
func (b *SharedBind) SetMark(mark uint32) error {
|
||||
// Not implemented for this use case
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchSize returns the preferred batch size for sending packets.
|
||||
func (b *SharedBind) BatchSize() int {
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
return wgConn.IdealBatchSize
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// ParseEndpoint creates a new endpoint from a string address.
|
||||
func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
|
||||
addrPort, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
@@ -1,424 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
// TestSharedBindCreation tests basic creation and initialization
|
||||
func TestSharedBindCreation(t *testing.T) {
|
||||
// Create a UDP connection
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
// Create SharedBind
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
|
||||
if bind == nil {
|
||||
t.Fatal("SharedBind is nil")
|
||||
}
|
||||
|
||||
// Verify initial reference count
|
||||
if bind.refCount.Load() != 1 {
|
||||
t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
// Clean up
|
||||
if err := bind.Close(); err != nil {
|
||||
t.Errorf("Failed to close SharedBind: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindReferenceCount tests reference counting
|
||||
func TestSharedBindReferenceCount(t *testing.T) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
|
||||
// Add references
|
||||
bind.AddRef()
|
||||
if bind.refCount.Load() != 2 {
|
||||
t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
bind.AddRef()
|
||||
if bind.refCount.Load() != 3 {
|
||||
t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
// Release references
|
||||
bind.Release()
|
||||
if bind.refCount.Load() != 2 {
|
||||
t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
bind.Release()
|
||||
bind.Release() // This should close the connection
|
||||
|
||||
if !bind.closed.Load() {
|
||||
t.Error("Expected bind to be closed after all references released")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
|
||||
func TestSharedBindWriteToUDP(t *testing.T) {
|
||||
// Create sender
|
||||
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||
}
|
||||
|
||||
senderBind, err := New(senderConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||
}
|
||||
defer senderBind.Close()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
// Send data
|
||||
testData := []byte("Hello, SharedBind!")
|
||||
n, err := senderBind.WriteToUDP(testData, receiverAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteToUDP failed: %v", err)
|
||||
}
|
||||
|
||||
if n != len(testData) {
|
||||
t.Errorf("Expected to send %d bytes, sent %d", len(testData), n)
|
||||
}
|
||||
|
||||
// Receive data
|
||||
buf := make([]byte, 1024)
|
||||
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, _, err = receiverConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive data: %v", err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(testData) {
|
||||
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindConcurrentWrites tests thread-safety
|
||||
func TestSharedBindConcurrentWrites(t *testing.T) {
|
||||
// Create sender
|
||||
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||
}
|
||||
|
||||
senderBind, err := New(senderConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||
}
|
||||
defer senderBind.Close()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
// Launch concurrent writes
|
||||
numGoroutines := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
data := []byte{byte(id)}
|
||||
_, err := senderBind.WriteToUDP(data, receiverAddr)
|
||||
if err != nil {
|
||||
t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
|
||||
func TestSharedBindWireGuardInterface(t *testing.T) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
defer bind.Close()
|
||||
|
||||
// Test Open
|
||||
recvFuncs, port, err := bind.Open(0)
|
||||
if err != nil {
|
||||
t.Fatalf("Open failed: %v", err)
|
||||
}
|
||||
|
||||
if len(recvFuncs) == 0 {
|
||||
t.Error("Expected at least one receive function")
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
t.Error("Expected non-zero port")
|
||||
}
|
||||
|
||||
// Test SetMark (should be a no-op)
|
||||
if err := bind.SetMark(0); err != nil {
|
||||
t.Errorf("SetMark failed: %v", err)
|
||||
}
|
||||
|
||||
// Test BatchSize
|
||||
batchSize := bind.BatchSize()
|
||||
if batchSize <= 0 {
|
||||
t.Error("Expected positive batch size")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindSend tests the Send method with WireGuard endpoints
|
||||
func TestSharedBindSend(t *testing.T) {
|
||||
// Create sender
|
||||
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||
}
|
||||
|
||||
senderBind, err := New(senderConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||
}
|
||||
defer senderBind.Close()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
// Create an endpoint
|
||||
addrPort := receiverAddr.AddrPort()
|
||||
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
|
||||
// Send data
|
||||
testData := []byte("WireGuard packet")
|
||||
bufs := [][]byte{testData}
|
||||
err = senderBind.Send(bufs, endpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("Send failed: %v", err)
|
||||
}
|
||||
|
||||
// Receive data
|
||||
buf := make([]byte, 1024)
|
||||
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, _, err := receiverConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive data: %v", err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(testData) {
|
||||
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
|
||||
func TestSharedBindMultipleUsers(t *testing.T) {
|
||||
// Create shared bind
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
sharedBind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
|
||||
// Add reference for hole punch sender
|
||||
sharedBind.AddRef()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Simulate WireGuard using the bind
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
addrPort := receiverAddr.AddrPort()
|
||||
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
data := []byte("WireGuard packet")
|
||||
bufs := [][]byte{data}
|
||||
if err := sharedBind.Send(bufs, endpoint); err != nil {
|
||||
t.Errorf("WireGuard Send failed: %v", err)
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
// Simulate hole punch sender using the bind
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 10; i++ {
|
||||
data := []byte("Hole punch packet")
|
||||
if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil {
|
||||
t.Errorf("Hole punch WriteToUDP failed: %v", err)
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Release the hole punch reference
|
||||
sharedBind.Release()
|
||||
|
||||
// Close WireGuard's reference (should close the connection)
|
||||
sharedBind.Close()
|
||||
|
||||
if !sharedBind.closed.Load() {
|
||||
t.Error("Expected bind to be closed after all users released it")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEndpoint tests the Endpoint implementation
|
||||
func TestEndpoint(t *testing.T) {
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
addrPort := netip.AddrPortFrom(addr, 51820)
|
||||
|
||||
ep := &Endpoint{AddrPort: addrPort}
|
||||
|
||||
// Test DstIP
|
||||
if ep.DstIP() != addr {
|
||||
t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP())
|
||||
}
|
||||
|
||||
// Test DstToString
|
||||
expected := "192.168.1.1:51820"
|
||||
if ep.DstToString() != expected {
|
||||
t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString())
|
||||
}
|
||||
|
||||
// Test DstToBytes
|
||||
bytes := ep.DstToBytes()
|
||||
if len(bytes) == 0 {
|
||||
t.Error("Expected DstToBytes to return non-empty slice")
|
||||
}
|
||||
|
||||
// Test SrcIP (should be zero)
|
||||
if ep.SrcIP().IsValid() {
|
||||
t.Error("Expected SrcIP to be invalid")
|
||||
}
|
||||
|
||||
// Test ClearSrc (should not panic)
|
||||
ep.ClearSrc()
|
||||
}
|
||||
|
||||
// TestParseEndpoint tests the ParseEndpoint method
|
||||
func TestParseEndpoint(t *testing.T) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
defer bind.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
checkAddr func(*testing.T, wgConn.Endpoint)
|
||||
}{
|
||||
{
|
||||
name: "valid IPv4",
|
||||
input: "192.168.1.1:51820",
|
||||
wantErr: false,
|
||||
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||
if ep.DstToString() != "192.168.1.1:51820" {
|
||||
t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid IPv6",
|
||||
input: "[::1]:51820",
|
||||
wantErr: false,
|
||||
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||
if ep.DstToString() != "[::1]:51820" {
|
||||
t.Errorf("Expected [::1]:51820, got %s", ep.DstToString())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid - missing port",
|
||||
input: "192.168.1.1",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - bad format",
|
||||
input: "not-an-address",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ep, err := bind.ParseEndpoint(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && tt.checkAddr != nil {
|
||||
tt.checkAddr(t, ep)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
158
config.go
158
config.go
@@ -8,6 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -21,9 +22,10 @@ type OlmConfig struct {
|
||||
UserToken string `json:"userToken"`
|
||||
|
||||
// Network settings
|
||||
MTU int `json:"mtu"`
|
||||
DNS string `json:"dns"`
|
||||
InterfaceName string `json:"interface"`
|
||||
MTU int `json:"mtu"`
|
||||
DNS string `json:"dns"`
|
||||
UpstreamDNS []string `json:"upstreamDNS"`
|
||||
InterfaceName string `json:"interface"`
|
||||
|
||||
// Logging
|
||||
LogLevel string `json:"logLevel"`
|
||||
@@ -38,8 +40,10 @@ type OlmConfig struct {
|
||||
PingTimeout string `json:"pingTimeout"`
|
||||
|
||||
// Advanced
|
||||
Holepunch bool `json:"holepunch"`
|
||||
TlsClientCert string `json:"tlsClientCert"`
|
||||
DisableHolepunch bool `json:"disableHolepunch"`
|
||||
TlsClientCert string `json:"tlsClientCert"`
|
||||
OverrideDNS bool `json:"overrideDNS"`
|
||||
DisableRelay bool `json:"disableRelay"`
|
||||
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
|
||||
|
||||
// Parsed values (not in JSON)
|
||||
@@ -74,15 +78,16 @@ func DefaultConfig() *OlmConfig {
|
||||
}
|
||||
|
||||
config := &OlmConfig{
|
||||
MTU: 1280,
|
||||
DNS: "8.8.8.8",
|
||||
LogLevel: "INFO",
|
||||
InterfaceName: "olm",
|
||||
EnableAPI: false,
|
||||
SocketPath: socketPath,
|
||||
PingInterval: "3s",
|
||||
PingTimeout: "5s",
|
||||
Holepunch: false,
|
||||
MTU: 1280,
|
||||
DNS: "8.8.8.8",
|
||||
UpstreamDNS: []string{"8.8.8.8:53"},
|
||||
LogLevel: "INFO",
|
||||
InterfaceName: "olm",
|
||||
EnableAPI: false,
|
||||
SocketPath: socketPath,
|
||||
PingInterval: "3s",
|
||||
PingTimeout: "5s",
|
||||
DisableHolepunch: false,
|
||||
// DoNotCreateNewClient: false,
|
||||
sources: make(map[string]string),
|
||||
}
|
||||
@@ -90,6 +95,7 @@ func DefaultConfig() *OlmConfig {
|
||||
// Track default sources
|
||||
config.sources["mtu"] = string(SourceDefault)
|
||||
config.sources["dns"] = string(SourceDefault)
|
||||
config.sources["upstreamDNS"] = string(SourceDefault)
|
||||
config.sources["logLevel"] = string(SourceDefault)
|
||||
config.sources["interface"] = string(SourceDefault)
|
||||
config.sources["enableApi"] = string(SourceDefault)
|
||||
@@ -97,7 +103,9 @@ func DefaultConfig() *OlmConfig {
|
||||
config.sources["socketPath"] = string(SourceDefault)
|
||||
config.sources["pingInterval"] = string(SourceDefault)
|
||||
config.sources["pingTimeout"] = string(SourceDefault)
|
||||
config.sources["holepunch"] = string(SourceDefault)
|
||||
config.sources["disableHolepunch"] = string(SourceDefault)
|
||||
config.sources["overrideDNS"] = string(SourceDefault)
|
||||
config.sources["disableRelay"] = string(SourceDefault)
|
||||
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
|
||||
|
||||
return config
|
||||
@@ -213,6 +221,10 @@ func loadConfigFromEnv(config *OlmConfig) {
|
||||
config.DNS = val
|
||||
config.sources["dns"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("UPSTREAM_DNS"); val != "" {
|
||||
config.UpstreamDNS = []string{val}
|
||||
config.sources["upstreamDNS"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("LOG_LEVEL"); val != "" {
|
||||
config.LogLevel = val
|
||||
config.sources["logLevel"] = string(SourceEnv)
|
||||
@@ -241,9 +253,17 @@ func loadConfigFromEnv(config *OlmConfig) {
|
||||
config.SocketPath = val
|
||||
config.sources["socketPath"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("HOLEPUNCH"); val == "true" {
|
||||
config.Holepunch = true
|
||||
config.sources["holepunch"] = string(SourceEnv)
|
||||
if val := os.Getenv("DISABLE_HOLEPUNCH"); val == "true" {
|
||||
config.DisableHolepunch = true
|
||||
config.sources["disableHolepunch"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("OVERRIDE_DNS"); val == "true" {
|
||||
config.OverrideDNS = true
|
||||
config.sources["overrideDNS"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("DISABLE_RELAY"); val == "true" {
|
||||
config.DisableRelay = true
|
||||
config.sources["disableRelay"] = string(SourceEnv)
|
||||
}
|
||||
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
|
||||
// config.DoNotCreateNewClient = true
|
||||
@@ -257,21 +277,24 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||
|
||||
// Store original values to detect changes
|
||||
origValues := map[string]interface{}{
|
||||
"endpoint": config.Endpoint,
|
||||
"id": config.ID,
|
||||
"secret": config.Secret,
|
||||
"org": config.OrgID,
|
||||
"userToken": config.UserToken,
|
||||
"mtu": config.MTU,
|
||||
"dns": config.DNS,
|
||||
"logLevel": config.LogLevel,
|
||||
"interface": config.InterfaceName,
|
||||
"httpAddr": config.HTTPAddr,
|
||||
"socketPath": config.SocketPath,
|
||||
"pingInterval": config.PingInterval,
|
||||
"pingTimeout": config.PingTimeout,
|
||||
"enableApi": config.EnableAPI,
|
||||
"holepunch": config.Holepunch,
|
||||
"endpoint": config.Endpoint,
|
||||
"id": config.ID,
|
||||
"secret": config.Secret,
|
||||
"org": config.OrgID,
|
||||
"userToken": config.UserToken,
|
||||
"mtu": config.MTU,
|
||||
"dns": config.DNS,
|
||||
"upstreamDNS": fmt.Sprintf("%v", config.UpstreamDNS),
|
||||
"logLevel": config.LogLevel,
|
||||
"interface": config.InterfaceName,
|
||||
"httpAddr": config.HTTPAddr,
|
||||
"socketPath": config.SocketPath,
|
||||
"pingInterval": config.PingInterval,
|
||||
"pingTimeout": config.PingTimeout,
|
||||
"enableApi": config.EnableAPI,
|
||||
"disableHolepunch": config.DisableHolepunch,
|
||||
"overrideDNS": config.OverrideDNS,
|
||||
"disableRelay": config.DisableRelay,
|
||||
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||
}
|
||||
|
||||
@@ -283,6 +306,8 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||
serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)")
|
||||
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
|
||||
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
|
||||
var upstreamDNSFlag string
|
||||
serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8:53)")
|
||||
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
|
||||
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
|
||||
@@ -290,7 +315,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
|
||||
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
|
||||
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
|
||||
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
|
||||
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
|
||||
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings")
|
||||
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
|
||||
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
|
||||
|
||||
version := serviceFlags.Bool("version", false, "Print the version")
|
||||
@@ -301,6 +328,16 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
// Parse upstream DNS flag if provided
|
||||
if upstreamDNSFlag != "" {
|
||||
config.UpstreamDNS = []string{}
|
||||
for _, dns := range splitComma(upstreamDNSFlag) {
|
||||
if dns != "" {
|
||||
config.UpstreamDNS = append(config.UpstreamDNS, dns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track which values were changed by CLI args
|
||||
if config.Endpoint != origValues["endpoint"].(string) {
|
||||
config.sources["endpoint"] = string(SourceCLI)
|
||||
@@ -323,6 +360,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||
if config.DNS != origValues["dns"].(string) {
|
||||
config.sources["dns"] = string(SourceCLI)
|
||||
}
|
||||
if fmt.Sprintf("%v", config.UpstreamDNS) != origValues["upstreamDNS"].(string) {
|
||||
config.sources["upstreamDNS"] = string(SourceCLI)
|
||||
}
|
||||
if config.LogLevel != origValues["logLevel"].(string) {
|
||||
config.sources["logLevel"] = string(SourceCLI)
|
||||
}
|
||||
@@ -344,8 +384,14 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||
if config.EnableAPI != origValues["enableApi"].(bool) {
|
||||
config.sources["enableApi"] = string(SourceCLI)
|
||||
}
|
||||
if config.Holepunch != origValues["holepunch"].(bool) {
|
||||
config.sources["holepunch"] = string(SourceCLI)
|
||||
if config.DisableHolepunch != origValues["disableHolepunch"].(bool) {
|
||||
config.sources["disableHolepunch"] = string(SourceCLI)
|
||||
}
|
||||
if config.OverrideDNS != origValues["overrideDNS"].(bool) {
|
||||
config.sources["overrideDNS"] = string(SourceCLI)
|
||||
}
|
||||
if config.DisableRelay != origValues["disableRelay"].(bool) {
|
||||
config.sources["disableRelay"] = string(SourceCLI)
|
||||
}
|
||||
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
|
||||
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
|
||||
@@ -418,6 +464,10 @@ func mergeConfigs(dest, src *OlmConfig) {
|
||||
dest.DNS = src.DNS
|
||||
dest.sources["dns"] = string(SourceFile)
|
||||
}
|
||||
if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8:53]" {
|
||||
dest.UpstreamDNS = src.UpstreamDNS
|
||||
dest.sources["upstreamDNS"] = string(SourceFile)
|
||||
}
|
||||
if src.LogLevel != "" && src.LogLevel != "INFO" {
|
||||
dest.LogLevel = src.LogLevel
|
||||
dest.sources["logLevel"] = string(SourceFile)
|
||||
@@ -455,9 +505,17 @@ func mergeConfigs(dest, src *OlmConfig) {
|
||||
dest.EnableAPI = src.EnableAPI
|
||||
dest.sources["enableApi"] = string(SourceFile)
|
||||
}
|
||||
if src.Holepunch {
|
||||
dest.Holepunch = src.Holepunch
|
||||
dest.sources["holepunch"] = string(SourceFile)
|
||||
if src.DisableHolepunch {
|
||||
dest.DisableHolepunch = src.DisableHolepunch
|
||||
dest.sources["disableHolepunch"] = string(SourceFile)
|
||||
}
|
||||
if src.OverrideDNS {
|
||||
dest.OverrideDNS = src.OverrideDNS
|
||||
dest.sources["overrideDNS"] = string(SourceFile)
|
||||
}
|
||||
if src.DisableRelay {
|
||||
dest.DisableRelay = src.DisableRelay
|
||||
dest.sources["disableRelay"] = string(SourceFile)
|
||||
}
|
||||
// if src.DoNotCreateNewClient {
|
||||
// dest.DoNotCreateNewClient = src.DoNotCreateNewClient
|
||||
@@ -479,7 +537,7 @@ func SaveConfig(config *OlmConfig) error {
|
||||
func (c *OlmConfig) ShowConfig() {
|
||||
configPath := getOlmConfigPath()
|
||||
|
||||
fmt.Println("\n=== Olm Configuration ===\n")
|
||||
fmt.Print("\n=== Olm Configuration ===\n\n")
|
||||
fmt.Printf("Config File: %s\n", configPath)
|
||||
|
||||
// Check if config file exists
|
||||
@@ -490,7 +548,7 @@ func (c *OlmConfig) ShowConfig() {
|
||||
}
|
||||
|
||||
fmt.Println("\n--- Configuration Values ---")
|
||||
fmt.Println("(Format: Setting = Value [source])\n")
|
||||
fmt.Print("(Format: Setting = Value [source])\n\n")
|
||||
|
||||
// Helper to get source or default
|
||||
getSource := func(key string) string {
|
||||
@@ -526,6 +584,7 @@ func (c *OlmConfig) ShowConfig() {
|
||||
fmt.Println("\nNetwork:")
|
||||
fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu"))
|
||||
fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns"))
|
||||
fmt.Printf(" upstream-dns = %v [%s]\n", c.UpstreamDNS, getSource("upstreamDNS"))
|
||||
fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface"))
|
||||
|
||||
// Logging
|
||||
@@ -545,7 +604,9 @@ func (c *OlmConfig) ShowConfig() {
|
||||
|
||||
// Advanced
|
||||
fmt.Println("\nAdvanced:")
|
||||
fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch"))
|
||||
fmt.Printf(" disable-holepunch = %v [%s]\n", c.DisableHolepunch, getSource("disableHolepunch"))
|
||||
fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS"))
|
||||
fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay"))
|
||||
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
|
||||
if c.TlsClientCert != "" {
|
||||
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
|
||||
@@ -560,3 +621,16 @@ func (c *OlmConfig) ShowConfig() {
|
||||
fmt.Println("\nPriority: cli > environment > file > default")
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// splitComma splits a comma-separated string into a slice of trimmed strings
|
||||
func splitComma(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
43
create_test_creds.py
Normal file
43
create_test_creds.py
Normal file
@@ -0,0 +1,43 @@
|
||||
|
||||
import requests
|
||||
|
||||
def create_olm(base_url, user_token, olm_name, user_id):
|
||||
url = f"{base_url}/api/v1/user/{user_id}/olm"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "pangolin-cli",
|
||||
"X-CSRF-Token": "x-csrf-protection",
|
||||
"Cookie": f"p_session_token={user_token}"
|
||||
}
|
||||
payload = {"name": olm_name}
|
||||
response = requests.put(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
print(f"Response Data: {data}")
|
||||
|
||||
def create_client(base_url, user_token, client_name):
|
||||
url = f"{base_url}/api/v1/api/clients"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "pangolin-cli",
|
||||
"X-CSRF-Token": "x-csrf-protection",
|
||||
"Cookie": f"p_session_token={user_token}"
|
||||
}
|
||||
payload = {"name": client_name}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
print(f"Response Data: {data}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
base_url = input("Enter base URL (e.g., http://localhost:3000): ")
|
||||
user_token = input("Enter user token: ")
|
||||
user_id = input("Enter user ID: ")
|
||||
olm_name = input("Enter OLM name: ")
|
||||
client_name = input("Enter client name: ")
|
||||
|
||||
create_olm(base_url, user_token, olm_name, user_id)
|
||||
# client_id = create_client(base_url, user_token, client_name)
|
||||
331
device/middle_device.go
Normal file
331
device/middle_device.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
// PacketHandler processes intercepted packets and returns true if packet should be dropped
|
||||
type PacketHandler func(packet []byte) bool
|
||||
|
||||
// FilterRule defines a rule for packet filtering
|
||||
type FilterRule struct {
|
||||
DestIP netip.Addr
|
||||
Handler PacketHandler
|
||||
}
|
||||
|
||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||
type MiddleDevice struct {
|
||||
tun.Device
|
||||
rules []FilterRule
|
||||
mutex sync.RWMutex
|
||||
readCh chan readResult
|
||||
injectCh chan []byte
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
bufs [][]byte
|
||||
sizes []int
|
||||
offset int
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
// NewMiddleDevice creates a new filtered TUN device wrapper
|
||||
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||
d := &MiddleDevice{
|
||||
Device: device,
|
||||
rules: make([]FilterRule, 0),
|
||||
readCh: make(chan readResult),
|
||||
injectCh: make(chan []byte, 100),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go d.pump()
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) pump() {
|
||||
const defaultOffset = 16
|
||||
batchSize := d.Device.BatchSize()
|
||||
logger.Debug("MiddleDevice: pump started")
|
||||
|
||||
for {
|
||||
// Check closed first with priority
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Allocate buffers for reading
|
||||
// We allocate new buffers for each read to avoid race conditions
|
||||
// since we pass them to the channel
|
||||
bufs := make([][]byte, batchSize)
|
||||
sizes := make([]int, batchSize)
|
||||
for i := range bufs {
|
||||
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
|
||||
}
|
||||
|
||||
n, err := d.Device.Read(bufs, sizes, defaultOffset)
|
||||
|
||||
// Check closed again after read returns
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Now try to send the result
|
||||
select {
|
||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Debug("MiddleDevice: pump exiting due to read error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
|
||||
func (d *MiddleDevice) InjectOutbound(packet []byte) {
|
||||
select {
|
||||
case d.injectCh <- packet:
|
||||
case <-d.closed:
|
||||
}
|
||||
}
|
||||
|
||||
// AddRule adds a packet filtering rule
|
||||
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
d.rules = append(d.rules, FilterRule{
|
||||
DestIP: destIP,
|
||||
Handler: handler,
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveRule removes all rules for a given destination IP
|
||||
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
newRules := make([]FilterRule, 0, len(d.rules))
|
||||
for _, rule := range d.rules {
|
||||
if rule.DestIP != destIP {
|
||||
newRules = append(newRules, rule)
|
||||
}
|
||||
}
|
||||
d.rules = newRules
|
||||
}
|
||||
|
||||
// Close stops the device
|
||||
func (d *MiddleDevice) Close() error {
|
||||
select {
|
||||
case <-d.closed:
|
||||
// Already closed
|
||||
return nil
|
||||
default:
|
||||
logger.Debug("MiddleDevice: Closing, signaling closed channel")
|
||||
close(d.closed)
|
||||
}
|
||||
logger.Debug("MiddleDevice: Closing underlying TUN device")
|
||||
err := d.Device.Close()
|
||||
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// extractDestIP extracts destination IP from packet (fast path)
|
||||
func extractDestIP(packet []byte) (netip.Addr, bool) {
|
||||
if len(packet) < 20 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
version := packet[0] >> 4
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
if len(packet) < 20 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
// Destination IP is at bytes 16-19 for IPv4
|
||||
ip := netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
|
||||
return ip, true
|
||||
case 6:
|
||||
if len(packet) < 40 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
// Destination IP is at bytes 24-39 for IPv6
|
||||
var ip16 [16]byte
|
||||
copy(ip16[:], packet[24:40])
|
||||
ip := netip.AddrFrom16(ip16)
|
||||
return ip, true
|
||||
}
|
||||
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
// Check if already closed first (non-blocking)
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
|
||||
return 0, os.ErrClosed
|
||||
default:
|
||||
}
|
||||
|
||||
// Now block waiting for data
|
||||
select {
|
||||
case res := <-d.readCh:
|
||||
if res.err != nil {
|
||||
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
||||
return 0, res.err
|
||||
}
|
||||
|
||||
// Copy packets from result to provided buffers
|
||||
count := 0
|
||||
for i := 0; i < res.n && i < len(bufs); i++ {
|
||||
// Handle offset mismatch if necessary
|
||||
// We assume the pump used defaultOffset (16)
|
||||
// If caller asks for different offset, we need to shift
|
||||
src := res.bufs[i]
|
||||
srcOffset := res.offset
|
||||
srcSize := res.sizes[i]
|
||||
|
||||
// Calculate where the packet data starts and ends in src
|
||||
pktData := src[srcOffset : srcOffset+srcSize]
|
||||
|
||||
// Ensure dest buffer is large enough
|
||||
if len(bufs[i]) < offset+len(pktData) {
|
||||
continue // Skip if buffer too small
|
||||
}
|
||||
|
||||
copy(bufs[i][offset:], pktData)
|
||||
sizes[i] = len(pktData)
|
||||
count++
|
||||
}
|
||||
n = count
|
||||
|
||||
case pkt := <-d.injectCh:
|
||||
if len(bufs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if len(bufs[0]) < offset+len(pkt) {
|
||||
return 0, nil // Buffer too small
|
||||
}
|
||||
copy(bufs[0][offset:], pkt)
|
||||
sizes[0] = len(pkt)
|
||||
n = 1
|
||||
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
|
||||
return 0, os.ErrClosed // Signal that device is closed
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
rules := d.rules
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Process packets and filter out handled ones
|
||||
writeIdx := 0
|
||||
for readIdx := 0; readIdx < n; readIdx++ {
|
||||
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
||||
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
// Can't parse, keep packet
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if packet matches any rule
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
// Packet was handled and should be dropped
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
// Keep packet
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return writeIdx, err
|
||||
}
|
||||
|
||||
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
||||
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
d.mutex.RLock()
|
||||
rules := d.rules
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return d.Device.Write(bufs, offset)
|
||||
}
|
||||
|
||||
// Filter packets going down
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
for _, buf := range bufs {
|
||||
if len(buf) <= offset {
|
||||
continue
|
||||
}
|
||||
|
||||
packet := buf[offset:]
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
// Can't parse, keep packet
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if packet matches any rule
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
// Packet was handled and should be dropped
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredBufs) == 0 {
|
||||
return len(bufs), nil // All packets were handled
|
||||
}
|
||||
|
||||
return d.Device.Write(filteredBufs, offset)
|
||||
}
|
||||
102
device/middle_device_test.go
Normal file
102
device/middle_device_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/fosrl/newt/util"
|
||||
)
|
||||
|
||||
func TestExtractDestIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
packet []byte
|
||||
wantIP string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 packet",
|
||||
packet: []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
|
||||
0x0a, 0x1e, 0x1e, 0x1e, // Dest IP: 10.30.30.30
|
||||
},
|
||||
wantIP: "10.30.30.30",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "Too short packet",
|
||||
packet: []byte{0x45, 0x00},
|
||||
wantIP: "",
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotIP, gotOk := extractDestIP(tt.packet)
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("extractDestIP() ok = %v, want %v", gotOk, tt.wantOk)
|
||||
return
|
||||
}
|
||||
if tt.wantOk {
|
||||
wantAddr := netip.MustParseAddr(tt.wantIP)
|
||||
if gotIP != wantAddr {
|
||||
t.Errorf("extractDestIP() ip = %v, want %v", gotIP, wantAddr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProtocol(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
packet []byte
|
||||
wantProto uint8
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "UDP packet",
|
||||
packet: []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01, // Protocol: UDP (17) at byte 9
|
||||
0x0a, 0x1e, 0x1e, 0x1e,
|
||||
},
|
||||
wantProto: 17,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "Too short",
|
||||
packet: []byte{0x45, 0x00},
|
||||
wantProto: 0,
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotProto, gotOk := util.GetProtocol(tt.packet)
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
|
||||
return
|
||||
}
|
||||
if gotProto != tt.wantProto {
|
||||
t.Errorf("GetProtocol() proto = %v, want %v", gotProto, tt.wantProto)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExtractDestIP(b *testing.B) {
|
||||
packet := []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
|
||||
0x0a, 0x1e, 0x1e, 0x1e,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
extractDestIP(packet)
|
||||
}
|
||||
}
|
||||
44
device/tun_unix.go
Normal file
44
device/tun_unix.go
Normal file
@@ -0,0 +1,44 @@
|
||||
//go:build !windows
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
dupTunFd, err := unix.Dup(int(tunFd))
|
||||
if err != nil {
|
||||
logger.Error("Unable to dup tun fd: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(dupTunFd, true)
|
||||
if err != nil {
|
||||
unix.Close(dupTunFd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
|
||||
device, err := tun.CreateTUNFromFile(file, mtuInt)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(interfaceName)
|
||||
}
|
||||
|
||||
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build windows
|
||||
|
||||
package olm
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -11,15 +11,15 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
|
||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
return nil, errors.New("CreateTUNFromFile not supported on Windows")
|
||||
}
|
||||
|
||||
func uapiOpen(interfaceName string) (*os.File, error) {
|
||||
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
// On Windows, UAPIListen only takes one parameter
|
||||
return ipc.UAPIListen(interfaceName)
|
||||
}
|
||||
523
diff
523
diff
@@ -1,523 +0,0 @@
|
||||
diff --git a/api/api.go b/api/api.go
|
||||
index dd07751..0d2e4ef 100644
|
||||
--- a/api/api.go
|
||||
+++ b/api/api.go
|
||||
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
+// SwitchOrgRequest defines the structure for switching organizations
|
||||
+type SwitchOrgRequest struct {
|
||||
+ OrgID string `json:"orgId"`
|
||||
+}
|
||||
+
|
||||
// PeerStatus represents the status of a peer connection
|
||||
type PeerStatus struct {
|
||||
SiteID int `json:"siteId"`
|
||||
@@ -35,6 +40,7 @@ type StatusResponse struct {
|
||||
Registered bool `json:"registered"`
|
||||
TunnelIP string `json:"tunnelIP,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
+ OrgID string `json:"orgId,omitempty"`
|
||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||
}
|
||||
|
||||
@@ -46,6 +52,7 @@ type API struct {
|
||||
server *http.Server
|
||||
connectionChan chan ConnectionRequest
|
||||
shutdownChan chan struct{}
|
||||
+ switchOrgChan chan SwitchOrgRequest
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
connectedAt time.Time
|
||||
@@ -53,6 +60,7 @@ type API struct {
|
||||
isRegistered bool
|
||||
tunnelIP string
|
||||
version string
|
||||
+ orgID string
|
||||
}
|
||||
|
||||
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||
@@ -61,6 +69,7 @@ func NewAPI(addr string) *API {
|
||||
addr: addr,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
shutdownChan: make(chan struct{}, 1),
|
||||
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API {
|
||||
socketPath: socketPath,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
shutdownChan: make(chan struct{}, 1),
|
||||
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
@@ -85,6 +95,7 @@ func (s *API) Start() error {
|
||||
mux.HandleFunc("/connect", s.handleConnect)
|
||||
mux.HandleFunc("/status", s.handleStatus)
|
||||
mux.HandleFunc("/exit", s.handleExit)
|
||||
+ mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||
|
||||
s.server = &http.Server{
|
||||
Handler: mux,
|
||||
@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
|
||||
return s.shutdownChan
|
||||
}
|
||||
|
||||
+// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
||||
+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
||||
+ return s.switchOrgChan
|
||||
+}
|
||||
+
|
||||
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) {
|
||||
s.version = version
|
||||
}
|
||||
|
||||
+// SetOrgID sets the org ID
|
||||
+func (s *API) SetOrgID(orgID string) {
|
||||
+ s.statusMu.Lock()
|
||||
+ defer s.statusMu.Unlock()
|
||||
+ s.orgID = orgID
|
||||
+}
|
||||
+
|
||||
// UpdatePeerRelayStatus updates only the relay status of a peer
|
||||
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
Registered: s.isRegistered,
|
||||
TunnelIP: s.tunnelIP,
|
||||
Version: s.version,
|
||||
+ OrgID: s.orgID,
|
||||
PeerStatuses: s.peerStatuses,
|
||||
}
|
||||
|
||||
@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||
"status": "shutdown initiated",
|
||||
})
|
||||
}
|
||||
+
|
||||
+// handleSwitchOrg handles the /switch-org endpoint
|
||||
+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||
+ if r.Method != http.MethodPost {
|
||||
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
+ return
|
||||
+ }
|
||||
+
|
||||
+ var req SwitchOrgRequest
|
||||
+ decoder := json.NewDecoder(r.Body)
|
||||
+ if err := decoder.Decode(&req); err != nil {
|
||||
+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||
+ return
|
||||
+ }
|
||||
+
|
||||
+ // Validate required fields
|
||||
+ if req.OrgID == "" {
|
||||
+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
|
||||
+ return
|
||||
+ }
|
||||
+
|
||||
+ logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||
+
|
||||
+ // Send the request to the main goroutine
|
||||
+ select {
|
||||
+ case s.switchOrgChan <- req:
|
||||
+ // Signal sent successfully
|
||||
+ default:
|
||||
+ // Channel already has a signal, don't block
|
||||
+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests)
|
||||
+ return
|
||||
+ }
|
||||
+
|
||||
+ // Return a success response
|
||||
+ w.Header().Set("Content-Type", "application/json")
|
||||
+ w.WriteHeader(http.StatusAccepted)
|
||||
+ json.NewEncoder(w).Encode(map[string]string{
|
||||
+ "status": "org switch initiated",
|
||||
+ "orgId": req.OrgID,
|
||||
+ })
|
||||
+}
|
||||
diff --git a/olm/olm.go b/olm/olm.go
|
||||
index 78080c4..5e292d6 100644
|
||||
--- a/olm/olm.go
|
||||
+++ b/olm/olm.go
|
||||
@@ -58,6 +58,58 @@ type Config struct {
|
||||
OrgID string
|
||||
}
|
||||
|
||||
+// tunnelState holds all the active tunnel resources that need cleanup
|
||||
+type tunnelState struct {
|
||||
+ dev *device.Device
|
||||
+ tdev tun.Device
|
||||
+ uapiListener net.Listener
|
||||
+ peerMonitor *peermonitor.PeerMonitor
|
||||
+ stopRegister func()
|
||||
+ connected bool
|
||||
+}
|
||||
+
|
||||
+// teardownTunnel cleans up all tunnel resources
|
||||
+func teardownTunnel(state *tunnelState) {
|
||||
+ if state == nil {
|
||||
+ return
|
||||
+ }
|
||||
+
|
||||
+ logger.Info("Tearing down tunnel...")
|
||||
+
|
||||
+ // Stop registration messages
|
||||
+ if state.stopRegister != nil {
|
||||
+ state.stopRegister()
|
||||
+ state.stopRegister = nil
|
||||
+ }
|
||||
+
|
||||
+ // Stop peer monitor
|
||||
+ if state.peerMonitor != nil {
|
||||
+ state.peerMonitor.Stop()
|
||||
+ state.peerMonitor = nil
|
||||
+ }
|
||||
+
|
||||
+ // Close UAPI listener
|
||||
+ if state.uapiListener != nil {
|
||||
+ state.uapiListener.Close()
|
||||
+ state.uapiListener = nil
|
||||
+ }
|
||||
+
|
||||
+ // Close WireGuard device
|
||||
+ if state.dev != nil {
|
||||
+ state.dev.Close()
|
||||
+ state.dev = nil
|
||||
+ }
|
||||
+
|
||||
+ // Close TUN device
|
||||
+ if state.tdev != nil {
|
||||
+ state.tdev.Close()
|
||||
+ state.tdev = nil
|
||||
+ }
|
||||
+
|
||||
+ state.connected = false
|
||||
+ logger.Info("Tunnel teardown complete")
|
||||
+}
|
||||
+
|
||||
func Run(ctx context.Context, config Config) {
|
||||
// Create a cancellable context for internal shutdown control
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) {
|
||||
pingTimeout = config.PingTimeoutDuration
|
||||
doHolepunch = config.Holepunch
|
||||
privateKey wgtypes.Key
|
||||
- connected bool
|
||||
- dev *device.Device
|
||||
wgData WgData
|
||||
holePunchData HolePunchData
|
||||
- uapiListener net.Listener
|
||||
- tdev tun.Device
|
||||
+ orgID = config.OrgID
|
||||
)
|
||||
|
||||
+ // Tunnel state that can be torn down and recreated
|
||||
+ tunnel := &tunnelState{}
|
||||
+
|
||||
stopHolepunch = make(chan struct{})
|
||||
stopPing = make(chan struct{})
|
||||
|
||||
@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) {
|
||||
}
|
||||
|
||||
apiServer.SetVersion(config.Version)
|
||||
+ apiServer.SetOrgID(orgID)
|
||||
if err := apiServer.Start(); err != nil {
|
||||
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||
}
|
||||
@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) {
|
||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
|
||||
- if connected {
|
||||
+ if tunnel.connected {
|
||||
logger.Info("Already connected. Ignoring new connection request.")
|
||||
return
|
||||
}
|
||||
|
||||
- if stopRegister != nil {
|
||||
- stopRegister()
|
||||
- stopRegister = nil
|
||||
+ if tunnel.stopRegister != nil {
|
||||
+ tunnel.stopRegister()
|
||||
+ tunnel.stopRegister = nil
|
||||
}
|
||||
|
||||
close(stopHolepunch)
|
||||
@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// if there is an existing tunnel then close it
|
||||
- if dev != nil {
|
||||
+ if tunnel.dev != nil {
|
||||
logger.Info("Got new message. Closing existing tunnel!")
|
||||
- dev.Close()
|
||||
+ tunnel.dev.Close()
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) {
|
||||
return
|
||||
}
|
||||
|
||||
- tdev, err = func() (tun.Device, error) {
|
||||
+ tunnel.tdev, err = func() (tun.Device, error) {
|
||||
if runtime.GOOS == "darwin" {
|
||||
interfaceName, err := findUnusedUTUN()
|
||||
if err != nil {
|
||||
@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) {
|
||||
return
|
||||
}
|
||||
|
||||
- if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
||||
+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil {
|
||||
interfaceName = realInterfaceName
|
||||
}
|
||||
|
||||
@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) {
|
||||
return
|
||||
}
|
||||
|
||||
- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||
+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||
|
||||
- uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||
+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||
if err != nil {
|
||||
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||
os.Exit(1)
|
||||
@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) {
|
||||
|
||||
go func() {
|
||||
for {
|
||||
- conn, err := uapiListener.Accept()
|
||||
+ conn, err := tunnel.uapiListener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
- go dev.IpcHandle(conn)
|
||||
+ go tunnel.dev.IpcHandle(conn)
|
||||
}
|
||||
}()
|
||||
logger.Info("UAPI listener started")
|
||||
|
||||
- if err = dev.Up(); err != nil {
|
||||
+ if err = tunnel.dev.Up(); err != nil {
|
||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||
}
|
||||
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||
@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) {
|
||||
apiServer.SetTunnelIP(wgData.TunnelIP)
|
||||
}
|
||||
|
||||
- peerMonitor = peermonitor.NewPeerMonitor(
|
||||
+ tunnel.peerMonitor = peermonitor.NewPeerMonitor(
|
||||
func(siteID int, connected bool, rtt time.Duration) {
|
||||
if apiServer != nil {
|
||||
// Find the site config to get endpoint information
|
||||
@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) {
|
||||
},
|
||||
fixKey(privateKey.String()),
|
||||
olm,
|
||||
- dev,
|
||||
+ tunnel.dev,
|
||||
doHolepunch,
|
||||
)
|
||||
|
||||
@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) {
|
||||
// Format the endpoint before configuring the peer.
|
||||
site.Endpoint = formatEndpoint(site.Endpoint)
|
||||
|
||||
- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
|
||||
+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil {
|
||||
logger.Error("Failed to configure peer: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) {
|
||||
logger.Info("Configured peer %s", site.PublicKey)
|
||||
}
|
||||
|
||||
- peerMonitor.Start()
|
||||
+ tunnel.peerMonitor.Start()
|
||||
|
||||
if apiServer != nil {
|
||||
apiServer.SetRegistered(true)
|
||||
}
|
||||
|
||||
- connected = true
|
||||
+ tunnel.connected = true
|
||||
|
||||
logger.Info("WireGuard device created.")
|
||||
})
|
||||
@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) {
|
||||
}
|
||||
|
||||
// Update the peer in WireGuard
|
||||
- if dev != nil {
|
||||
+ if tunnel.dev != nil {
|
||||
// Find the existing peer to get old data
|
||||
var oldRemoteSubnets string
|
||||
var oldPublicKey string
|
||||
@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) {
|
||||
// If the public key has changed, remove the old peer first
|
||||
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
|
||||
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
|
||||
- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
|
||||
+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil {
|
||||
logger.Error("Failed to remove old peer: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) {
|
||||
// Format the endpoint before updating the peer.
|
||||
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||
|
||||
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
||||
logger.Error("Failed to update peer: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) {
|
||||
}
|
||||
|
||||
// Add the peer to WireGuard
|
||||
- if dev != nil {
|
||||
+ if tunnel.dev != nil {
|
||||
// Format the endpoint before adding the new peer.
|
||||
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||
|
||||
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
||||
logger.Error("Failed to add peer: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) {
|
||||
}
|
||||
|
||||
// Remove the peer from WireGuard
|
||||
- if dev != nil {
|
||||
- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
||||
+ if tunnel.dev != nil {
|
||||
+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
||||
logger.Error("Failed to remove peer: %v", err)
|
||||
// Send error response if needed
|
||||
return
|
||||
@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) {
|
||||
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
||||
}
|
||||
|
||||
- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||
+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
|
||||
@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) {
|
||||
apiServer.SetConnectionStatus(true)
|
||||
}
|
||||
|
||||
- if connected {
|
||||
+ if tunnel.connected {
|
||||
logger.Debug("Already connected, skipping registration")
|
||||
return nil
|
||||
}
|
||||
@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) {
|
||||
|
||||
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
|
||||
|
||||
- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": !doHolepunch,
|
||||
"olmVersion": config.Version,
|
||||
- "orgId": config.OrgID,
|
||||
+ "orgId": orgID,
|
||||
}, 1*time.Second)
|
||||
|
||||
go keepSendingPing(olm)
|
||||
@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) {
|
||||
}
|
||||
defer olm.Close()
|
||||
|
||||
+ // Listen for org switch requests from the API (after olm is created)
|
||||
+ if apiServer != nil {
|
||||
+ go func() {
|
||||
+ for req := range apiServer.GetSwitchOrgChannel() {
|
||||
+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID)
|
||||
+
|
||||
+ // Update the orgId
|
||||
+ orgID = req.OrgID
|
||||
+
|
||||
+ // Teardown existing tunnel
|
||||
+ teardownTunnel(tunnel)
|
||||
+
|
||||
+ // Reset tunnel state
|
||||
+ tunnel = &tunnelState{}
|
||||
+
|
||||
+ // Stop holepunch
|
||||
+ select {
|
||||
+ case <-stopHolepunch:
|
||||
+ // Channel already closed
|
||||
+ default:
|
||||
+ close(stopHolepunch)
|
||||
+ }
|
||||
+ stopHolepunch = make(chan struct{})
|
||||
+
|
||||
+ // Clear API server state
|
||||
+ apiServer.SetRegistered(false)
|
||||
+ apiServer.SetTunnelIP("")
|
||||
+ apiServer.SetOrgID(orgID)
|
||||
+
|
||||
+ // Send new registration message with updated orgId
|
||||
+ publicKey := privateKey.PublicKey()
|
||||
+ logger.Info("Sending registration message with new orgId: %s", orgID)
|
||||
+
|
||||
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
+ "publicKey": publicKey.String(),
|
||||
+ "relay": !doHolepunch,
|
||||
+ "olmVersion": config.Version,
|
||||
+ "orgId": orgID,
|
||||
+ }, 1*time.Second)
|
||||
+ }
|
||||
+ }()
|
||||
+ }
|
||||
+
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info("Context cancelled")
|
||||
@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) {
|
||||
close(stopHolepunch)
|
||||
}
|
||||
|
||||
- if stopRegister != nil {
|
||||
- stopRegister()
|
||||
- stopRegister = nil
|
||||
+ if tunnel.stopRegister != nil {
|
||||
+ tunnel.stopRegister()
|
||||
+ tunnel.stopRegister = nil
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) {
|
||||
close(stopPing)
|
||||
}
|
||||
|
||||
- if peerMonitor != nil {
|
||||
- peerMonitor.Stop()
|
||||
- }
|
||||
-
|
||||
- if uapiListener != nil {
|
||||
- uapiListener.Close()
|
||||
- }
|
||||
- if dev != nil {
|
||||
- dev.Close()
|
||||
- }
|
||||
+ // Use teardownTunnel to clean up all tunnel resources
|
||||
+ teardownTunnel(tunnel)
|
||||
|
||||
if apiServer != nil {
|
||||
apiServer.Stop()
|
||||
457
dns/dns_proxy.go
Normal file
457
dns/dns_proxy.go
Normal file
@@ -0,0 +1,457 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/device"
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
DNSPort = 53
|
||||
)
|
||||
|
||||
// DNSProxy implements a DNS proxy using gvisor netstack
|
||||
type DNSProxy struct {
|
||||
stack *stack.Stack
|
||||
ep *channel.Endpoint
|
||||
proxyIP netip.Addr
|
||||
upstreamDNS []string
|
||||
mtu int
|
||||
tunDevice tun.Device // Direct reference to underlying TUN device for responses
|
||||
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
|
||||
recordStore *DNSRecordStore // Local DNS records
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewDNSProxy creates a new DNS proxy
|
||||
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string) (*DNSProxy, error) {
|
||||
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
||||
}
|
||||
|
||||
if len(upstreamDns) == 0 {
|
||||
return nil, fmt.Errorf("at least one upstream DNS server must be specified")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
proxy := &DNSProxy{
|
||||
proxyIP: proxyIP,
|
||||
mtu: mtu,
|
||||
tunDevice: tunDevice,
|
||||
middleDevice: middleDevice,
|
||||
upstreamDNS: upstreamDns,
|
||||
recordStore: NewDNSRecordStore(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create gvisor netstack
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
HandleLocal: true,
|
||||
}
|
||||
|
||||
proxy.ep = channel.New(256, uint32(mtu), "")
|
||||
proxy.stack = stack.New(stackOpts)
|
||||
|
||||
// Create NIC
|
||||
if err := proxy.stack.CreateNIC(1, proxy.ep); err != nil {
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
// Add IP address
|
||||
// Parse the proxy IP to get the octets
|
||||
ipBytes := proxyIP.As4()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
||||
}
|
||||
|
||||
if err := proxy.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to add protocol address: %v", err)
|
||||
}
|
||||
|
||||
// Add default route
|
||||
proxy.stack.AddRoute(tcpip.Route{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// Start starts the DNS proxy and registers with the filter
|
||||
func (p *DNSProxy) Start() error {
|
||||
// Install packet filter rule
|
||||
p.middleDevice.AddRule(p.proxyIP, p.handlePacket)
|
||||
|
||||
// Start DNS listener
|
||||
p.wg.Add(2)
|
||||
go p.runDNSListener()
|
||||
go p.runPacketSender()
|
||||
|
||||
logger.Info("DNS proxy started on %s:%d", p.proxyIP.String(), DNSPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the DNS proxy
|
||||
func (p *DNSProxy) Stop() {
|
||||
if p.middleDevice != nil {
|
||||
p.middleDevice.RemoveRule(p.proxyIP)
|
||||
}
|
||||
p.cancel()
|
||||
|
||||
// Close the endpoint first to unblock any pending Read() calls in runPacketSender
|
||||
if p.ep != nil {
|
||||
p.ep.Close()
|
||||
}
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
if p.stack != nil {
|
||||
p.stack.Close()
|
||||
}
|
||||
|
||||
logger.Info("DNS proxy stopped")
|
||||
}
|
||||
|
||||
func (p *DNSProxy) GetProxyIP() netip.Addr {
|
||||
return p.proxyIP
|
||||
}
|
||||
|
||||
// handlePacket is called by the filter for packets destined to DNS proxy IP
|
||||
func (p *DNSProxy) handlePacket(packet []byte) bool {
|
||||
if len(packet) < 20 {
|
||||
return false // Don't drop, malformed
|
||||
}
|
||||
|
||||
// Quick check for UDP port 53
|
||||
proto, ok := util.GetProtocol(packet)
|
||||
if !ok || proto != 17 { // 17 = UDP
|
||||
return false // Not UDP, don't handle
|
||||
}
|
||||
|
||||
port, ok := util.GetDestPort(packet)
|
||||
if !ok || port != DNSPort {
|
||||
return false // Not DNS port
|
||||
}
|
||||
|
||||
// Inject packet into our netstack
|
||||
version := packet[0] >> 4
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
p.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||
case 6:
|
||||
p.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||
default:
|
||||
pkb.DecRef()
|
||||
return false
|
||||
}
|
||||
|
||||
pkb.DecRef()
|
||||
return true // Drop packet from normal path
|
||||
}
|
||||
|
||||
// runDNSListener listens for DNS queries on the netstack
|
||||
func (p *DNSProxy) runDNSListener() {
|
||||
defer p.wg.Done()
|
||||
|
||||
// Create UDP listener using gonet
|
||||
// Parse the proxy IP to get the octets
|
||||
ipBytes := p.proxyIP.As4()
|
||||
laddr := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4(ipBytes),
|
||||
Port: DNSPort,
|
||||
}
|
||||
|
||||
udpConn, err := gonet.DialUDP(p.stack, laddr, nil, ipv4.ProtocolNumber)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create DNS listener: %v", err)
|
||||
return
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
logger.Debug("DNS proxy listening on netstack")
|
||||
|
||||
// Handle DNS queries
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
udpConn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
n, remoteAddr, err := udpConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
logger.Error("DNS read error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
query := make([]byte, n)
|
||||
copy(query, buf[:n])
|
||||
|
||||
// Handle query in background
|
||||
go p.handleDNSQuery(udpConn, query, remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// handleDNSQuery processes a DNS query, checking local records first, then forwarding upstream
|
||||
func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clientAddr net.Addr) {
|
||||
// Parse the DNS query
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(queryData); err != nil {
|
||||
logger.Error("Failed to parse DNS query: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(msg.Question) == 0 {
|
||||
logger.Debug("DNS query has no questions")
|
||||
return
|
||||
}
|
||||
|
||||
question := msg.Question[0]
|
||||
logger.Debug("DNS query for %s (type %s)", question.Name, dns.TypeToString[question.Qtype])
|
||||
|
||||
// Check if we have local records for this query
|
||||
var response *dns.Msg
|
||||
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
|
||||
response = p.checkLocalRecords(msg, question)
|
||||
}
|
||||
|
||||
// If no local records, forward to upstream
|
||||
if response == nil {
|
||||
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
|
||||
response = p.forwardToUpstream(msg)
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
logger.Error("Failed to get DNS response for %s", question.Name)
|
||||
return
|
||||
}
|
||||
|
||||
// Pack and send response
|
||||
responseData, err := response.Pack()
|
||||
if err != nil {
|
||||
logger.Error("Failed to pack DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = udpConn.WriteTo(responseData, clientAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to send DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// checkLocalRecords checks if we have local records for the query
|
||||
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
|
||||
var recordType RecordType
|
||||
if question.Qtype == dns.TypeA {
|
||||
recordType = RecordTypeA
|
||||
} else if question.Qtype == dns.TypeAAAA {
|
||||
recordType = RecordTypeAAAA
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
ips := p.recordStore.GetRecords(question.Name, recordType)
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
|
||||
|
||||
// Create response message
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(query)
|
||||
response.Authoritative = true
|
||||
|
||||
// Add answer records
|
||||
for _, ip := range ips {
|
||||
var rr dns.RR
|
||||
if question.Qtype == dns.TypeA {
|
||||
rr = &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300, // 5 minutes
|
||||
},
|
||||
A: ip.To4(),
|
||||
}
|
||||
} else { // TypeAAAA
|
||||
rr = &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300, // 5 minutes
|
||||
},
|
||||
AAAA: ip.To16(),
|
||||
}
|
||||
}
|
||||
response.Answer = append(response.Answer, rr)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// forwardToUpstream forwards a DNS query to upstream DNS servers
|
||||
func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
|
||||
// Try primary DNS server
|
||||
response, err := p.queryUpstream(p.upstreamDNS[0], query, 2*time.Second)
|
||||
if err != nil && len(p.upstreamDNS) > 1 {
|
||||
// Try secondary DNS server
|
||||
logger.Debug("Primary DNS failed, trying secondary: %v", err)
|
||||
response, err = p.queryUpstream(p.upstreamDNS[1], query, 2*time.Second)
|
||||
if err != nil {
|
||||
logger.Error("Both DNS servers failed: %v", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
// queryUpstream sends a DNS query to upstream server using miekg/dns
|
||||
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||
client := &dns.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
response, _, err := client.Exchange(query, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// runPacketSender sends packets from netstack back to TUN
|
||||
func (p *DNSProxy) runPacketSender() {
|
||||
defer p.wg.Done()
|
||||
|
||||
// MessageTransportHeaderSize is the offset used by WireGuard device
|
||||
// for reading/writing packets to the TUN interface
|
||||
const offset = 16
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read packets from netstack endpoint
|
||||
pkt := p.ep.Read()
|
||||
if pkt == nil {
|
||||
// No packet available, small sleep to avoid busy loop
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract packet data as slices
|
||||
slices := pkt.AsSlices()
|
||||
if len(slices) > 0 {
|
||||
// Flatten all slices into a single packet buffer
|
||||
var totalSize int
|
||||
for _, slice := range slices {
|
||||
totalSize += len(slice)
|
||||
}
|
||||
|
||||
// Allocate buffer with offset space for WireGuard transport header
|
||||
// The first 'offset' bytes are reserved for the transport header
|
||||
buf := make([]byte, offset+totalSize)
|
||||
|
||||
// Copy packet data after the offset
|
||||
pos := offset
|
||||
for _, slice := range slices {
|
||||
copy(buf[pos:], slice)
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Write packet to TUN device
|
||||
// offset=16 indicates packet data starts at position 16 in the buffer
|
||||
_, err := p.tunDevice.Write([][]byte{buf}, offset)
|
||||
if err != nil {
|
||||
logger.Error("Failed to write DNS response to TUN: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
// AddDNSRecord adds a DNS record to the local store
|
||||
// domain should be a domain name (e.g., "example.com" or "example.com.")
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
|
||||
return p.recordStore.AddRecord(domain, ip)
|
||||
}
|
||||
|
||||
// RemoveDNSRecord removes a DNS record from the local store
|
||||
// If ip is nil, removes all records for the domain
|
||||
func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
|
||||
p.recordStore.RemoveRecord(domain, ip)
|
||||
}
|
||||
|
||||
// GetDNSRecords returns all IP addresses for a domain and record type
|
||||
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
|
||||
return p.recordStore.GetRecords(domain, recordType)
|
||||
}
|
||||
|
||||
// ClearDNSRecords removes all DNS records from the local store
|
||||
func (p *DNSProxy) ClearDNSRecords() {
|
||||
p.recordStore.Clear()
|
||||
}
|
||||
|
||||
func PickIPFromSubnet(subnet string) (netip.Addr, error) {
|
||||
// given a subnet in CIDR notation, pick the first usable IP
|
||||
prefix, err := netip.ParsePrefix(subnet)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invalid subnet: %w", err)
|
||||
}
|
||||
|
||||
// Pick the first usable IP address from the subnet
|
||||
ip := prefix.Addr().Next()
|
||||
if !ip.IsValid() {
|
||||
return netip.Addr{}, fmt.Errorf("no valid IP address found in subnet: %s", subnet)
|
||||
}
|
||||
|
||||
return ip, nil
|
||||
}
|
||||
166
dns/dns_records.go
Normal file
166
dns/dns_records.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// RecordType represents the type of DNS record
|
||||
type RecordType uint16
|
||||
|
||||
const (
|
||||
RecordTypeA RecordType = RecordType(dns.TypeA)
|
||||
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
|
||||
)
|
||||
|
||||
// DNSRecordStore manages local DNS records for A and AAAA queries
|
||||
type DNSRecordStore struct {
|
||||
mu sync.RWMutex
|
||||
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
||||
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
||||
}
|
||||
|
||||
// NewDNSRecordStore creates a new DNS record store
|
||||
func NewDNSRecordStore() *DNSRecordStore {
|
||||
return &DNSRecordStore{
|
||||
aRecords: make(map[string][]net.IP),
|
||||
aaaaRecords: make(map[string][]net.IP),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRecord adds a DNS record mapping (A or AAAA)
|
||||
// domain should be in FQDN format (e.g., "example.com.")
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase
|
||||
domain = dns.Fqdn(domain)
|
||||
|
||||
if ip.To4() != nil {
|
||||
// IPv4 address
|
||||
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
||||
} else if ip.To16() != nil {
|
||||
// IPv6 address
|
||||
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
||||
} else {
|
||||
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRecord removes a specific DNS record mapping
|
||||
// If ip is nil, removes all records for the domain
|
||||
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase
|
||||
domain = dns.Fqdn(domain)
|
||||
|
||||
if ip == nil {
|
||||
// Remove all records for this domain
|
||||
delete(s.aRecords, domain)
|
||||
delete(s.aaaaRecords, domain)
|
||||
return
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
// Remove specific IPv4 address
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
s.aRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aRecords[domain]) == 0 {
|
||||
delete(s.aRecords, domain)
|
||||
}
|
||||
}
|
||||
} else if ip.To16() != nil {
|
||||
// Remove specific IPv6 address
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
s.aaaaRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aaaaRecords[domain]) == 0 {
|
||||
delete(s.aaaaRecords, domain)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRecords returns all IP addresses for a domain and record type
|
||||
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = dns.Fqdn(domain)
|
||||
|
||||
var records []net.IP
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
}
|
||||
case RecordTypeAAAA:
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// HasRecord checks if a domain has any records of the specified type
|
||||
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = dns.Fqdn(domain)
|
||||
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
_, ok := s.aRecords[domain]
|
||||
return ok
|
||||
case RecordTypeAAAA:
|
||||
_, ok := s.aaaaRecords[domain]
|
||||
return ok
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Clear removes all records from the store
|
||||
func (s *DNSRecordStore) Clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.aRecords = make(map[string][]net.IP)
|
||||
s.aaaaRecords = make(map[string][]net.IP)
|
||||
}
|
||||
|
||||
// removeIP is a helper function to remove a specific IP from a slice
|
||||
func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
|
||||
result := make([]net.IP, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
if !ip.Equal(toRemove) {
|
||||
result = append(result, ip)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
68
dns/override/dns_override_darwin.go
Normal file
68
dns/override/dns_override_darwin.go
Normal file
@@ -0,0 +1,68 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/dns"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
|
||||
// Uses scutil for DNS configuration
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
if dnsProxy == nil {
|
||||
return fmt.Errorf("DNS proxy is nil")
|
||||
}
|
||||
|
||||
var err error
|
||||
configurator, err = platform.NewDarwinDNSConfigurator()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Darwin DNS configurator: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Using Darwin scutil DNS configurator")
|
||||
|
||||
// Get current DNS servers before changing
|
||||
currentDNS, err := configurator.GetCurrentDNS()
|
||||
if err != nil {
|
||||
logger.Warn("Could not get current DNS: %v", err)
|
||||
} else {
|
||||
logger.Info("Current DNS servers: %v", currentDNS)
|
||||
}
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
dnsProxy.GetProxyIP(),
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
originalDNS, err := configurator.SetDNS(newDNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Original DNS servers backed up: %v", originalDNS)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride restores the original DNS configuration
|
||||
func RestoreDNSOverride() error {
|
||||
if configurator == nil {
|
||||
logger.Debug("No DNS configurator to restore")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Restoring original DNS configuration")
|
||||
if err := configurator.RestoreDNS(); err != nil {
|
||||
return fmt.Errorf("failed to restore DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("DNS configuration restored successfully")
|
||||
return nil
|
||||
}
|
||||
105
dns/override/dns_override_unix.go
Normal file
105
dns/override/dns_override_unix.go
Normal file
@@ -0,0 +1,105 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/dns"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
|
||||
// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
if dnsProxy == nil {
|
||||
return fmt.Errorf("DNS proxy is nil")
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability
|
||||
managerType := platform.DetectDNSManager(interfaceName)
|
||||
logger.Info("Detected DNS manager: %s", managerType.String())
|
||||
|
||||
// Create configurator based on detected manager
|
||||
switch managerType {
|
||||
case platform.SystemdResolvedManager:
|
||||
configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using systemd-resolved DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err)
|
||||
|
||||
case platform.NetworkManagerManager:
|
||||
configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using NetworkManager DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err)
|
||||
|
||||
case platform.ResolvconfManager:
|
||||
configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using resolvconf DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create resolvconf configurator: %v, falling back", err)
|
||||
}
|
||||
|
||||
// Fall back to direct file manipulation
|
||||
configurator, err = platform.NewFileDNSConfigurator()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file DNS configurator: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Using file-based DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
}
|
||||
|
||||
// setDNS is a helper function to set DNS and log the results
|
||||
func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
|
||||
// Get current DNS servers before changing
|
||||
currentDNS, err := conf.GetCurrentDNS()
|
||||
if err != nil {
|
||||
logger.Warn("Could not get current DNS: %v", err)
|
||||
} else {
|
||||
logger.Info("Current DNS servers: %v", currentDNS)
|
||||
}
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
dnsProxy.GetProxyIP(),
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
originalDNS, err := conf.SetDNS(newDNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Original DNS servers backed up: %v", originalDNS)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride restores the original DNS configuration
|
||||
func RestoreDNSOverride() error {
|
||||
if configurator == nil {
|
||||
logger.Debug("No DNS configurator to restore")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Restoring original DNS configuration")
|
||||
if err := configurator.RestoreDNS(); err != nil {
|
||||
return fmt.Errorf("failed to restore DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("DNS configuration restored successfully")
|
||||
return nil
|
||||
}
|
||||
68
dns/override/dns_override_windows.go
Normal file
68
dns/override/dns_override_windows.go
Normal file
@@ -0,0 +1,68 @@
|
||||
//go:build windows
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/dns"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
|
||||
// Uses registry-based configuration (automatically extracts interface GUID)
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
if dnsProxy == nil {
|
||||
return fmt.Errorf("DNS proxy is nil")
|
||||
}
|
||||
|
||||
var err error
|
||||
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Windows DNS configurator: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Using Windows registry DNS configurator for interface: %s", interfaceName)
|
||||
|
||||
// Get current DNS servers before changing
|
||||
currentDNS, err := configurator.GetCurrentDNS()
|
||||
if err != nil {
|
||||
logger.Warn("Could not get current DNS: %v", err)
|
||||
} else {
|
||||
logger.Info("Current DNS servers: %v", currentDNS)
|
||||
}
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
dnsProxy.GetProxyIP(),
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
originalDNS, err := configurator.SetDNS(newDNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Original DNS servers backed up: %v", originalDNS)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride restores the original DNS configuration
|
||||
func RestoreDNSOverride() error {
|
||||
if configurator == nil {
|
||||
logger.Debug("No DNS configurator to restore")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Restoring original DNS configuration")
|
||||
if err := configurator.RestoreDNS(); err != nil {
|
||||
return fmt.Errorf("failed to restore DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("DNS configuration restored successfully")
|
||||
return nil
|
||||
}
|
||||
268
dns/platform/darwin.go
Normal file
268
dns/platform/darwin.go
Normal file
@@ -0,0 +1,268 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
scutilPath = "/usr/sbin/scutil"
|
||||
dscacheutilPath = "/usr/bin/dscacheutil"
|
||||
|
||||
dnsStateKeyFormat = "State:/Network/Service/Olm-%s/DNS"
|
||||
globalIPv4State = "State:/Network/Global/IPv4"
|
||||
primaryServiceFormat = "State:/Network/Service/%s/DNS"
|
||||
|
||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
|
||||
keyServerAddresses = "ServerAddresses"
|
||||
keyServerPort = "ServerPort"
|
||||
arraySymbol = "* "
|
||||
digitSymbol = "# "
|
||||
)
|
||||
|
||||
// DarwinDNSConfigurator manages DNS settings on macOS using scutil
|
||||
type DarwinDNSConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewDarwinDNSConfigurator creates a new macOS DNS configurator
|
||||
func NewDarwinDNSConfigurator() (*DarwinDNSConfigurator, error) {
|
||||
return &DarwinDNSConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (d *DarwinDNSConfigurator) Name() string {
|
||||
return "darwin-scutil"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (d *DarwinDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := d.GetCurrentDNS()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current DNS: %w", err)
|
||||
}
|
||||
|
||||
// Store original state
|
||||
d.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: d.Name(),
|
||||
}
|
||||
|
||||
// Set new DNS servers
|
||||
if err := d.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := d.flushDNSCache(); err != nil {
|
||||
// Non-fatal, just log
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (d *DarwinDNSConfigurator) RestoreDNS() error {
|
||||
// Remove all created keys
|
||||
for key := range d.createdKeys {
|
||||
if err := d.removeKey(key); err != nil {
|
||||
return fmt.Errorf("remove key %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := d.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (d *DarwinDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
primaryServiceKey, err := d.getPrimaryServiceKey()
|
||||
if err != nil || primaryServiceKey == "" {
|
||||
return nil, fmt.Errorf("get primary service: %w", err)
|
||||
}
|
||||
|
||||
dnsKey := fmt.Sprintf(primaryServiceFormat, primaryServiceKey)
|
||||
cmd := fmt.Sprintf("show %s\n", dnsKey)
|
||||
|
||||
output, err := d.runScutil(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("run scutil: %w", err)
|
||||
}
|
||||
|
||||
servers := d.parseServerAddresses(output)
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// applyDNSServers applies the DNS server configuration
|
||||
func (d *DarwinDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
key := fmt.Sprintf(dnsStateKeyFormat, "Override")
|
||||
|
||||
// Use SupplementalMatchDomains with empty string to match ALL domains
|
||||
// This is the key to making DNS override work on macOS
|
||||
// Setting SupplementalMatchDomainsNoSearch to 0 enables search domain behavior
|
||||
err := d.addDNSState(key, "\"\"", servers[0], 53, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set DNS servers: %w", err)
|
||||
}
|
||||
|
||||
d.createdKeys[key] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addDNSState adds a DNS state entry with the specified configuration
|
||||
func (d *DarwinDNSConfigurator) addDNSState(state, domains string, dnsServer netip.Addr, port int, enableSearch bool) error {
|
||||
noSearch := "1"
|
||||
if enableSearch {
|
||||
noSearch = "0"
|
||||
}
|
||||
|
||||
// Build the scutil command following NetBird's approach
|
||||
var commands strings.Builder
|
||||
commands.WriteString("d.init\n")
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keySupplementalMatchDomains, arraySymbol, domains))
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keySupplementalMatchDomainsNoSearch, digitSymbol, noSearch))
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keyServerAddresses, arraySymbol, dnsServer.String()))
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keyServerPort, digitSymbol, strconv.Itoa(port)))
|
||||
commands.WriteString(fmt.Sprintf("set %s\n", state))
|
||||
|
||||
if _, err := d.runScutil(commands.String()); err != nil {
|
||||
return fmt.Errorf("applying state for domains %s, error: %w", domains, err)
|
||||
}
|
||||
|
||||
logger.Info("Added DNS override with server %s:%d for domains: %s", dnsServer.String(), port, domains)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeKey removes a DNS configuration key
|
||||
func (d *DarwinDNSConfigurator) removeKey(key string) error {
|
||||
cmd := fmt.Sprintf("remove %s\n", key)
|
||||
|
||||
if _, err := d.runScutil(cmd); err != nil {
|
||||
return fmt.Errorf("remove key: %w", err)
|
||||
}
|
||||
|
||||
delete(d.createdKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPrimaryServiceKey gets the primary network service key
|
||||
func (d *DarwinDNSConfigurator) getPrimaryServiceKey() (string, error) {
|
||||
cmd := fmt.Sprintf("show %s\n", globalIPv4State)
|
||||
|
||||
output, err := d.runScutil(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("run scutil: %w", err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(output))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.Contains(line, "PrimaryService") {
|
||||
parts := strings.Split(line, ":")
|
||||
if len(parts) >= 2 {
|
||||
return strings.TrimSpace(parts[1]), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("scan output: %w", err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("primary service not found")
|
||||
}
|
||||
|
||||
// parseServerAddresses parses DNS server addresses from scutil output
|
||||
func (d *DarwinDNSConfigurator) parseServerAddresses(output []byte) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
inServerArray := false
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(output))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if strings.HasPrefix(line, "ServerAddresses : <array> {") {
|
||||
inServerArray = true
|
||||
continue
|
||||
}
|
||||
|
||||
if line == "}" {
|
||||
inServerArray = false
|
||||
continue
|
||||
}
|
||||
|
||||
if inServerArray {
|
||||
// Line format: "0 : 8.8.8.8"
|
||||
parts := strings.Split(line, " : ")
|
||||
if len(parts) >= 2 {
|
||||
if addr, err := netip.ParseAddr(parts[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// flushDNSCache flushes the system DNS cache
|
||||
func (d *DarwinDNSConfigurator) flushDNSCache() error {
|
||||
logger.Debug("Flushing dscacheutil cache")
|
||||
cmd := exec.Command(dscacheutilPath, "-flushcache")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("flush cache: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Flushing mDNSResponder cache")
|
||||
|
||||
cmd = exec.Command("killall", "-HUP", "mDNSResponder")
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Non-fatal, mDNSResponder might not be running
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runScutil executes an scutil command
|
||||
func (d *DarwinDNSConfigurator) runScutil(commands string) ([]byte, error) {
|
||||
// Wrap commands with open/quit
|
||||
wrapped := fmt.Sprintf("open\n%squit\n", commands)
|
||||
|
||||
logger.Debug("Running scutil with commands:\n%s\n", wrapped)
|
||||
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader(wrapped)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scutil command failed: %w, output: %s", err, output)
|
||||
}
|
||||
|
||||
logger.Debug("scutil output:\n%s\n", output)
|
||||
|
||||
return output, nil
|
||||
}
|
||||
158
dns/platform/detect_unix.go
Normal file
158
dns/platform/detect_unix.go
Normal file
@@ -0,0 +1,158 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
const defaultResolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// DNSManagerType represents the type of DNS manager detected
|
||||
type DNSManagerType int
|
||||
|
||||
const (
|
||||
// UnknownManager indicates we couldn't determine the DNS manager
|
||||
UnknownManager DNSManagerType = iota
|
||||
// SystemdResolvedManager indicates systemd-resolved is managing DNS
|
||||
SystemdResolvedManager
|
||||
// NetworkManagerManager indicates NetworkManager is managing DNS
|
||||
NetworkManagerManager
|
||||
// ResolvconfManager indicates resolvconf is managing DNS
|
||||
ResolvconfManager
|
||||
// FileManager indicates direct file management (no DNS manager)
|
||||
FileManager
|
||||
)
|
||||
|
||||
// DetectDNSManagerFromFile reads /etc/resolv.conf to determine which DNS manager is in use
|
||||
// This provides a hint based on comments in the file, similar to Netbird's approach
|
||||
func DetectDNSManagerFromFile() DNSManagerType {
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return UnknownManager
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
if len(text) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// If we hit a non-comment line, default to file-based
|
||||
if text[0] != '#' {
|
||||
return FileManager
|
||||
}
|
||||
|
||||
// Check for DNS manager signatures in comments
|
||||
if strings.Contains(text, "NetworkManager") {
|
||||
return NetworkManagerManager
|
||||
}
|
||||
|
||||
if strings.Contains(text, "systemd-resolved") {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
return ResolvconfManager
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return UnknownManager
|
||||
}
|
||||
|
||||
// No indicators found, assume file-based management
|
||||
return FileManager
|
||||
}
|
||||
|
||||
// String returns a human-readable name for the DNS manager type
|
||||
func (d DNSManagerType) String() string {
|
||||
switch d {
|
||||
case SystemdResolvedManager:
|
||||
return "systemd-resolved"
|
||||
case NetworkManagerManager:
|
||||
return "NetworkManager"
|
||||
case ResolvconfManager:
|
||||
return "resolvconf"
|
||||
case FileManager:
|
||||
return "file"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// DetectDNSManager combines file detection with runtime availability checks
|
||||
// to determine the best DNS configurator to use
|
||||
func DetectDNSManager(interfaceName string) DNSManagerType {
|
||||
// First check what the file suggests
|
||||
fileHint := DetectDNSManagerFromFile()
|
||||
|
||||
// Verify the hint with runtime checks
|
||||
switch fileHint {
|
||||
case SystemdResolvedManager:
|
||||
// Verify systemd-resolved is actually running
|
||||
if IsSystemdResolvedAvailable() {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
logger.Warn("dns platform: Found systemd-resolved but it is not running. Falling back to file...")
|
||||
os.Exit(0)
|
||||
return FileManager
|
||||
|
||||
case NetworkManagerManager:
|
||||
// Verify NetworkManager is actually running
|
||||
if IsNetworkManagerAvailable() {
|
||||
// Check if NetworkManager is delegating to systemd-resolved
|
||||
if !IsNetworkManagerDNSModeSupported() {
|
||||
logger.Info("NetworkManager is delegating DNS to systemd-resolved, using systemd-resolved configurator")
|
||||
if IsSystemdResolvedAvailable() {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
}
|
||||
return NetworkManagerManager
|
||||
}
|
||||
logger.Warn("dns platform: Found network manager but it is not running. Falling back to file...")
|
||||
return FileManager
|
||||
|
||||
case ResolvconfManager:
|
||||
// Verify resolvconf is available
|
||||
if IsResolvconfAvailable() {
|
||||
return ResolvconfManager
|
||||
}
|
||||
// If resolvconf is mentioned but not available, fall back to file
|
||||
return FileManager
|
||||
|
||||
case FileManager:
|
||||
// File suggests direct file management
|
||||
// But we should still check if a manager is available that wasn't mentioned
|
||||
if IsSystemdResolvedAvailable() && interfaceName != "" {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
if IsNetworkManagerAvailable() && interfaceName != "" {
|
||||
return NetworkManagerManager
|
||||
}
|
||||
if IsResolvconfAvailable() && interfaceName != "" {
|
||||
return ResolvconfManager
|
||||
}
|
||||
return FileManager
|
||||
|
||||
default:
|
||||
// Unknown - do runtime detection
|
||||
if IsSystemdResolvedAvailable() && interfaceName != "" {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
if IsNetworkManagerAvailable() && interfaceName != "" {
|
||||
return NetworkManagerManager
|
||||
}
|
||||
if IsResolvconfAvailable() && interfaceName != "" {
|
||||
return ResolvconfManager
|
||||
}
|
||||
return FileManager
|
||||
}
|
||||
}
|
||||
192
dns/platform/file.go
Normal file
192
dns/platform/file.go
Normal file
@@ -0,0 +1,192 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
resolvConfPath = "/etc/resolv.conf"
|
||||
resolvConfBackupPath = "/etc/resolv.conf.olm.backup"
|
||||
resolvConfHeader = "# Generated by Olm DNS Manager\n# Original file backed up to " + resolvConfBackupPath + "\n\n"
|
||||
)
|
||||
|
||||
// FileDNSConfigurator manages DNS settings by directly modifying /etc/resolv.conf
|
||||
type FileDNSConfigurator struct {
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewFileDNSConfigurator creates a new file-based DNS configurator
|
||||
func NewFileDNSConfigurator() (*FileDNSConfigurator, error) {
|
||||
return &FileDNSConfigurator{}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (f *FileDNSConfigurator) Name() string {
|
||||
return "file-resolv.conf"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (f *FileDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := f.GetCurrentDNS()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current DNS: %w", err)
|
||||
}
|
||||
|
||||
// Backup original resolv.conf if not already backed up
|
||||
if !f.isBackupExists() {
|
||||
if err := f.backupResolvConf(); err != nil {
|
||||
return nil, fmt.Errorf("backup resolv.conf: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
f.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: f.Name(),
|
||||
}
|
||||
|
||||
// Write new resolv.conf
|
||||
if err := f.writeResolvConf(servers); err != nil {
|
||||
return nil, fmt.Errorf("write resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (f *FileDNSConfigurator) RestoreDNS() error {
|
||||
if !f.isBackupExists() {
|
||||
return fmt.Errorf("no backup file exists")
|
||||
}
|
||||
|
||||
// Copy backup back to original location
|
||||
if err := copyFile(resolvConfBackupPath, resolvConfPath); err != nil {
|
||||
return fmt.Errorf("restore from backup: %w", err)
|
||||
}
|
||||
|
||||
// Remove backup file
|
||||
if err := os.Remove(resolvConfBackupPath); err != nil {
|
||||
return fmt.Errorf("remove backup file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (f *FileDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
content, err := os.ReadFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return f.parseNameservers(string(content)), nil
|
||||
}
|
||||
|
||||
// backupResolvConf creates a backup of the current resolv.conf
|
||||
func (f *FileDNSConfigurator) backupResolvConf() error {
|
||||
// Get file info for permissions
|
||||
info, err := os.Stat(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
if err := copyFile(resolvConfPath, resolvConfBackupPath); err != nil {
|
||||
return fmt.Errorf("copy file: %w", err)
|
||||
}
|
||||
|
||||
// Preserve permissions
|
||||
if err := os.Chmod(resolvConfBackupPath, info.Mode()); err != nil {
|
||||
return fmt.Errorf("chmod backup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeResolvConf writes a new resolv.conf with the specified DNS servers
|
||||
func (f *FileDNSConfigurator) writeResolvConf(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Get file info for permissions
|
||||
info, err := os.Stat(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
var content strings.Builder
|
||||
content.WriteString(resolvConfHeader)
|
||||
|
||||
// Write nameservers
|
||||
for _, server := range servers {
|
||||
content.WriteString("nameserver ")
|
||||
content.WriteString(server.String())
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
// Write the file
|
||||
if err := os.WriteFile(resolvConfPath, []byte(content.String()), info.Mode()); err != nil {
|
||||
return fmt.Errorf("write resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isBackupExists checks if a backup file exists
|
||||
func (f *FileDNSConfigurator) isBackupExists() bool {
|
||||
_, err := os.Stat(resolvConfBackupPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// parseNameservers extracts nameserver entries from resolv.conf content
|
||||
func (f *FileDNSConfigurator) parseNameservers(content string) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// Skip comments and empty lines
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Look for nameserver lines
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
if addr, err := netip.ParseAddr(fields[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// copyFile copies a file from src to dst
|
||||
func copyFile(src, dst string) error {
|
||||
content, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read source: %w", err)
|
||||
}
|
||||
|
||||
// Get source file permissions
|
||||
info, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat source: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(dst, content, info.Mode()); err != nil {
|
||||
return fmt.Errorf("write destination: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
294
dns/platform/network_manager.go
Normal file
294
dns/platform/network_manager.go
Normal file
@@ -0,0 +1,294 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbus "github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
// NetworkManager D-Bus constants
|
||||
networkManagerDest = "org.freedesktop.NetworkManager"
|
||||
networkManagerDbusObjectNode = "/org/freedesktop/NetworkManager"
|
||||
networkManagerDbusDNSManagerInterface = "org.freedesktop.NetworkManager.DnsManager"
|
||||
networkManagerDbusDNSManagerObjectNode = networkManagerDbusObjectNode + "/DnsManager"
|
||||
networkManagerDbusDNSManagerModeProperty = networkManagerDbusDNSManagerInterface + ".Mode"
|
||||
networkManagerDbusVersionProperty = "org.freedesktop.NetworkManager.Version"
|
||||
|
||||
// NetworkManager dispatcher script path
|
||||
networkManagerDispatcherDir = "/etc/NetworkManager/dispatcher.d"
|
||||
networkManagerConfDir = "/etc/NetworkManager/conf.d"
|
||||
networkManagerDNSConfFile = "olm-dns.conf"
|
||||
networkManagerDispatcherFile = "01-olm-dns"
|
||||
)
|
||||
|
||||
// NetworkManagerDNSConfigurator manages DNS settings using NetworkManager configuration files
|
||||
// This approach works with unmanaged interfaces by modifying NetworkManager's global DNS settings
|
||||
type NetworkManagerDNSConfigurator struct {
|
||||
ifaceName string
|
||||
originalState *DNSState
|
||||
confPath string
|
||||
dispatchPath string
|
||||
}
|
||||
|
||||
// NewNetworkManagerDNSConfigurator creates a new NetworkManager DNS configurator
|
||||
func NewNetworkManagerDNSConfigurator(ifaceName string) (*NetworkManagerDNSConfigurator, error) {
|
||||
if ifaceName == "" {
|
||||
return nil, fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
// Check that NetworkManager conf.d directory exists
|
||||
if _, err := os.Stat(networkManagerConfDir); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("NetworkManager conf.d directory not found: %s", networkManagerConfDir)
|
||||
}
|
||||
|
||||
return &NetworkManagerDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
confPath: networkManagerConfDir + "/" + networkManagerDNSConfFile,
|
||||
dispatchPath: networkManagerDispatcherDir + "/" + networkManagerDispatcherFile,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (n *NetworkManagerDNSConfigurator) Name() string {
|
||||
return "network-manager"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (n *NetworkManagerDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := n.GetCurrentDNS()
|
||||
if err != nil {
|
||||
// If we can't get current DNS, proceed anyway
|
||||
originalServers = []netip.Addr{}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
n.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: n.Name(),
|
||||
}
|
||||
|
||||
// Apply new DNS servers
|
||||
if err := n.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (n *NetworkManagerDNSConfigurator) RestoreDNS() error {
|
||||
// Remove our configuration file
|
||||
if err := os.Remove(n.confPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove DNS config file: %w", err)
|
||||
}
|
||||
|
||||
// Reload NetworkManager to apply the change
|
||||
if err := n.reloadNetworkManager(); err != nil {
|
||||
return fmt.Errorf("reload NetworkManager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers by reading /etc/resolv.conf
|
||||
func (n *NetworkManagerDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
content, err := os.ReadFile("/etc/resolv.conf")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
var servers []netip.Addr
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
if addr, err := netip.ParseAddr(fields[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// applyDNSServers applies DNS server configuration via NetworkManager config file
|
||||
func (n *NetworkManagerDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Build DNS server list
|
||||
var dnsServers []string
|
||||
for _, server := range servers {
|
||||
dnsServers = append(dnsServers, server.String())
|
||||
}
|
||||
|
||||
// Create NetworkManager configuration file that sets global DNS
|
||||
// This overrides DNS for all connections
|
||||
configContent := fmt.Sprintf(`# Generated by Olm DNS Manager - DO NOT EDIT
|
||||
# This file configures NetworkManager to use Olm's DNS proxy
|
||||
|
||||
[global-dns-domain-*]
|
||||
servers=%s
|
||||
`, strings.Join(dnsServers, ","))
|
||||
|
||||
// Write the configuration file
|
||||
if err := os.WriteFile(n.confPath, []byte(configContent), 0644); err != nil {
|
||||
return fmt.Errorf("write DNS config file: %w", err)
|
||||
}
|
||||
|
||||
// Reload NetworkManager to apply the new configuration
|
||||
if err := n.reloadNetworkManager(); err != nil {
|
||||
// Try to clean up
|
||||
os.Remove(n.confPath)
|
||||
return fmt.Errorf("reload NetworkManager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reloadNetworkManager tells NetworkManager to reload its configuration
|
||||
func (n *NetworkManagerDNSConfigurator) reloadNetworkManager() error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Call Reload method with flags=0 (reload everything)
|
||||
// See: https://networkmanager.dev/docs/api/latest/gdbus-org.freedesktop.NetworkManager.html#gdbus-method-org-freedesktop-NetworkManager.Reload
|
||||
err = obj.CallWithContext(ctx, networkManagerDest+".Reload", 0, uint32(0)).Store()
|
||||
if err != nil {
|
||||
return fmt.Errorf("call Reload: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsNetworkManagerAvailable checks if NetworkManager is available and responsive
|
||||
func IsNetworkManagerAvailable() bool {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Try to ping NetworkManager
|
||||
if err := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsNetworkManagerDNSModeSupported checks if NetworkManager's DNS mode is one we can work with
|
||||
// Some DNS modes delegate to other systems (like systemd-resolved) which we should use directly
|
||||
func IsNetworkManagerDNSModeSupported() bool {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
|
||||
|
||||
modeVariant, err := obj.GetProperty(networkManagerDbusDNSManagerModeProperty)
|
||||
if err != nil {
|
||||
// If we can't get the mode, assume it's not supported
|
||||
return false
|
||||
}
|
||||
|
||||
mode, ok := modeVariant.Value().(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// If NetworkManager is delegating DNS to systemd-resolved, we should use
|
||||
// systemd-resolved directly for better control
|
||||
switch mode {
|
||||
case "systemd-resolved":
|
||||
// NetworkManager is delegating to systemd-resolved
|
||||
// We should use systemd-resolved configurator instead
|
||||
return false
|
||||
case "dnsmasq", "unbound":
|
||||
// NetworkManager is using a local resolver that it controls
|
||||
// We can configure DNS through NetworkManager
|
||||
return true
|
||||
case "default", "none", "":
|
||||
// NetworkManager is managing DNS directly or not at all
|
||||
// We can configure DNS through NetworkManager
|
||||
return true
|
||||
default:
|
||||
// Unknown mode, try to use it
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// GetNetworkManagerDNSMode returns the current DNS mode of NetworkManager
|
||||
func GetNetworkManagerDNSMode() (string, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
|
||||
|
||||
modeVariant, err := obj.GetProperty(networkManagerDbusDNSManagerModeProperty)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get DNS mode property: %w", err)
|
||||
}
|
||||
|
||||
mode, ok := modeVariant.Value().(string)
|
||||
if !ok {
|
||||
return "", errors.New("DNS mode is not a string")
|
||||
}
|
||||
|
||||
return mode, nil
|
||||
}
|
||||
|
||||
// GetNetworkManagerVersion returns the version of NetworkManager
|
||||
func GetNetworkManagerVersion() (string, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
|
||||
|
||||
versionVariant, err := obj.GetProperty(networkManagerDbusVersionProperty)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get version property: %w", err)
|
||||
}
|
||||
|
||||
version, ok := versionVariant.Value().(string)
|
||||
if !ok {
|
||||
return "", errors.New("version is not a string")
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
192
dns/platform/resolvconf.go
Normal file
192
dns/platform/resolvconf.go
Normal file
@@ -0,0 +1,192 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const resolvconfCommand = "resolvconf"
|
||||
|
||||
// ResolvconfDNSConfigurator manages DNS settings using the resolvconf utility
|
||||
type ResolvconfDNSConfigurator struct {
|
||||
ifaceName string
|
||||
implType string
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewResolvconfDNSConfigurator creates a new resolvconf DNS configurator
|
||||
func NewResolvconfDNSConfigurator(ifaceName string) (*ResolvconfDNSConfigurator, error) {
|
||||
if ifaceName == "" {
|
||||
return nil, fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
// Detect resolvconf implementation type
|
||||
implType, err := detectResolvconfType()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("detect resolvconf type: %w", err)
|
||||
}
|
||||
|
||||
return &ResolvconfDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
implType: implType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (r *ResolvconfDNSConfigurator) Name() string {
|
||||
return fmt.Sprintf("resolvconf-%s", r.implType)
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (r *ResolvconfDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := r.GetCurrentDNS()
|
||||
if err != nil {
|
||||
// If we can't get current DNS, proceed anyway
|
||||
originalServers = []netip.Addr{}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
r.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: r.Name(),
|
||||
}
|
||||
|
||||
// Apply new DNS servers
|
||||
if err := r.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (r *ResolvconfDNSConfigurator) RestoreDNS() error {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch r.implType {
|
||||
case "openresolv":
|
||||
// Force delete with -f
|
||||
cmd = exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||
default:
|
||||
cmd = exec.Command(resolvconfCommand, "-d", r.ifaceName)
|
||||
}
|
||||
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("delete resolvconf config: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (r *ResolvconfDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
// resolvconf doesn't provide a direct way to query per-interface DNS
|
||||
// We can try to read /etc/resolv.conf but it's merged from all sources
|
||||
content, err := exec.Command(resolvconfCommand, "-l").CombinedOutput()
|
||||
if err != nil {
|
||||
// Fall back to reading resolv.conf
|
||||
return readResolvConfServers()
|
||||
}
|
||||
|
||||
// Parse the output (format varies by implementation)
|
||||
return parseResolvconfOutput(string(content)), nil
|
||||
}
|
||||
|
||||
// applyDNSServers applies DNS server configuration via resolvconf
|
||||
func (r *ResolvconfDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Build resolv.conf content
|
||||
var content bytes.Buffer
|
||||
content.WriteString("# Generated by Olm DNS Manager\n\n")
|
||||
|
||||
for _, server := range servers {
|
||||
content.WriteString("nameserver ")
|
||||
content.WriteString(server.String())
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
// Apply via resolvconf
|
||||
var cmd *exec.Cmd
|
||||
switch r.implType {
|
||||
case "openresolv":
|
||||
// OpenResolv supports exclusive mode with -x
|
||||
cmd = exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||
default:
|
||||
cmd = exec.Command(resolvconfCommand, "-a", r.ifaceName)
|
||||
}
|
||||
|
||||
cmd.Stdin = &content
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("apply resolvconf config: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectResolvconfType detects which resolvconf implementation is being used
|
||||
func detectResolvconfType() (string, error) {
|
||||
cmd := exec.Command(resolvconfCommand, "--version")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("detect resolvconf type: %w", err)
|
||||
}
|
||||
|
||||
if strings.Contains(string(out), "openresolv") {
|
||||
return "openresolv", nil
|
||||
}
|
||||
|
||||
return "resolvconf", nil
|
||||
}
|
||||
|
||||
// parseResolvconfOutput parses resolvconf -l output for DNS servers
|
||||
func parseResolvconfOutput(output string) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
|
||||
lines := strings.Split(output, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// Skip comments and empty lines
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Look for nameserver lines
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
if addr, err := netip.ParseAddr(fields[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// readResolvConfServers reads DNS servers from /etc/resolv.conf
|
||||
func readResolvConfServers() ([]netip.Addr, error) {
|
||||
cmd := exec.Command("cat", "/etc/resolv.conf")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return parseResolvconfOutput(string(out)), nil
|
||||
}
|
||||
|
||||
// IsResolvconfAvailable checks if resolvconf is available
|
||||
func IsResolvconfAvailable() bool {
|
||||
cmd := exec.Command(resolvconfCommand, "--version")
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
286
dns/platform/systemd.go
Normal file
286
dns/platform/systemd.go
Normal file
@@ -0,0 +1,286 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
dbus "github.com/godbus/dbus/v5"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
systemdResolvedDest = "org.freedesktop.resolve1"
|
||||
systemdDbusObjectNode = "/org/freedesktop/resolve1"
|
||||
systemdDbusManagerIface = "org.freedesktop.resolve1.Manager"
|
||||
systemdDbusGetLinkMethod = systemdDbusManagerIface + ".GetLink"
|
||||
systemdDbusFlushCachesMethod = systemdDbusManagerIface + ".FlushCaches"
|
||||
systemdDbusLinkInterface = "org.freedesktop.resolve1.Link"
|
||||
systemdDbusSetDNSMethod = systemdDbusLinkInterface + ".SetDNS"
|
||||
systemdDbusSetDefaultRouteMethod = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||
systemdDbusSetDomainsMethod = systemdDbusLinkInterface + ".SetDomains"
|
||||
systemdDbusSetDNSSECMethod = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||
systemdDbusSetDNSOverTLSMethod = systemdDbusLinkInterface + ".SetDNSOverTLS"
|
||||
systemdDbusRevertMethod = systemdDbusLinkInterface + ".Revert"
|
||||
|
||||
// RootZone is the root DNS zone that matches all queries
|
||||
RootZone = "."
|
||||
)
|
||||
|
||||
// systemdDbusDNSInput maps to (iay) dbus input for SetDNS method
|
||||
type systemdDbusDNSInput struct {
|
||||
Family int32
|
||||
Address []byte
|
||||
}
|
||||
|
||||
// systemdDbusDomainsInput maps to (sb) dbus input for SetDomains method
|
||||
type systemdDbusDomainsInput struct {
|
||||
Domain string
|
||||
MatchOnly bool
|
||||
}
|
||||
|
||||
// SystemdResolvedDNSConfigurator manages DNS settings using systemd-resolved D-Bus API
|
||||
type SystemdResolvedDNSConfigurator struct {
|
||||
ifaceName string
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewSystemdResolvedDNSConfigurator creates a new systemd-resolved DNS configurator
|
||||
func NewSystemdResolvedDNSConfigurator(ifaceName string) (*SystemdResolvedDNSConfigurator, error) {
|
||||
// Get network interface
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get interface: %w", err)
|
||||
}
|
||||
|
||||
// Connect to D-Bus
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
|
||||
|
||||
// Get the link object for this interface
|
||||
var linkPath string
|
||||
if err := obj.Call(systemdDbusGetLinkMethod, 0, iface.Index).Store(&linkPath); err != nil {
|
||||
return nil, fmt.Errorf("get link: %w", err)
|
||||
}
|
||||
|
||||
return &SystemdResolvedDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
dbusLinkObject: dbus.ObjectPath(linkPath),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (s *SystemdResolvedDNSConfigurator) Name() string {
|
||||
return "systemd-resolved"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (s *SystemdResolvedDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := s.GetCurrentDNS()
|
||||
if err != nil {
|
||||
// If we can't get current DNS, proceed anyway
|
||||
originalServers = []netip.Addr{}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
s.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: s.Name(),
|
||||
}
|
||||
|
||||
// Apply new DNS servers
|
||||
if err := s.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (s *SystemdResolvedDNSConfigurator) RestoreDNS() error {
|
||||
// Call Revert method to restore systemd-resolved defaults
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := obj.CallWithContext(ctx, systemdDbusRevertMethod, 0).Store(); err != nil {
|
||||
return fmt.Errorf("revert DNS settings: %w", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache after reverting
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
// Note: systemd-resolved doesn't easily expose current per-link DNS servers via D-Bus
|
||||
// This is a placeholder that returns an empty list
|
||||
func (s *SystemdResolvedDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
// systemd-resolved's D-Bus API doesn't have a simple way to query current DNS servers
|
||||
// We would need to parse resolvectl status output or read from /run/systemd/resolve/
|
||||
// For now, return empty list
|
||||
return []netip.Addr{}, nil
|
||||
}
|
||||
|
||||
// applyDNSServers applies DNS server configuration via systemd-resolved
|
||||
func (s *SystemdResolvedDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Convert servers to systemd-resolved format
|
||||
var dnsInputs []systemdDbusDNSInput
|
||||
for _, server := range servers {
|
||||
family := unix.AF_INET
|
||||
if server.Is6() {
|
||||
family = unix.AF_INET6
|
||||
}
|
||||
|
||||
dnsInputs = append(dnsInputs, systemdDbusDNSInput{
|
||||
Family: int32(family),
|
||||
Address: server.AsSlice(),
|
||||
})
|
||||
}
|
||||
|
||||
// Connect to D-Bus
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Call SetDNS method to set the DNS servers
|
||||
if err := obj.CallWithContext(ctx, systemdDbusSetDNSMethod, 0, dnsInputs).Store(); err != nil {
|
||||
return fmt.Errorf("set DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Set this interface as the default route for DNS
|
||||
// This ensures all DNS queries prefer this interface
|
||||
if err := s.callLinkMethod(systemdDbusSetDefaultRouteMethod, true); err != nil {
|
||||
return fmt.Errorf("set default route: %w", err)
|
||||
}
|
||||
|
||||
// Set the root zone "." as a match-only domain
|
||||
// This captures ALL DNS queries and routes them through this interface
|
||||
domainsInput := []systemdDbusDomainsInput{
|
||||
{
|
||||
Domain: RootZone,
|
||||
MatchOnly: true,
|
||||
},
|
||||
}
|
||||
if err := s.callLinkMethod(systemdDbusSetDomainsMethod, domainsInput); err != nil {
|
||||
return fmt.Errorf("set domains: %w", err)
|
||||
}
|
||||
|
||||
// Disable DNSSEC - we don't support it and it may be enabled by default
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSSECMethod, "no"); err != nil {
|
||||
// Log warning but don't fail - this is optional
|
||||
fmt.Printf("warning: failed to disable DNSSEC: %v\n", err)
|
||||
}
|
||||
|
||||
// Disable DNSOverTLS - we don't support it and it may be enabled by default
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethod, "no"); err != nil {
|
||||
// Log warning but don't fail - this is optional
|
||||
fmt.Printf("warning: failed to disable DNSOverTLS: %v\n", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache to ensure new settings take effect immediately
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callLinkMethod is a helper to call methods on the link object
|
||||
func (s *SystemdResolvedDNSConfigurator) callLinkMethod(method string, value any) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if value != nil {
|
||||
if err := obj.CallWithContext(ctx, method, 0, value).Store(); err != nil {
|
||||
return fmt.Errorf("call %s: %w", method, err)
|
||||
}
|
||||
} else {
|
||||
if err := obj.CallWithContext(ctx, method, 0).Store(); err != nil {
|
||||
return fmt.Errorf("call %s: %w", method, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// flushDNSCache flushes the systemd-resolved DNS cache
|
||||
func (s *SystemdResolvedDNSConfigurator) flushDNSCache() error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, 0).Store(); err != nil {
|
||||
return fmt.Errorf("flush caches: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsSystemdResolvedAvailable checks if systemd-resolved is available and responsive
|
||||
func IsSystemdResolvedAvailable() bool {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Try to ping systemd-resolved
|
||||
if err := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
41
dns/platform/types.go
Normal file
41
dns/platform/types.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package dns
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// DNSConfigurator provides an interface for managing system DNS settings
|
||||
// across different platforms and implementations
|
||||
type DNSConfigurator interface {
|
||||
// SetDNS overrides the system DNS servers with the specified ones
|
||||
// Returns the original DNS servers that were replaced
|
||||
SetDNS(servers []netip.Addr) ([]netip.Addr, error)
|
||||
|
||||
// RestoreDNS restores the original DNS servers
|
||||
RestoreDNS() error
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
GetCurrentDNS() ([]netip.Addr, error)
|
||||
|
||||
// Name returns the name of this configurator implementation
|
||||
Name() string
|
||||
}
|
||||
|
||||
// DNSConfig contains the configuration for DNS override
|
||||
type DNSConfig struct {
|
||||
// Servers is the list of DNS servers to use
|
||||
Servers []netip.Addr
|
||||
|
||||
// SearchDomains is an optional list of search domains
|
||||
SearchDomains []string
|
||||
}
|
||||
|
||||
// DNSState represents the saved state of DNS configuration
|
||||
type DNSState struct {
|
||||
// OriginalServers are the DNS servers before override
|
||||
OriginalServers []netip.Addr
|
||||
|
||||
// OriginalSearchDomains are the search domains before override
|
||||
OriginalSearchDomains []string
|
||||
|
||||
// ConfiguratorName is the name of the configurator that saved this state
|
||||
ConfiguratorName string
|
||||
}
|
||||
343
dns/platform/windows.go
Normal file
343
dns/platform/windows.go
Normal file
@@ -0,0 +1,343 @@
|
||||
//go:build windows
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
|
||||
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
|
||||
)
|
||||
|
||||
const (
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServer = "NameServer"
|
||||
interfaceConfigDhcpNameServer = "DhcpNameServer"
|
||||
)
|
||||
|
||||
// WindowsDNSConfigurator manages DNS settings on Windows using the registry
|
||||
type WindowsDNSConfigurator struct {
|
||||
guid string
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
|
||||
// Accepts an interface name and extracts the GUID internally
|
||||
func NewWindowsDNSConfigurator(interfaceName string) (*WindowsDNSConfigurator, error) {
|
||||
if interfaceName == "" {
|
||||
return nil, fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
guid, err := getInterfaceGUIDString(interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get interface GUID: %w", err)
|
||||
}
|
||||
|
||||
return &WindowsDNSConfigurator{
|
||||
guid: guid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newWindowsDNSConfiguratorFromGUID creates a configurator from a GUID string
|
||||
// This is an internal function for use by DetectBestConfigurator
|
||||
func newWindowsDNSConfiguratorFromGUID(guid string) (*WindowsDNSConfigurator, error) {
|
||||
if guid == "" {
|
||||
return nil, fmt.Errorf("interface GUID is required")
|
||||
}
|
||||
|
||||
return &WindowsDNSConfigurator{
|
||||
guid: guid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (w *WindowsDNSConfigurator) Name() string {
|
||||
return "windows-registry"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (w *WindowsDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := w.GetCurrentDNS()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current DNS: %w", err)
|
||||
}
|
||||
|
||||
// Store original state
|
||||
w.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: w.Name(),
|
||||
}
|
||||
|
||||
// Set new DNS servers
|
||||
if err := w.setDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("set DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := w.flushDNSCache(); err != nil {
|
||||
// Non-fatal, just log
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (w *WindowsDNSConfigurator) RestoreDNS() error {
|
||||
if w.originalState == nil {
|
||||
return fmt.Errorf("no original state to restore")
|
||||
}
|
||||
|
||||
// Clear the static DNS setting
|
||||
if err := w.clearDNSServers(); err != nil {
|
||||
return fmt.Errorf("clear DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := w.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (w *WindowsDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
regKey, err := w.getInterfaceRegistryKey(registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer closeKey(regKey)
|
||||
|
||||
// Try to get static DNS first
|
||||
nameServer, _, err := regKey.GetStringValue(interfaceConfigNameServer)
|
||||
if err == nil && nameServer != "" {
|
||||
return w.parseServerList(nameServer), nil
|
||||
}
|
||||
|
||||
// Fall back to DHCP DNS
|
||||
dhcpNameServer, _, err := regKey.GetStringValue(interfaceConfigDhcpNameServer)
|
||||
if err == nil && dhcpNameServer != "" {
|
||||
return w.parseServerList(dhcpNameServer), nil
|
||||
}
|
||||
|
||||
return []netip.Addr{}, nil
|
||||
}
|
||||
|
||||
// setDNSServers sets the DNS servers in the registry
|
||||
func (w *WindowsDNSConfigurator) setDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer closeKey(regKey)
|
||||
|
||||
// Build comma-separated or space-separated list of servers
|
||||
var serverList string
|
||||
for i, server := range servers {
|
||||
if i > 0 {
|
||||
serverList += ","
|
||||
}
|
||||
serverList += server.String()
|
||||
}
|
||||
|
||||
if err := regKey.SetStringValue(interfaceConfigNameServer, serverList); err != nil {
|
||||
return fmt.Errorf("set NameServer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearDNSServers clears the static DNS server setting
|
||||
func (w *WindowsDNSConfigurator) clearDNSServers() error {
|
||||
regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer closeKey(regKey)
|
||||
|
||||
// Set empty string to revert to DHCP
|
||||
if err := regKey.SetStringValue(interfaceConfigNameServer, ""); err != nil {
|
||||
return fmt.Errorf("clear NameServer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInterfaceRegistryKey opens the registry key for the network interface
|
||||
func (w *WindowsDNSConfigurator) getInterfaceRegistryKey(access uint32) (registry.Key, error) {
|
||||
regKeyPath := interfaceConfigPath + `\` + w.guid
|
||||
|
||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, access)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("open HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
|
||||
}
|
||||
|
||||
return regKey, nil
|
||||
}
|
||||
|
||||
// parseServerList parses a comma or space-separated list of DNS servers
|
||||
func (w *WindowsDNSConfigurator) parseServerList(serverList string) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
|
||||
// Split by comma or space
|
||||
parts := splitByDelimiters(serverList, []rune{',', ' '})
|
||||
|
||||
for _, part := range parts {
|
||||
if addr, err := netip.ParseAddr(part); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// flushDNSCache flushes the Windows DNS resolver cache
|
||||
func (w *WindowsDNSConfigurator) flushDNSCache() error {
|
||||
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
fmt.Printf("warning: DnsFlushResolverCache panicked: %v\n", rec)
|
||||
}
|
||||
}()
|
||||
|
||||
ret, _, err := dnsFlushResolverCacheFn.Call()
|
||||
if ret == 0 {
|
||||
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
|
||||
}
|
||||
return fmt.Errorf("DnsFlushResolverCache failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitByDelimiters splits a string by multiple delimiters
|
||||
func splitByDelimiters(s string, delimiters []rune) []string {
|
||||
var result []string
|
||||
var current []rune
|
||||
|
||||
for _, char := range s {
|
||||
isDelimiter := false
|
||||
for _, delim := range delimiters {
|
||||
if char == delim {
|
||||
isDelimiter = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isDelimiter {
|
||||
if len(current) > 0 {
|
||||
result = append(result, string(current))
|
||||
current = []rune{}
|
||||
}
|
||||
} else {
|
||||
current = append(current, char)
|
||||
}
|
||||
}
|
||||
|
||||
if len(current) > 0 {
|
||||
result = append(result, string(current))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// closeKey closes a registry key and logs errors
|
||||
func closeKey(closer io.Closer) {
|
||||
if err := closer.Close(); err != nil {
|
||||
fmt.Printf("warning: failed to close registry key: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
|
||||
// This is required for registry-based DNS configuration on Windows
|
||||
func getInterfaceGUIDString(interfaceName string) (string, error) {
|
||||
if interfaceName == "" {
|
||||
return "", fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get interface %s: %w", interfaceName, err)
|
||||
}
|
||||
|
||||
luid, err := indexToLUID(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to convert index to LUID: %w", err)
|
||||
}
|
||||
|
||||
// Convert LUID to GUID using Windows API
|
||||
guid, err := luidToGUID(luid)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to convert LUID to GUID: %w", err)
|
||||
}
|
||||
|
||||
return guid, nil
|
||||
}
|
||||
|
||||
// indexToLUID converts a Windows interface index to a LUID
|
||||
func indexToLUID(index uint32) (uint64, error) {
|
||||
var luid uint64
|
||||
|
||||
// Load the iphlpapi.dll and get the ConvertInterfaceIndexToLuid function
|
||||
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
convertInterfaceIndexToLuid := iphlpapi.NewProc("ConvertInterfaceIndexToLuid")
|
||||
|
||||
// Call the Windows API
|
||||
ret, _, err := convertInterfaceIndexToLuid.Call(
|
||||
uintptr(index),
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
return 0, fmt.Errorf("ConvertInterfaceIndexToLuid failed with code %d: %w", ret, err)
|
||||
}
|
||||
|
||||
return luid, nil
|
||||
}
|
||||
|
||||
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
|
||||
// using the Windows ConvertInterface* APIs
|
||||
func luidToGUID(luid uint64) (string, error) {
|
||||
var guid windows.GUID
|
||||
|
||||
// Load the iphlpapi.dll and get the ConvertInterfaceLuidToGuid function
|
||||
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
convertLuidToGuid := iphlpapi.NewProc("ConvertInterfaceLuidToGuid")
|
||||
|
||||
// Call the Windows API
|
||||
// NET_LUID is a 64-bit value on Windows
|
||||
ret, _, err := convertLuidToGuid.Call(
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
uintptr(unsafe.Pointer(&guid)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
return "", fmt.Errorf("ConvertInterfaceLuidToGuid failed with code %d: %w", ret, err)
|
||||
}
|
||||
|
||||
// Format the GUID as a string with curly braces
|
||||
guidStr := fmt.Sprintf("{%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}",
|
||||
guid.Data1, guid.Data2, guid.Data3,
|
||||
guid.Data4[0], guid.Data4[1], guid.Data4[2], guid.Data4[3],
|
||||
guid.Data4[4], guid.Data4[5], guid.Data4[6], guid.Data4[7])
|
||||
|
||||
return guidStr, nil
|
||||
}
|
||||
23
go.mod
23
go.mod
@@ -4,20 +4,29 @@ go 1.25
|
||||
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2
|
||||
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7
|
||||
github.com/fosrl/newt v0.0.0-20251208171729-6d7985689552
|
||||
github.com/godbus/dbus/v5 v5.2.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
|
||||
golang.org/x/net v0.45.0
|
||||
golang.org/x/sys v0.37.0
|
||||
github.com/miekg/dns v1.1.68
|
||||
golang.org/x/sys v0.38.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/vishvananda/netlink v1.3.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
||||
golang.org/x/mod v0.30.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
golang.org/x/tools v0.39.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||
)
|
||||
|
||||
36
go.sum
36
go.sum
@@ -1,36 +1,48 @@
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 h1:6bSU8Efyhx1SR53iSw1Wjk5V8vDfizGAudq/GlE9b+o=
|
||||
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7/go.mod h1:Ac0k2FmAMC+hu21rAK+p7EnnEGrqKO/QZuGTVHA/XDM=
|
||||
github.com/fosrl/newt v0.0.0-20251208171729-6d7985689552 h1:51pHUtoqQhYPS9OiBDHLgYV44X/CBzR5J7GuWO3izhU=
|
||||
github.com/fosrl/newt v0.0.0-20251208171729-6d7985689552/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
|
||||
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
|
||||
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
|
||||
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c=
|
||||
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
package holepunch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/bind"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/exp/rand"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// DomainResolver is a function type for resolving domains to IP addresses
|
||||
type DomainResolver func(string) (string, error)
|
||||
|
||||
// ExitNode represents a WireGuard exit node for hole punching
|
||||
type ExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
// Manager handles UDP hole punching operations
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
olmID string
|
||||
token string
|
||||
domainResolver DomainResolver
|
||||
}
|
||||
|
||||
// NewManager creates a new hole punch manager
|
||||
func NewManager(sharedBind *bind.SharedBind, olmID string, domainResolver DomainResolver) *Manager {
|
||||
return &Manager{
|
||||
sharedBind: sharedBind,
|
||||
olmID: olmID,
|
||||
domainResolver: domainResolver,
|
||||
}
|
||||
}
|
||||
|
||||
// SetToken updates the authentication token used for hole punching
|
||||
func (m *Manager) SetToken(token string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.token = token
|
||||
}
|
||||
|
||||
// IsRunning returns whether hole punching is currently active
|
||||
func (m *Manager) IsRunning() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.running
|
||||
}
|
||||
|
||||
// Stop stops any ongoing hole punch operations
|
||||
func (m *Manager) Stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.running {
|
||||
return
|
||||
}
|
||||
|
||||
if m.stopChan != nil {
|
||||
close(m.stopChan)
|
||||
m.stopChan = nil
|
||||
}
|
||||
|
||||
m.running = false
|
||||
logger.Info("Hole punch manager stopped")
|
||||
}
|
||||
|
||||
// StartMultipleExitNodes starts hole punching to multiple exit nodes
|
||||
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
||||
m.mu.Lock()
|
||||
|
||||
if m.running {
|
||||
m.mu.Unlock()
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return fmt.Errorf("hole punch already running")
|
||||
}
|
||||
|
||||
if len(exitNodes) == 0 {
|
||||
m.mu.Unlock()
|
||||
logger.Warn("No exit nodes provided for hole punching")
|
||||
return fmt.Errorf("no exit nodes provided")
|
||||
}
|
||||
|
||||
m.running = true
|
||||
m.stopChan = make(chan struct{})
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||
|
||||
go m.runMultipleExitNodes(exitNodes)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
|
||||
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
|
||||
m.mu.Lock()
|
||||
|
||||
if m.running {
|
||||
m.mu.Unlock()
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return fmt.Errorf("hole punch already running")
|
||||
}
|
||||
|
||||
m.running = true
|
||||
m.stopChan = make(chan struct{})
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
|
||||
|
||||
go m.runSingleEndpoint(endpoint, serverPubKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runMultipleExitNodes performs hole punching to multiple exit nodes
|
||||
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
m.running = false
|
||||
m.mu.Unlock()
|
||||
logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
||||
}()
|
||||
|
||||
// Resolve all endpoints upfront
|
||||
type resolvedExitNode struct {
|
||||
remoteAddr *net.UDPAddr
|
||||
publicKey string
|
||||
endpointName string
|
||||
}
|
||||
|
||||
var resolvedNodes []resolvedExitNode
|
||||
for _, exitNode := range exitNodes {
|
||||
host, err := m.domainResolver(exitNode.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||
remoteAddr: remoteAddr,
|
||||
publicKey: exitNode.PublicKey,
|
||||
endpointName: exitNode.Endpoint,
|
||||
})
|
||||
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||
}
|
||||
|
||||
if len(resolvedNodes) == 0 {
|
||||
logger.Error("No exit nodes could be resolved")
|
||||
return
|
||||
}
|
||||
|
||||
// Send initial hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopChan:
|
||||
logger.Debug("Hole punch stopped by signal")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Debug("Hole punch timeout reached")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Send hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runSingleEndpoint performs hole punching to a single endpoint
|
||||
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
m.running = false
|
||||
m.mu.Unlock()
|
||||
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
||||
}()
|
||||
|
||||
host, err := m.domainResolver(endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
|
||||
return
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute once immediately before starting the loop
|
||||
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||
logger.Warn("Failed to send initial hole punch: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopChan:
|
||||
logger.Debug("Hole punch stopped by signal")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Debug("Hole punch timeout reached")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||
logger.Debug("Failed to send hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendHolePunch sends an encrypted hole punch packet using the shared bind
|
||||
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
|
||||
m.mu.Lock()
|
||||
token := m.token
|
||||
olmID := m.olmID
|
||||
m.mu.Unlock()
|
||||
|
||||
if serverPubKey == "" || token == "" {
|
||||
return fmt.Errorf("server public key or OLM token is empty")
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
OlmID string `json:"olmId"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
OlmID: olmID,
|
||||
Token: token,
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %w", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(encryptedPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
|
||||
}
|
||||
|
||||
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write to UDP: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
|
||||
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
|
||||
// Generate an ephemeral keypair for this message
|
||||
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||
}
|
||||
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||
|
||||
// Parse the server's public key
|
||||
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||
}
|
||||
|
||||
// Use X25519 for key exchange
|
||||
var ephPrivKeyFixed [32]byte
|
||||
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||
|
||||
// Perform X25519 key exchange
|
||||
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||
}
|
||||
|
||||
// Create an AEAD cipher using the shared secret
|
||||
aead, err := chacha20poly1305.New(sharedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload
|
||||
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||
|
||||
// Prepare the final encrypted message
|
||||
encryptedMsg := struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}{
|
||||
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}
|
||||
|
||||
return encryptedMsg, nil
|
||||
}
|
||||
95
main.go
95
main.go
@@ -9,6 +9,7 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/updates"
|
||||
"github.com/fosrl/olm/olm"
|
||||
)
|
||||
|
||||
@@ -154,25 +155,30 @@ func main() {
|
||||
}
|
||||
|
||||
// Create a context that will be cancelled on interrupt signals
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
signalCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
// Create a separate context for programmatic shutdown (e.g., via API exit)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Run in console mode
|
||||
runOlmMainWithArgs(ctx, os.Args[1:])
|
||||
runOlmMainWithArgs(ctx, cancel, signalCtx, os.Args[1:])
|
||||
}
|
||||
|
||||
func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCtx context.Context, args []string) {
|
||||
// Setup Windows event logging if on Windows
|
||||
if runtime.GOOS != "windows" {
|
||||
if runtime.GOOS == "windows" {
|
||||
setupWindowsEventLog()
|
||||
} else {
|
||||
// Initialize logger for non-Windows platforms
|
||||
logger.Init()
|
||||
logger.Init(nil)
|
||||
}
|
||||
|
||||
// Load configuration from file, env vars, and CLI args
|
||||
// Priority: CLI args > Env vars > Config file > Defaults
|
||||
config, showVersion, showConfig, err := LoadConfig(os.Args[1:])
|
||||
// Use the passed args parameter instead of os.Args[1:] to support Windows service mode
|
||||
config, showVersion, showConfig, err := LoadConfig(args)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to load configuration: %v\n", err)
|
||||
return
|
||||
@@ -189,7 +195,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
fmt.Println("Olm version " + olmVersion)
|
||||
os.Exit(0)
|
||||
}
|
||||
logger.Info("Olm version " + olmVersion)
|
||||
logger.Info("Olm version %s", olmVersion)
|
||||
|
||||
config.Version = olmVersion
|
||||
|
||||
@@ -199,27 +205,60 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
logger.Debug("Saved full olm config with all options")
|
||||
}
|
||||
|
||||
// Create a new olm.Config struct and copy values from the main config
|
||||
olmConfig := olm.Config{
|
||||
Endpoint: config.Endpoint,
|
||||
ID: config.ID,
|
||||
Secret: config.Secret,
|
||||
UserToken: config.UserToken,
|
||||
MTU: config.MTU,
|
||||
DNS: config.DNS,
|
||||
InterfaceName: config.InterfaceName,
|
||||
LogLevel: config.LogLevel,
|
||||
EnableAPI: config.EnableAPI,
|
||||
HTTPAddr: config.HTTPAddr,
|
||||
SocketPath: config.SocketPath,
|
||||
Holepunch: config.Holepunch,
|
||||
TlsClientCert: config.TlsClientCert,
|
||||
PingIntervalDuration: config.PingIntervalDuration,
|
||||
PingTimeoutDuration: config.PingTimeoutDuration,
|
||||
Version: config.Version,
|
||||
OrgID: config.OrgID,
|
||||
// DoNotCreateNewClient: config.DoNotCreateNewClient,
|
||||
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
|
||||
logger.Debug("Failed to check for updates: %v", err)
|
||||
}
|
||||
|
||||
olm.Run(ctx, olmConfig)
|
||||
// Create a new olm.Config struct and copy values from the main config
|
||||
olmConfig := olm.GlobalConfig{
|
||||
LogLevel: config.LogLevel,
|
||||
EnableAPI: config.EnableAPI,
|
||||
HTTPAddr: config.HTTPAddr,
|
||||
SocketPath: config.SocketPath,
|
||||
Version: config.Version,
|
||||
Agent: "Olm CLI",
|
||||
OnExit: cancel, // Pass cancel function directly to trigger shutdown
|
||||
OnTerminated: cancel,
|
||||
}
|
||||
|
||||
olm.Init(ctx, olmConfig)
|
||||
if err := olm.StartApi(); err != nil {
|
||||
logger.Fatal("Failed to start API server: %v", err)
|
||||
}
|
||||
|
||||
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
|
||||
tunnelConfig := olm.TunnelConfig{
|
||||
Endpoint: config.Endpoint,
|
||||
ID: config.ID,
|
||||
Secret: config.Secret,
|
||||
UserToken: config.UserToken,
|
||||
MTU: config.MTU,
|
||||
DNS: config.DNS,
|
||||
UpstreamDNS: config.UpstreamDNS,
|
||||
InterfaceName: config.InterfaceName,
|
||||
Holepunch: !config.DisableHolepunch,
|
||||
TlsClientCert: config.TlsClientCert,
|
||||
PingIntervalDuration: config.PingIntervalDuration,
|
||||
PingTimeoutDuration: config.PingTimeoutDuration,
|
||||
OrgID: config.OrgID,
|
||||
OverrideDNS: config.OverrideDNS,
|
||||
DisableRelay: config.DisableRelay,
|
||||
EnableUAPI: true,
|
||||
}
|
||||
go olm.StartTunnel(tunnelConfig)
|
||||
} else {
|
||||
logger.Info("Incomplete tunnel configuration, not starting tunnel")
|
||||
}
|
||||
|
||||
// Wait for either signal or programmatic shutdown
|
||||
select {
|
||||
case <-signalCtx.Done():
|
||||
logger.Info("Shutdown signal received, cleaning up...")
|
||||
case <-ctx.Done():
|
||||
logger.Info("Shutdown requested via API, cleaning up...")
|
||||
}
|
||||
|
||||
// Clean up resources
|
||||
olm.Close()
|
||||
logger.Info("Shutdown complete")
|
||||
}
|
||||
|
||||
126
namespace.sh
Normal file
126
namespace.sh
Normal file
@@ -0,0 +1,126 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Configuration
|
||||
NS_NAME="isolated_ns" # Name of the namespace
|
||||
VETH_HOST="veth_host" # Interface name on host side
|
||||
VETH_NS="veth_ns" # Interface name inside namespace
|
||||
HOST_IP="192.168.15.1" # Gateway IP for the namespace (host side)
|
||||
NS_IP="192.168.15.2" # IP address for the namespace
|
||||
SUBNET_CIDR="24" # Subnet mask
|
||||
DNS_SERVER="8.8.8.8" # DNS to use inside namespace
|
||||
|
||||
# Detect the main physical interface (gateway to internet)
|
||||
PHY_IFACE=$(ip route get 8.8.8.8 | awk -- '{printf $5}')
|
||||
|
||||
# Helper function to check for root
|
||||
check_root() {
|
||||
if [ "$EUID" -ne 0 ]; then
|
||||
echo "Error: This script must be run as root."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
setup_ns() {
|
||||
echo "Bringing up namespace '$NS_NAME'..."
|
||||
|
||||
# 1. Create the network namespace
|
||||
if ip netns list | grep -q "$NS_NAME"; then
|
||||
echo "Namespace $NS_NAME already exists. Run 'down' first."
|
||||
exit 1
|
||||
fi
|
||||
ip netns add "$NS_NAME"
|
||||
|
||||
# 2. Create veth pair
|
||||
ip link add "$VETH_HOST" type veth peer name "$VETH_NS"
|
||||
|
||||
# 3. Move peer interface to namespace
|
||||
ip link set "$VETH_NS" netns "$NS_NAME"
|
||||
|
||||
# 4. Configure Host Side Interface
|
||||
ip addr add "${HOST_IP}/${SUBNET_CIDR}" dev "$VETH_HOST"
|
||||
ip link set "$VETH_HOST" up
|
||||
|
||||
# 5. Configure Namespace Side Interface
|
||||
ip netns exec "$NS_NAME" ip addr add "${NS_IP}/${SUBNET_CIDR}" dev "$VETH_NS"
|
||||
ip netns exec "$NS_NAME" ip link set "$VETH_NS" up
|
||||
|
||||
# 6. Bring up loopback inside namespace (crucial for many apps)
|
||||
ip netns exec "$NS_NAME" ip link set lo up
|
||||
|
||||
# 7. Routing: Add default gateway inside namespace pointing to host
|
||||
ip netns exec "$NS_NAME" ip route add default via "$HOST_IP"
|
||||
|
||||
# 8. Enable IP forwarding on host
|
||||
echo 1 > /proc/sys/net/ipv4/ip_forward
|
||||
|
||||
# 9. NAT/Masquerade: Allow traffic from namespace to go out physical interface
|
||||
# We verify rule doesn't exist first to avoid duplicates
|
||||
iptables -t nat -C POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE 2>/dev/null || \
|
||||
iptables -t nat -A POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE
|
||||
|
||||
# Allow forwarding from host veth to WAN and back
|
||||
iptables -C FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT 2>/dev/null || \
|
||||
iptables -A FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT
|
||||
|
||||
iptables -C FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT 2>/dev/null || \
|
||||
iptables -A FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT
|
||||
|
||||
# 10. DNS Setup
|
||||
# Netns uses /etc/netns/<name>/resolv.conf if it exists
|
||||
mkdir -p "/etc/netns/$NS_NAME"
|
||||
echo "nameserver $DNS_SERVER" > "/etc/netns/$NS_NAME/resolv.conf"
|
||||
|
||||
echo "Namespace $NS_NAME is UP."
|
||||
echo "To enter shell: sudo ip netns exec $NS_NAME bash"
|
||||
}
|
||||
|
||||
teardown_ns() {
|
||||
echo "Tearing down namespace '$NS_NAME'..."
|
||||
|
||||
# 1. Remove Namespace (this automatically deletes the veth pair inside it)
|
||||
# The host side veth usually disappears when the peer is destroyed.
|
||||
if ip netns list | grep -q "$NS_NAME"; then
|
||||
ip netns del "$NS_NAME"
|
||||
else
|
||||
echo "Namespace $NS_NAME does not exist."
|
||||
fi
|
||||
|
||||
# 2. Clean up veth host side if it still lingers
|
||||
if ip link show "$VETH_HOST" > /dev/null 2>&1; then
|
||||
ip link delete "$VETH_HOST"
|
||||
fi
|
||||
|
||||
# 3. Remove iptables rules
|
||||
# We use -D to delete the specific rules we added
|
||||
iptables -t nat -D POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE 2>/dev/null
|
||||
iptables -D FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT 2>/dev/null
|
||||
iptables -D FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT 2>/dev/null
|
||||
|
||||
# 4. Remove DNS config
|
||||
rm -rf "/etc/netns/$NS_NAME"
|
||||
|
||||
echo "Namespace $NS_NAME is DOWN."
|
||||
}
|
||||
|
||||
test_connectivity() {
|
||||
echo "Testing connectivity inside $NS_NAME..."
|
||||
ip netns exec "$NS_NAME" ping -c 3 8.8.8.8
|
||||
}
|
||||
|
||||
# Main execution logic
|
||||
check_root
|
||||
|
||||
case "$1" in
|
||||
up)
|
||||
setup_ns
|
||||
;;
|
||||
down)
|
||||
teardown_ns
|
||||
;;
|
||||
test)
|
||||
test_connectivity
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {up|down|test}"
|
||||
exit 1
|
||||
esac
|
||||
70
olm.iss
70
olm.iss
@@ -57,13 +57,13 @@ Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"
|
||||
; The 'Path' variable is located under 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'.
|
||||
; ValueType: expandsz allows for environment variables (like %ProgramFiles%) in the path.
|
||||
; ValueData: "{olddata};{app}" appends the current application directory to the existing PATH.
|
||||
; Flags: uninsdeletevalue ensures the entry is removed upon uninstallation.
|
||||
; Check: IsWin64 ensures this is applied on 64-bit systems, which matches ArchitecturesAllowed.
|
||||
; Note: Removal during uninstallation is handled by CurUninstallStepChanged procedure in [Code] section.
|
||||
; Check: NeedsAddPath ensures this is applied only if the path is not already present.
|
||||
[Registry]
|
||||
; Add the application's installation directory to the system PATH.
|
||||
Root: HKLM; Subkey: "SYSTEM\CurrentControlSet\Control\Session Manager\Environment"; \
|
||||
ValueType: expandsz; ValueName: "Path"; ValueData: "{olddata};{app}"; \
|
||||
Flags: uninsdeletevalue; Check: NeedsAddPath(ExpandConstant('{app}'))
|
||||
Check: NeedsAddPath(ExpandConstant('{app}'))
|
||||
|
||||
[Code]
|
||||
function NeedsAddPath(Path: string): boolean;
|
||||
@@ -85,4 +85,68 @@ begin
|
||||
Result := False
|
||||
else
|
||||
Result := True;
|
||||
end;
|
||||
|
||||
procedure RemovePathEntry(PathToRemove: string);
|
||||
var
|
||||
OrigPath: string;
|
||||
NewPath: string;
|
||||
PathList: TStringList;
|
||||
I: Integer;
|
||||
begin
|
||||
if not RegQueryStringValue(HKEY_LOCAL_MACHINE,
|
||||
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||
'Path', OrigPath)
|
||||
then begin
|
||||
// Path variable doesn't exist, nothing to remove
|
||||
exit;
|
||||
end;
|
||||
|
||||
// Create a string list to parse the PATH entries
|
||||
PathList := TStringList.Create;
|
||||
try
|
||||
// Split the PATH by semicolons
|
||||
PathList.Delimiter := ';';
|
||||
PathList.StrictDelimiter := True;
|
||||
PathList.DelimitedText := OrigPath;
|
||||
|
||||
// Find and remove the matching entry (case-insensitive)
|
||||
for I := PathList.Count - 1 downto 0 do
|
||||
begin
|
||||
if CompareText(Trim(PathList[I]), Trim(PathToRemove)) = 0 then
|
||||
begin
|
||||
Log('Found and removing PATH entry: ' + PathList[I]);
|
||||
PathList.Delete(I);
|
||||
end;
|
||||
end;
|
||||
|
||||
// Reconstruct the PATH
|
||||
NewPath := PathList.DelimitedText;
|
||||
|
||||
// Write the new PATH back to the registry
|
||||
if RegWriteExpandStringValue(HKEY_LOCAL_MACHINE,
|
||||
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||
'Path', NewPath)
|
||||
then
|
||||
Log('Successfully removed path entry: ' + PathToRemove)
|
||||
else
|
||||
Log('Failed to write modified PATH to registry');
|
||||
finally
|
||||
PathList.Free;
|
||||
end;
|
||||
end;
|
||||
|
||||
procedure CurUninstallStepChanged(CurUninstallStep: TUninstallStep);
|
||||
var
|
||||
AppPath: string;
|
||||
begin
|
||||
if CurUninstallStep = usUninstall then
|
||||
begin
|
||||
// Get the application installation path
|
||||
AppPath := ExpandConstant('{app}');
|
||||
Log('Removing PATH entry for: ' + AppPath);
|
||||
|
||||
// Remove only our path entry from the system PATH
|
||||
RemovePathEntry(AppPath);
|
||||
end;
|
||||
end;
|
||||
885
olm/common.go
885
olm/common.go
@@ -1,885 +0,0 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/exp/rand"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type WgData struct {
|
||||
Sites []SiteConfig `json:"sites"`
|
||||
TunnelIP string `json:"tunnelIP"`
|
||||
}
|
||||
|
||||
type SiteConfig struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
ServerIP string `json:"serverIP"`
|
||||
ServerPort uint16 `json:"serverPort"`
|
||||
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||
}
|
||||
|
||||
type TargetsByType struct {
|
||||
UDP []string `json:"udp"`
|
||||
TCP []string `json:"tcp"`
|
||||
}
|
||||
|
||||
type TargetData struct {
|
||||
Targets []string `json:"targets"`
|
||||
}
|
||||
|
||||
type HolePunchMessage struct {
|
||||
NewtID string `json:"newtId"`
|
||||
}
|
||||
|
||||
type ExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
type HolePunchData struct {
|
||||
ExitNodes []ExitNode `json:"exitNodes"`
|
||||
}
|
||||
|
||||
type EncryptedHolePunchMessage struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}
|
||||
|
||||
var (
|
||||
peerMonitor *peermonitor.PeerMonitor
|
||||
stopHolepunch chan struct{}
|
||||
stopRegister func()
|
||||
stopPing chan struct{}
|
||||
olmToken string
|
||||
holePunchRunning bool
|
||||
)
|
||||
|
||||
const (
|
||||
ENV_WG_TUN_FD = "WG_TUN_FD"
|
||||
ENV_WG_UAPI_FD = "WG_UAPI_FD"
|
||||
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
|
||||
)
|
||||
|
||||
// PeerAction represents a request to add, update, or remove a peer
|
||||
type PeerAction struct {
|
||||
Action string `json:"action"` // "add", "update", or "remove"
|
||||
SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information
|
||||
}
|
||||
|
||||
// UpdatePeerData represents the data needed to update a peer
|
||||
type UpdatePeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
ServerIP string `json:"serverIP"`
|
||||
ServerPort uint16 `json:"serverPort"`
|
||||
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||
}
|
||||
|
||||
// AddPeerData represents the data needed to add a peer
|
||||
type AddPeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
ServerIP string `json:"serverIP"`
|
||||
ServerPort uint16 `json:"serverPort"`
|
||||
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||
}
|
||||
|
||||
// RemovePeerData represents the data needed to remove a peer
|
||||
type RemovePeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
}
|
||||
|
||||
type RelayPeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
// Helper function to format endpoints correctly
|
||||
func formatEndpoint(endpoint string) string {
|
||||
if endpoint == "" {
|
||||
return ""
|
||||
}
|
||||
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
|
||||
_, _, err := net.SplitHostPort(endpoint)
|
||||
if err == nil {
|
||||
return endpoint // Already valid, no change needed
|
||||
}
|
||||
|
||||
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
|
||||
lastColon := strings.LastIndex(endpoint, ":")
|
||||
if lastColon > 0 { // Ensure there is a colon and it's not the first character
|
||||
hostPart := endpoint[:lastColon]
|
||||
// Check if the host part is a literal IPv6 address
|
||||
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
|
||||
// It is! Reformat it with brackets.
|
||||
portPart := endpoint[lastColon+1:]
|
||||
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
|
||||
}
|
||||
}
|
||||
|
||||
// If it's not the specific malformed case, return it as is.
|
||||
return endpoint
|
||||
}
|
||||
|
||||
func fixKey(key string) string {
|
||||
// Remove any whitespace
|
||||
key = strings.TrimSpace(key)
|
||||
|
||||
// Decode from base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
logger.Fatal("Error decoding base64")
|
||||
}
|
||||
|
||||
// Convert to hex
|
||||
return hex.EncodeToString(decoded)
|
||||
}
|
||||
|
||||
func parseLogLevel(level string) logger.LogLevel {
|
||||
switch strings.ToUpper(level) {
|
||||
case "DEBUG":
|
||||
return logger.DEBUG
|
||||
case "INFO":
|
||||
return logger.INFO
|
||||
case "WARN":
|
||||
return logger.WARN
|
||||
case "ERROR":
|
||||
return logger.ERROR
|
||||
case "FATAL":
|
||||
return logger.FATAL
|
||||
default:
|
||||
return logger.INFO // default to INFO if invalid level provided
|
||||
}
|
||||
}
|
||||
|
||||
func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
||||
switch level {
|
||||
case logger.DEBUG:
|
||||
return device.LogLevelVerbose
|
||||
// case logger.INFO:
|
||||
// return device.LogLevel
|
||||
case logger.WARN:
|
||||
return device.LogLevelError
|
||||
case logger.ERROR, logger.FATAL:
|
||||
return device.LogLevelSilent
|
||||
default:
|
||||
return device.LogLevelSilent
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveDomain(domain string) (string, error) {
|
||||
// First handle any protocol prefix
|
||||
domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://")
|
||||
|
||||
// if there are any trailing slashes, remove them
|
||||
domain = strings.TrimSuffix(domain, "/")
|
||||
|
||||
// Now split host and port
|
||||
host, port, err := net.SplitHostPort(domain)
|
||||
if err != nil {
|
||||
// No port found, use the domain as is
|
||||
host = domain
|
||||
port = ""
|
||||
}
|
||||
|
||||
// Lookup IP addresses
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||
}
|
||||
|
||||
// Get the first IPv4 address if available
|
||||
var ipAddr string
|
||||
for _, ip := range ips {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
ipAddr = ipv4.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no IPv4 found, use the first IP (might be IPv6)
|
||||
if ipAddr == "" {
|
||||
ipAddr = ips[0].String()
|
||||
}
|
||||
|
||||
// Add port back if it existed
|
||||
if port != "" {
|
||||
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||
}
|
||||
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
if maxPort < minPort {
|
||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
// Create a slice of all ports in the range
|
||||
portRange := make([]uint16, maxPort-minPort+1)
|
||||
for i := range portRange {
|
||||
portRange[i] = minPort + uint16(i)
|
||||
}
|
||||
|
||||
// Fisher-Yates shuffle to randomize the port order
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
for i := len(portRange) - 1; i > 0; i-- {
|
||||
j := rand.Intn(i + 1)
|
||||
portRange[i], portRange[j] = portRange[j], portRange[i]
|
||||
}
|
||||
|
||||
// Try each port in the randomized order
|
||||
for _, port := range portRange {
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: int(port),
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
continue // Port is in use or there was an error, try next port
|
||||
}
|
||||
_ = conn.SetDeadline(time.Now())
|
||||
conn.Close()
|
||||
return port, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
func sendPing(olm *websocket.Client) error {
|
||||
err := olm.SendMessage("olm/ping", map[string]interface{}{
|
||||
"timestamp": time.Now().Unix(),
|
||||
"userToken": olm.GetConfig().UserToken,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send ping message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Debug("Sent ping message")
|
||||
return nil
|
||||
}
|
||||
|
||||
func keepSendingPing(olm *websocket.Client) {
|
||||
// Send ping immediately on startup
|
||||
if err := sendPing(olm); err != nil {
|
||||
logger.Error("Failed to send initial ping: %v", err)
|
||||
} else {
|
||||
logger.Info("Sent initial ping message")
|
||||
}
|
||||
|
||||
// Set up ticker for one minute intervals
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopPing:
|
||||
logger.Info("Stopping ping messages")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := sendPing(olm); err != nil {
|
||||
logger.Error("Failed to send periodic ping: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
|
||||
siteHost, err := ResolveDomain(siteConfig.Endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||
}
|
||||
|
||||
// Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP
|
||||
allowedIp := strings.Split(siteConfig.ServerIP, "/")
|
||||
if len(allowedIp) > 1 {
|
||||
allowedIp[1] = "32"
|
||||
} else {
|
||||
allowedIp = append(allowedIp, "32")
|
||||
}
|
||||
allowedIpStr := strings.Join(allowedIp, "/")
|
||||
|
||||
// Collect all allowed IPs in a slice
|
||||
var allowedIPs []string
|
||||
allowedIPs = append(allowedIPs, allowedIpStr)
|
||||
|
||||
// If we have anything in remoteSubnets, add those as well
|
||||
if siteConfig.RemoteSubnets != "" {
|
||||
// Split remote subnets by comma and add each one
|
||||
remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",")
|
||||
for _, subnet := range remoteSubnets {
|
||||
subnet = strings.TrimSpace(subnet)
|
||||
if subnet != "" {
|
||||
allowedIPs = append(allowedIPs, subnet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct WireGuard config for this peer
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String())))
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(siteConfig.PublicKey)))
|
||||
|
||||
// Add each allowed IP separately
|
||||
for _, allowedIP := range allowedIPs {
|
||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||
}
|
||||
|
||||
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
||||
configBuilder.WriteString("persistent_keepalive_interval=1\n")
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Configuring peer with config: %s", config)
|
||||
|
||||
err = dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
// Set up peer monitoring
|
||||
if peerMonitor != nil {
|
||||
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||
|
||||
primaryRelay, err := ResolveDomain(endpoint) // Using global endpoint variable
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||
}
|
||||
|
||||
wgConfig := &peermonitor.WireGuardConfig{
|
||||
SiteID: siteConfig.SiteId,
|
||||
PublicKey: fixKey(siteConfig.PublicKey),
|
||||
ServerIP: strings.Split(siteConfig.ServerIP, "/")[0],
|
||||
Endpoint: siteConfig.Endpoint,
|
||||
PrimaryRelay: primaryRelay,
|
||||
}
|
||||
|
||||
err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
|
||||
} else {
|
||||
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeer removes a peer from the WireGuard device
|
||||
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
|
||||
// Construct WireGuard config to remove the peer
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(publicKey)))
|
||||
configBuilder.WriteString("remove=true\n")
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Removing peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
// Stop monitoring this peer
|
||||
if peerMonitor != nil {
|
||||
peerMonitor.RemovePeer(siteId)
|
||||
logger.Info("Stopped monitoring for site %d", siteId)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigureInterface configures a network interface with an IP address and brings it up
|
||||
func ConfigureInterface(interfaceName string, wgData WgData) error {
|
||||
var ipAddr string = wgData.TunnelIP
|
||||
|
||||
// Parse the IP address and network
|
||||
ip, ipNet, err := net.ParseCIDR(ipAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid IP address: %v", err)
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
return configureLinux(interfaceName, ip, ipNet)
|
||||
case "darwin":
|
||||
return configureDarwin(interfaceName, ip, ipNet)
|
||||
case "windows":
|
||||
return configureWindows(interfaceName, ip, ipNet)
|
||||
default:
|
||||
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
logger.Info("Configuring Windows interface: %s", interfaceName)
|
||||
|
||||
// Calculate mask string (e.g., 255.255.255.0)
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
mask := net.CIDRMask(maskBits, 32)
|
||||
maskIP := net.IP(mask)
|
||||
|
||||
// Set the IP address using netsh
|
||||
cmd := exec.Command("netsh", "interface", "ipv4", "set", "address",
|
||||
fmt.Sprintf("name=%s", interfaceName),
|
||||
"source=static",
|
||||
fmt.Sprintf("addr=%s", ip.String()),
|
||||
fmt.Sprintf("mask=%s", maskIP.String()))
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("netsh command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
// Bring up the interface if needed (in Windows, setting the IP usually brings it up)
|
||||
// But we'll explicitly enable it to be sure
|
||||
cmd = exec.Command("netsh", "interface", "set", "interface",
|
||||
interfaceName,
|
||||
"admin=enable")
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
out, err = cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
// delay 2 seconds
|
||||
time.Sleep(8 * time.Second)
|
||||
|
||||
// Wait for the interface to be up and have the correct IP
|
||||
err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interface did not come up within timeout: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForInterfaceUp polls the network interface until it's up or times out
|
||||
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
|
||||
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||
deadline := time.Now().Add(timeout)
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
// Check if interface exists and is up
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err == nil {
|
||||
// Check if interface is up
|
||||
if iface.Flags&net.FlagUp != 0 {
|
||||
// Check if it has the expected IP
|
||||
addrs, err := iface.Addrs()
|
||||
if err == nil {
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if ok && ipNet.IP.Equal(expectedIP) {
|
||||
logger.Info("Interface %s is up with correct IP", interfaceName)
|
||||
return nil // Interface is up with correct IP
|
||||
}
|
||||
}
|
||||
logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Interface %s exists but is not up yet", interfaceName)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Interface %s not found yet: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
// Wait before next check
|
||||
time.Sleep(pollInterval)
|
||||
}
|
||||
|
||||
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||
}
|
||||
|
||||
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
|
||||
// Parse destination to get the IP and subnet
|
||||
ip, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Calculate the subnet mask
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
mask := net.CIDRMask(maskBits, 32)
|
||||
maskIP := net.IP(mask)
|
||||
|
||||
if gateway != "" {
|
||||
// Route with specific gateway
|
||||
cmd = exec.Command("route", "add",
|
||||
ip.String(),
|
||||
"mask", maskIP.String(),
|
||||
gateway,
|
||||
"metric", "1")
|
||||
} else if interfaceName != "" {
|
||||
// First, get the interface index
|
||||
indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces")
|
||||
output, err := indexCmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface index: %v, output: %s", err, output)
|
||||
}
|
||||
|
||||
// Parse the output to find the interface index
|
||||
lines := strings.Split(string(output), "\n")
|
||||
var ifIndex string
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, interfaceName) {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) > 0 {
|
||||
ifIndex = fields[0]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ifIndex == "" {
|
||||
return fmt.Errorf("could not find index for interface %s", interfaceName)
|
||||
}
|
||||
|
||||
// Convert to integer to validate
|
||||
idx, err := strconv.Atoi(ifIndex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid interface index: %v", err)
|
||||
}
|
||||
|
||||
// Route via interface using the index
|
||||
cmd = exec.Command("route", "add",
|
||||
ip.String(),
|
||||
"mask", maskIP.String(),
|
||||
"0.0.0.0",
|
||||
"if", strconv.Itoa(idx))
|
||||
} else {
|
||||
return fmt.Errorf("either gateway or interface must be specified")
|
||||
}
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func WindowsRemoveRoute(destination string) error {
|
||||
// Parse destination to get the IP
|
||||
ip, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Calculate the subnet mask
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
mask := net.CIDRMask(maskBits, 32)
|
||||
maskIP := net.IP(mask)
|
||||
|
||||
cmd := exec.Command("route", "delete",
|
||||
ip.String(),
|
||||
"mask", maskIP.String())
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func findUnusedUTUN() (string, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to list interfaces: %v", err)
|
||||
}
|
||||
used := make(map[int]bool)
|
||||
re := regexp.MustCompile(`^utun(\d+)$`)
|
||||
for _, iface := range ifaces {
|
||||
if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 {
|
||||
if num, err := strconv.Atoi(matches[1]); err == nil {
|
||||
used[num] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try utun0 up to utun255.
|
||||
for i := 0; i < 256; i++ {
|
||||
if !used[i] {
|
||||
return fmt.Sprintf("utun%d", i), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no unused utun interface found")
|
||||
}
|
||||
|
||||
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
logger.Info("Configuring darwin interface: %s", interfaceName)
|
||||
|
||||
prefix, _ := ipNet.Mask.Size()
|
||||
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
|
||||
|
||||
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
cmd = exec.Command("ifconfig", interfaceName, "up")
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err = cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
// Get the interface
|
||||
link, err := netlink.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
// Create the IP address attributes
|
||||
addr := &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: ip,
|
||||
Mask: ipNet.Mask,
|
||||
},
|
||||
}
|
||||
|
||||
// Add the IP address to the interface
|
||||
if err := netlink.AddrAdd(link, addr); err != nil {
|
||||
return fmt.Errorf("failed to add IP address: %v", err)
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
if err := netlink.LinkSetUp(link); err != nil {
|
||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
|
||||
if gateway != "" {
|
||||
// Route with specific gateway
|
||||
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway)
|
||||
} else if interfaceName != "" {
|
||||
// Route via interface
|
||||
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName)
|
||||
} else {
|
||||
return fmt.Errorf("either gateway or interface must be specified")
|
||||
}
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DarwinRemoveRoute(destination string) error {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination)
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
|
||||
if gateway != "" {
|
||||
// Route with specific gateway
|
||||
cmd = exec.Command("ip", "route", "add", destination, "via", gateway)
|
||||
} else if interfaceName != "" {
|
||||
// Route via interface
|
||||
cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName)
|
||||
} else {
|
||||
return fmt.Errorf("either gateway or interface must be specified")
|
||||
}
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ip route command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func LinuxRemoveRoute(destination string) error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.Command("ip", "route", "del", destination)
|
||||
logger.Info("Running command: %v", cmd)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRouteForServerIP adds an OS-specific route for the server IP
|
||||
func addRouteForServerIP(serverIP, interfaceName string) error {
|
||||
if runtime.GOOS == "darwin" {
|
||||
return DarwinAddRoute(serverIP, "", interfaceName)
|
||||
}
|
||||
// else if runtime.GOOS == "windows" {
|
||||
// return WindowsAddRoute(serverIP, "", interfaceName)
|
||||
// } else if runtime.GOOS == "linux" {
|
||||
// return LinuxAddRoute(serverIP, "", interfaceName)
|
||||
// }
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRouteForServerIP removes an OS-specific route for the server IP
|
||||
func removeRouteForServerIP(serverIP string) error {
|
||||
if runtime.GOOS == "darwin" {
|
||||
return DarwinRemoveRoute(serverIP)
|
||||
}
|
||||
// else if runtime.GOOS == "windows" {
|
||||
// return WindowsRemoveRoute(serverIP)
|
||||
// } else if runtime.GOOS == "linux" {
|
||||
// return LinuxRemoveRoute(serverIP)
|
||||
// }
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets
|
||||
func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error {
|
||||
if remoteSubnets == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Split remote subnets by comma and add routes for each one
|
||||
subnets := strings.Split(remoteSubnets, ",")
|
||||
for _, subnet := range subnets {
|
||||
subnet = strings.TrimSpace(subnet)
|
||||
if subnet == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Add route based on operating system
|
||||
if runtime.GOOS == "darwin" {
|
||||
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
|
||||
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
} else if runtime.GOOS == "windows" {
|
||||
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
|
||||
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
} else if runtime.GOOS == "linux" {
|
||||
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
|
||||
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Added route for remote subnet: %s", subnet)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets
|
||||
func removeRoutesForRemoteSubnets(remoteSubnets string) error {
|
||||
if remoteSubnets == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Split remote subnets by comma and remove routes for each one
|
||||
subnets := strings.Split(remoteSubnets, ",")
|
||||
for _, subnet := range subnets {
|
||||
subnet = strings.TrimSpace(subnet)
|
||||
if subnet == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove route based on operating system
|
||||
if runtime.GOOS == "darwin" {
|
||||
if err := DarwinRemoveRoute(subnet); err != nil {
|
||||
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
} else if runtime.GOOS == "windows" {
|
||||
if err := WindowsRemoveRoute(subnet); err != nil {
|
||||
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
} else if runtime.GOOS == "linux" {
|
||||
if err := LinuxRemoveRoute(subnet); err != nil {
|
||||
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Removed route for remote subnet: %s", subnet)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
1230
olm/olm.go
1230
olm/olm.go
File diff suppressed because it is too large
Load Diff
66
olm/types.go
Normal file
66
olm/types.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/olm/peers"
|
||||
)
|
||||
|
||||
type WgData struct {
|
||||
Sites []peers.SiteConfig `json:"sites"`
|
||||
TunnelIP string `json:"tunnelIP"`
|
||||
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
|
||||
}
|
||||
|
||||
type GlobalConfig struct {
|
||||
// Logging
|
||||
LogLevel string
|
||||
|
||||
// HTTP server
|
||||
EnableAPI bool
|
||||
HTTPAddr string
|
||||
SocketPath string
|
||||
Version string
|
||||
Agent string
|
||||
|
||||
// Callbacks
|
||||
OnRegistered func()
|
||||
OnConnected func()
|
||||
OnTerminated func()
|
||||
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
|
||||
OnExit func() // Called when exit is requested via API
|
||||
}
|
||||
|
||||
type TunnelConfig struct {
|
||||
// Connection settings
|
||||
Endpoint string
|
||||
ID string
|
||||
Secret string
|
||||
UserToken string
|
||||
|
||||
// Network settings
|
||||
MTU int
|
||||
DNS string
|
||||
UpstreamDNS []string
|
||||
InterfaceName string
|
||||
|
||||
// Advanced
|
||||
Holepunch bool
|
||||
TlsClientCert string
|
||||
|
||||
// Parsed values (not in JSON)
|
||||
PingIntervalDuration time.Duration
|
||||
PingTimeoutDuration time.Duration
|
||||
|
||||
OrgID string
|
||||
// DoNotCreateNewClient bool
|
||||
|
||||
FileDescriptorTun uint32
|
||||
FileDescriptorUAPI uint32
|
||||
|
||||
EnableUAPI bool
|
||||
|
||||
OverrideDNS bool
|
||||
|
||||
DisableRelay bool
|
||||
}
|
||||
35
olm/unix.go
35
olm/unix.go
@@ -1,35 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
|
||||
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(int(fd), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "")
|
||||
return tun.CreateTUNFromFile(file, mtuInt)
|
||||
}
|
||||
func uapiOpen(interfaceName string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(interfaceName)
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||
}
|
||||
98
olm/util.go
Normal file
98
olm/util.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
)
|
||||
|
||||
// Helper function to format endpoints correctly
|
||||
func formatEndpoint(endpoint string) string {
|
||||
if endpoint == "" {
|
||||
return ""
|
||||
}
|
||||
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
|
||||
_, _, err := net.SplitHostPort(endpoint)
|
||||
if err == nil {
|
||||
return endpoint // Already valid, no change needed
|
||||
}
|
||||
|
||||
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
|
||||
lastColon := strings.LastIndex(endpoint, ":")
|
||||
if lastColon > 0 { // Ensure there is a colon and it's not the first character
|
||||
hostPart := endpoint[:lastColon]
|
||||
// Check if the host part is a literal IPv6 address
|
||||
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
|
||||
// It is! Reformat it with brackets.
|
||||
portPart := endpoint[lastColon+1:]
|
||||
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
|
||||
}
|
||||
}
|
||||
|
||||
// If it's not the specific malformed case, return it as is.
|
||||
return endpoint
|
||||
}
|
||||
|
||||
func sendPing(olm *websocket.Client) error {
|
||||
err := olm.SendMessage("olm/ping", map[string]interface{}{
|
||||
"timestamp": time.Now().Unix(),
|
||||
"userToken": olm.GetConfig().UserToken,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send ping message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Debug("Sent ping message")
|
||||
return nil
|
||||
}
|
||||
|
||||
func keepSendingPing(olm *websocket.Client) {
|
||||
// Send ping immediately on startup
|
||||
if err := sendPing(olm); err != nil {
|
||||
logger.Error("Failed to send initial ping: %v", err)
|
||||
} else {
|
||||
logger.Info("Sent initial ping message")
|
||||
}
|
||||
|
||||
// Set up ticker for one minute intervals
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopPing:
|
||||
logger.Info("Stopping ping messages")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := sendPing(olm); err != nil {
|
||||
logger.Error("Failed to send periodic ping: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetNetworkSettingsJSON() (string, error) {
|
||||
return network.GetJSON()
|
||||
}
|
||||
|
||||
func GetNetworkSettingsIncrementor() int {
|
||||
return network.GetIncrementor()
|
||||
}
|
||||
|
||||
// stringSlicesEqual compares two string slices for equality
|
||||
func stringSlicesEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,331 +0,0 @@
|
||||
package peermonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"github.com/fosrl/olm/wgtester"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
// PeerMonitorCallback is the function type for connection status change callbacks
|
||||
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
|
||||
|
||||
// WireGuardConfig holds the WireGuard configuration for a peer
|
||||
type WireGuardConfig struct {
|
||||
SiteID int
|
||||
PublicKey string
|
||||
ServerIP string
|
||||
Endpoint string
|
||||
PrimaryRelay string // The primary relay endpoint
|
||||
}
|
||||
|
||||
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||
type PeerMonitor struct {
|
||||
monitors map[int]*wgtester.Client
|
||||
configs map[int]*WireGuardConfig
|
||||
callback PeerMonitorCallback
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
privateKey string
|
||||
wsClient *websocket.Client
|
||||
device *device.Device
|
||||
handleRelaySwitch bool // Whether to handle relay switching
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
|
||||
return &PeerMonitor{
|
||||
monitors: make(map[int]*wgtester.Client),
|
||||
configs: make(map[int]*WireGuardConfig),
|
||||
callback: callback,
|
||||
interval: 1 * time.Second, // Default check interval
|
||||
timeout: 2500 * time.Millisecond,
|
||||
maxAttempts: 8,
|
||||
privateKey: privateKey,
|
||||
wsClient: wsClient,
|
||||
device: device,
|
||||
handleRelaySwitch: handleRelaySwitch,
|
||||
}
|
||||
}
|
||||
|
||||
// SetInterval changes how frequently peers are checked
|
||||
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.interval = interval
|
||||
|
||||
// Update interval for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetPacketInterval(interval)
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeout changes the timeout for waiting for responses
|
||||
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.timeout = timeout
|
||||
|
||||
// Update timeout for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetTimeout(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.maxAttempts = attempts
|
||||
|
||||
// Update max attempts for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetMaxAttempts(attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeer adds a new peer to monitor
|
||||
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// Check if we're already monitoring this peer
|
||||
if _, exists := pm.monitors[siteID]; exists {
|
||||
// Update the endpoint instead of creating a new monitor
|
||||
pm.removePeerUnlocked(siteID)
|
||||
}
|
||||
|
||||
client, err := wgtester.NewClient(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure the client with our settings
|
||||
client.SetPacketInterval(pm.interval)
|
||||
client.SetTimeout(pm.timeout)
|
||||
client.SetMaxAttempts(pm.maxAttempts)
|
||||
|
||||
// Store the client and config
|
||||
pm.monitors[siteID] = client
|
||||
pm.configs[siteID] = wgConfig
|
||||
|
||||
// If monitor is already running, start monitoring this peer
|
||||
if pm.running {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||
})
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
|
||||
// This function assumes the mutex is already held by the caller
|
||||
func (pm *PeerMonitor) removePeerUnlocked(siteID int) {
|
||||
client, exists := pm.monitors[siteID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
delete(pm.configs, siteID)
|
||||
}
|
||||
|
||||
// RemovePeer stops monitoring a peer and removes it from the monitor
|
||||
func (pm *PeerMonitor) RemovePeer(siteID int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.removePeerUnlocked(siteID)
|
||||
}
|
||||
|
||||
// Start begins monitoring all peers
|
||||
func (pm *PeerMonitor) Start() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if pm.running {
|
||||
return // Already running
|
||||
}
|
||||
|
||||
pm.running = true
|
||||
|
||||
// Start monitoring all peers
|
||||
for siteID, client := range pm.monitors {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err)
|
||||
continue
|
||||
}
|
||||
logger.Info("Started monitoring peer %d\n", siteID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnectionStatusChange is called when a peer's connection status changes
|
||||
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status wgtester.ConnectionStatus) {
|
||||
// Call the user-provided callback first
|
||||
if pm.callback != nil {
|
||||
pm.callback(siteID, status.Connected, status.RTT)
|
||||
}
|
||||
|
||||
// If disconnected, handle failover
|
||||
if !status.Connected {
|
||||
// Send relay message to the server
|
||||
if pm.wsClient != nil {
|
||||
pm.sendRelay(siteID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleFailover handles failover to the relay server when a peer is disconnected
|
||||
func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) {
|
||||
pm.mutex.Lock()
|
||||
config, exists := pm.configs[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for IPv6 and format the endpoint correctly
|
||||
formattedEndpoint := relayEndpoint
|
||||
if strings.Contains(relayEndpoint, ":") {
|
||||
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
|
||||
}
|
||||
|
||||
// Configure WireGuard to use the relay
|
||||
wgConfig := fmt.Sprintf(`private_key=%s
|
||||
public_key=%s
|
||||
allowed_ip=%s/32
|
||||
endpoint=%s:21820
|
||||
persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, formattedEndpoint)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
logger.Error("Failed to configure WireGuard device: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Adjusted peer %d to point to relay!\n", siteID)
|
||||
}
|
||||
|
||||
// sendRelay sends a relay message to the server
|
||||
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||
if !pm.handleRelaySwitch {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent relay message")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops monitoring all peers
|
||||
func (pm *PeerMonitor) Stop() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
|
||||
// Stop all monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.StopMonitor()
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops monitoring and cleans up resources
|
||||
func (pm *PeerMonitor) Close() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// Stop and close all clients
|
||||
for siteID, client := range pm.monitors {
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
}
|
||||
|
||||
// TestPeer tests connectivity to a specific peer
|
||||
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||
pm.mutex.Lock()
|
||||
client, exists := pm.monitors[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
defer cancel()
|
||||
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
return connected, rtt, nil
|
||||
}
|
||||
|
||||
// TestAllPeers tests connectivity to all peers
|
||||
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
} {
|
||||
pm.mutex.Lock()
|
||||
peers := make(map[int]*wgtester.Client, len(pm.monitors))
|
||||
for siteID, client := range pm.monitors {
|
||||
peers[siteID] = client
|
||||
}
|
||||
pm.mutex.Unlock()
|
||||
|
||||
results := make(map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
})
|
||||
for siteID, client := range peers {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
cancel()
|
||||
|
||||
results[siteID] = struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
}{
|
||||
Connected: connected,
|
||||
RTT: rtt,
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
884
peers/manager.go
Normal file
884
peers/manager.go
Normal file
@@ -0,0 +1,884 @@
|
||||
package peers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/api"
|
||||
olmDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/dns"
|
||||
"github.com/fosrl/olm/peers/monitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// PeerManagerConfig contains the configuration for creating a PeerManager
|
||||
type PeerManagerConfig struct {
|
||||
Device *device.Device
|
||||
DNSProxy *dns.DNSProxy
|
||||
InterfaceName string
|
||||
PrivateKey wgtypes.Key
|
||||
// For peer monitoring
|
||||
MiddleDev *olmDevice.MiddleDevice
|
||||
LocalIP string
|
||||
SharedBind *bind.SharedBind
|
||||
// WSClient is optional - if nil, relay messages won't be sent
|
||||
WSClient *websocket.Client
|
||||
APIServer *api.API
|
||||
}
|
||||
|
||||
type PeerManager struct {
|
||||
mu sync.RWMutex
|
||||
device *device.Device
|
||||
peers map[int]SiteConfig
|
||||
peerMonitor *monitor.PeerMonitor
|
||||
dnsProxy *dns.DNSProxy
|
||||
interfaceName string
|
||||
privateKey wgtypes.Key
|
||||
// allowedIPOwners tracks which peer currently "owns" each allowed IP in WireGuard
|
||||
// key is the CIDR string, value is the siteId that has it configured in WG
|
||||
allowedIPOwners map[string]int
|
||||
// allowedIPClaims tracks all peers that claim each allowed IP
|
||||
// key is the CIDR string, value is a set of siteIds that want this IP
|
||||
allowedIPClaims map[string]map[int]bool
|
||||
APIServer *api.API
|
||||
}
|
||||
|
||||
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
|
||||
func NewPeerManager(config PeerManagerConfig) *PeerManager {
|
||||
pm := &PeerManager{
|
||||
device: config.Device,
|
||||
peers: make(map[int]SiteConfig),
|
||||
dnsProxy: config.DNSProxy,
|
||||
interfaceName: config.InterfaceName,
|
||||
privateKey: config.PrivateKey,
|
||||
allowedIPOwners: make(map[string]int),
|
||||
allowedIPClaims: make(map[string]map[int]bool),
|
||||
APIServer: config.APIServer,
|
||||
}
|
||||
|
||||
// Create the peer monitor
|
||||
pm.peerMonitor = monitor.NewPeerMonitor(
|
||||
config.WSClient,
|
||||
config.MiddleDev,
|
||||
config.LocalIP,
|
||||
config.SharedBind,
|
||||
config.APIServer,
|
||||
)
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
peer, ok := pm.peers[siteId]
|
||||
return peer, ok
|
||||
}
|
||||
|
||||
func (pm *PeerManager) GetAllPeers() []SiteConfig {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
peers := make([]SiteConfig, 0, len(pm.peers))
|
||||
for _, peer := range pm.peers {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
// build the allowed IPs list from the remote subnets and aliases and add them to the peer
|
||||
allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
|
||||
allowedIPs = append(allowedIPs, siteConfig.RemoteSubnets...)
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
allowedIPs = append(allowedIPs, alias.AliasAddress+"/32")
|
||||
}
|
||||
siteConfig.AllowedIps = allowedIPs
|
||||
|
||||
// Register claims for all allowed IPs and determine which ones this peer will own
|
||||
ownedIPs := make([]string, 0, len(allowedIPs))
|
||||
for _, ip := range allowedIPs {
|
||||
pm.claimAllowedIP(siteConfig.SiteId, ip)
|
||||
// Check if this peer became the owner
|
||||
if pm.allowedIPOwners[ip] == siteConfig.SiteId {
|
||||
ownedIPs = append(ownedIPs, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a config with only the owned IPs for WireGuard
|
||||
wgConfig := siteConfig
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := network.AddRouteForServerIP(siteConfig.ServerIP, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to add route for server IP: %v", err)
|
||||
}
|
||||
if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||
}
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||
|
||||
err := pm.peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, siteConfig.Endpoint) // always use the real site endpoint for hole punch monitoring
|
||||
if err != nil {
|
||||
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
|
||||
} else {
|
||||
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||
}
|
||||
|
||||
pm.peers[siteConfig.SiteId] = siteConfig
|
||||
|
||||
pm.APIServer.AddPeerStatus(siteConfig.SiteId, siteConfig.Name, false, 0, siteConfig.Endpoint, false)
|
||||
|
||||
// Perform rapid initial holepunch test (outside of lock to avoid blocking)
|
||||
// This quickly determines if holepunch is viable and triggers relay if not
|
||||
go pm.performRapidInitialTest(siteConfig.SiteId, siteConfig.Endpoint)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *PeerManager) RemovePeer(siteId int) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
if err := RemovePeer(pm.device, siteId, peer.PublicKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := network.RemoveRouteForServerIP(peer.ServerIP, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to remove route for server IP: %v", err)
|
||||
}
|
||||
|
||||
// Only remove routes for subnets that aren't used by other peers
|
||||
for _, subnet := range peer.RemoteSubnets {
|
||||
subnetStillInUse := false
|
||||
for otherSiteId, otherPeer := range pm.peers {
|
||||
if otherSiteId == siteId {
|
||||
continue // Skip the peer being removed
|
||||
}
|
||||
for _, otherSubnet := range otherPeer.RemoteSubnets {
|
||||
if otherSubnet == subnet {
|
||||
subnetStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if subnetStillInUse {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !subnetStillInUse {
|
||||
if err := network.RemoveRoutes([]string{subnet}); err != nil {
|
||||
logger.Error("Failed to remove route for remote subnet %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For aliases
|
||||
for _, alias := range peer.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.RemoveDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
// Release all IP claims and promote other peers as needed
|
||||
// Collect promotions first to avoid modifying while iterating
|
||||
type promotion struct {
|
||||
newOwner int
|
||||
cidr string
|
||||
}
|
||||
var promotions []promotion
|
||||
|
||||
for _, ip := range peer.AllowedIps {
|
||||
newOwner, promoted := pm.releaseAllowedIP(siteId, ip)
|
||||
if promoted && newOwner >= 0 {
|
||||
promotions = append(promotions, promotion{newOwner: newOwner, cidr: ip})
|
||||
}
|
||||
}
|
||||
|
||||
// Apply promotions - update WireGuard config for newly promoted peers
|
||||
// Group by peer to avoid multiple config updates
|
||||
promotedPeers := make(map[int]bool)
|
||||
for _, p := range promotions {
|
||||
promotedPeers[p.newOwner] = true
|
||||
logger.Info("Promoted peer %d to owner of IP %s", p.newOwner, p.cidr)
|
||||
}
|
||||
|
||||
for promotedPeerId := range promotedPeers {
|
||||
if promotedPeer, exists := pm.peers[promotedPeerId]; exists {
|
||||
// Build the list of IPs this peer now owns
|
||||
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||
wgConfig := promotedPeer
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
|
||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop monitoring this peer
|
||||
pm.peerMonitor.RemovePeer(siteId)
|
||||
logger.Info("Stopped monitoring for site %d", siteId)
|
||||
|
||||
pm.APIServer.RemovePeerStatus(siteId)
|
||||
|
||||
delete(pm.peers, siteId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
oldPeer, exists := pm.peers[siteConfig.SiteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteConfig.SiteId)
|
||||
}
|
||||
|
||||
// If public key changed, remove old peer first
|
||||
if siteConfig.PublicKey != oldPeer.PublicKey {
|
||||
if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil {
|
||||
logger.Error("Failed to remove old peer: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build the new allowed IPs list
|
||||
newAllowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
|
||||
newAllowedIPs = append(newAllowedIPs, siteConfig.RemoteSubnets...)
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
newAllowedIPs = append(newAllowedIPs, alias.AliasAddress+"/32")
|
||||
}
|
||||
siteConfig.AllowedIps = newAllowedIPs
|
||||
|
||||
// Handle allowed IP claim changes
|
||||
oldAllowedIPs := make(map[string]bool)
|
||||
for _, ip := range oldPeer.AllowedIps {
|
||||
oldAllowedIPs[ip] = true
|
||||
}
|
||||
newAllowedIPsSet := make(map[string]bool)
|
||||
for _, ip := range newAllowedIPs {
|
||||
newAllowedIPsSet[ip] = true
|
||||
}
|
||||
|
||||
// Track peers that need WireGuard config updates due to promotions
|
||||
peersToUpdate := make(map[int]bool)
|
||||
|
||||
// Release claims for removed IPs and handle promotions
|
||||
for ip := range oldAllowedIPs {
|
||||
if !newAllowedIPsSet[ip] {
|
||||
newOwner, promoted := pm.releaseAllowedIP(siteConfig.SiteId, ip)
|
||||
if promoted && newOwner >= 0 {
|
||||
peersToUpdate[newOwner] = true
|
||||
logger.Info("Promoted peer %d to owner of IP %s", newOwner, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add claims for new IPs
|
||||
for ip := range newAllowedIPsSet {
|
||||
if !oldAllowedIPs[ip] {
|
||||
pm.claimAllowedIP(siteConfig.SiteId, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Build the list of IPs this peer owns for WireGuard config
|
||||
ownedIPs := pm.getOwnedAllowedIPs(siteConfig.SiteId)
|
||||
wgConfig := siteConfig
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update WireGuard config for any promoted peers
|
||||
for promotedPeerId := range peersToUpdate {
|
||||
if promotedPeer, exists := pm.peers[promotedPeerId]; exists {
|
||||
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||
promotedWgConfig := promotedPeer
|
||||
promotedWgConfig.AllowedIps = promotedOwnedIPs
|
||||
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
|
||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle remote subnet route changes
|
||||
// Calculate added and removed subnets
|
||||
oldSubnets := make(map[string]bool)
|
||||
for _, s := range oldPeer.RemoteSubnets {
|
||||
oldSubnets[s] = true
|
||||
}
|
||||
newSubnets := make(map[string]bool)
|
||||
for _, s := range siteConfig.RemoteSubnets {
|
||||
newSubnets[s] = true
|
||||
}
|
||||
|
||||
var addedSubnets []string
|
||||
var removedSubnets []string
|
||||
|
||||
for s := range newSubnets {
|
||||
if !oldSubnets[s] {
|
||||
addedSubnets = append(addedSubnets, s)
|
||||
}
|
||||
}
|
||||
for s := range oldSubnets {
|
||||
if !newSubnets[s] {
|
||||
removedSubnets = append(removedSubnets, s)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove routes for removed subnets (only if no other peer needs them)
|
||||
for _, subnet := range removedSubnets {
|
||||
subnetStillInUse := false
|
||||
for otherSiteId, otherPeer := range pm.peers {
|
||||
if otherSiteId == siteConfig.SiteId {
|
||||
continue // Skip the current peer (already updated)
|
||||
}
|
||||
for _, otherSubnet := range otherPeer.RemoteSubnets {
|
||||
if otherSubnet == subnet {
|
||||
subnetStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if subnetStillInUse {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !subnetStillInUse {
|
||||
if err := network.RemoveRoutes([]string{subnet}); err != nil {
|
||||
logger.Error("Failed to remove route for subnet %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add routes for added subnets
|
||||
if len(addedSubnets) > 0 {
|
||||
if err := network.AddRoutes(addedSubnets, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to add routes: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update aliases
|
||||
// Remove old aliases
|
||||
for _, alias := range oldPeer.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.RemoveDNSRecord(alias.Alias, address)
|
||||
}
|
||||
// Add new aliases
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
|
||||
|
||||
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||
pm.peerMonitor.UpdatePeerEndpoint(siteConfig.SiteId, monitorPeer) // +1 for monitor port
|
||||
|
||||
pm.peers[siteConfig.SiteId] = siteConfig
|
||||
return nil
|
||||
}
|
||||
|
||||
// claimAllowedIP registers a peer's claim to an allowed IP.
|
||||
// If no other peer owns it in WireGuard, this peer becomes the owner.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) claimAllowedIP(siteId int, cidr string) {
|
||||
// Add to claims
|
||||
if pm.allowedIPClaims[cidr] == nil {
|
||||
pm.allowedIPClaims[cidr] = make(map[int]bool)
|
||||
}
|
||||
pm.allowedIPClaims[cidr][siteId] = true
|
||||
|
||||
// If no owner yet, this peer becomes the owner
|
||||
if _, hasOwner := pm.allowedIPOwners[cidr]; !hasOwner {
|
||||
pm.allowedIPOwners[cidr] = siteId
|
||||
}
|
||||
}
|
||||
|
||||
// releaseAllowedIP removes a peer's claim to an allowed IP.
|
||||
// If this peer was the owner, it promotes another claimant to owner.
|
||||
// Returns the new owner's siteId (or -1 if no new owner) and whether promotion occurred.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) releaseAllowedIP(siteId int, cidr string) (newOwner int, promoted bool) {
|
||||
// Remove from claims
|
||||
if claims, exists := pm.allowedIPClaims[cidr]; exists {
|
||||
delete(claims, siteId)
|
||||
if len(claims) == 0 {
|
||||
delete(pm.allowedIPClaims, cidr)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this peer was the owner
|
||||
owner, isOwned := pm.allowedIPOwners[cidr]
|
||||
if !isOwned || owner != siteId {
|
||||
return -1, false // Not the owner, nothing to promote
|
||||
}
|
||||
|
||||
// This peer was the owner, need to find a new owner
|
||||
delete(pm.allowedIPOwners, cidr)
|
||||
|
||||
// Find another claimant to promote
|
||||
if claims, exists := pm.allowedIPClaims[cidr]; exists && len(claims) > 0 {
|
||||
for claimantId := range claims {
|
||||
pm.allowedIPOwners[cidr] = claimantId
|
||||
return claimantId, true
|
||||
}
|
||||
}
|
||||
|
||||
return -1, false
|
||||
}
|
||||
|
||||
// getOwnedAllowedIPs returns the list of allowed IPs that a peer currently owns in WireGuard.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) getOwnedAllowedIPs(siteId int) []string {
|
||||
var owned []string
|
||||
for cidr, owner := range pm.allowedIPOwners {
|
||||
if owner == siteId {
|
||||
owned = append(owned, cidr)
|
||||
}
|
||||
}
|
||||
return owned
|
||||
}
|
||||
|
||||
// addAllowedIp adds an IP (subnet) to the allowed IPs list of a peer
|
||||
// and updates WireGuard configuration if this peer owns the IP.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) addAllowedIp(siteId int, ip string) error {
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
// Check if IP already exists in AllowedIps
|
||||
for _, allowedIp := range peer.AllowedIps {
|
||||
if allowedIp == ip {
|
||||
return nil // Already exists
|
||||
}
|
||||
}
|
||||
|
||||
// Register our claim to this IP
|
||||
pm.claimAllowedIP(siteId, ip)
|
||||
|
||||
peer.AllowedIps = append(peer.AllowedIps, ip)
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
// Only update WireGuard if we own this IP
|
||||
if pm.allowedIPOwners[ip] == siteId {
|
||||
if err := AddAllowedIP(pm.device, peer.PublicKey, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeAllowedIp removes an IP (subnet) from the allowed IPs list of a peer
|
||||
// and updates WireGuard configuration. If this peer owned the IP, it promotes
|
||||
// another peer that also claims this IP. Must be called with lock held.
|
||||
func (pm *PeerManager) removeAllowedIp(siteId int, cidr string) error {
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
found := false
|
||||
|
||||
// Remove from AllowedIps
|
||||
newAllowedIps := make([]string, 0, len(peer.AllowedIps))
|
||||
for _, allowedIp := range peer.AllowedIps {
|
||||
if allowedIp == cidr {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newAllowedIps = append(newAllowedIps, allowedIp)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil // Not found
|
||||
}
|
||||
|
||||
peer.AllowedIps = newAllowedIps
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
// Release our claim and check if we need to promote another peer
|
||||
newOwner, promoted := pm.releaseAllowedIP(siteId, cidr)
|
||||
|
||||
// Build the list of IPs this peer currently owns for the replace operation
|
||||
ownedIPs := pm.getOwnedAllowedIPs(siteId)
|
||||
// Also include the server IP which is always owned
|
||||
serverIP := strings.Split(peer.ServerIP, "/")[0] + "/32"
|
||||
hasServerIP := false
|
||||
for _, ip := range ownedIPs {
|
||||
if ip == serverIP {
|
||||
hasServerIP = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasServerIP {
|
||||
ownedIPs = append([]string{serverIP}, ownedIPs...)
|
||||
}
|
||||
|
||||
// Update WireGuard for this peer using replace_allowed_ips
|
||||
if err := RemoveAllowedIP(pm.device, peer.PublicKey, ownedIPs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If another peer was promoted to owner, add the IP to their WireGuard config
|
||||
if promoted && newOwner >= 0 {
|
||||
if newOwnerPeer, exists := pm.peers[newOwner]; exists {
|
||||
if err := AddAllowedIP(pm.device, newOwnerPeer.PublicKey, cidr); err != nil {
|
||||
logger.Error("Failed to promote peer %d for IP %s: %v", newOwner, cidr, err)
|
||||
} else {
|
||||
logger.Info("Promoted peer %d to owner of IP %s", newOwner, cidr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRemoteSubnet adds an IP (subnet) to the allowed IPs list of a peer
|
||||
func (pm *PeerManager) AddRemoteSubnet(siteId int, cidr string) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
// Check if IP already exists in RemoteSubnets
|
||||
for _, subnet := range peer.RemoteSubnets {
|
||||
if subnet == cidr {
|
||||
return nil // Already exists
|
||||
}
|
||||
}
|
||||
|
||||
peer.RemoteSubnets = append(peer.RemoteSubnets, cidr)
|
||||
pm.peers[siteId] = peer // Save before calling addAllowedIp which reads from pm.peers
|
||||
|
||||
// Add to allowed IPs
|
||||
if err := pm.addAllowedIp(siteId, cidr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add route
|
||||
if err := network.AddRoutes([]string{cidr}, pm.interfaceName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRemoteSubnet removes an IP (subnet) from the allowed IPs list of a peer
|
||||
func (pm *PeerManager) RemoveRemoteSubnet(siteId int, ip string) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
found := false
|
||||
|
||||
// Remove from RemoteSubnets
|
||||
newSubnets := make([]string, 0, len(peer.RemoteSubnets))
|
||||
for _, subnet := range peer.RemoteSubnets {
|
||||
if subnet == ip {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newSubnets = append(newSubnets, subnet)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil // Not found
|
||||
}
|
||||
|
||||
peer.RemoteSubnets = newSubnets
|
||||
pm.peers[siteId] = peer // Save before calling removeAllowedIp which reads from pm.peers
|
||||
|
||||
// Remove from allowed IPs (this also handles promotion of other peers)
|
||||
if err := pm.removeAllowedIp(siteId, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if any other peer still has this subnet before removing the route
|
||||
subnetStillInUse := false
|
||||
for otherSiteId, otherPeer := range pm.peers {
|
||||
if otherSiteId == siteId {
|
||||
continue // Skip the current peer (already updated above)
|
||||
}
|
||||
for _, subnet := range otherPeer.RemoteSubnets {
|
||||
if subnet == ip {
|
||||
subnetStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if subnetStillInUse {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Only remove route if no other peer needs it
|
||||
if !subnetStillInUse {
|
||||
if err := network.RemoveRoutes([]string{ip}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddAlias adds an alias to a peer
|
||||
func (pm *PeerManager) AddAlias(siteId int, alias Alias) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
peer.Aliases = append(peer.Aliases, alias)
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address != nil {
|
||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
// Add an allowed IP for the alias
|
||||
if err := pm.addAllowedIp(siteId, alias.AliasAddress+"/32"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAlias removes an alias from a peer
|
||||
func (pm *PeerManager) RemoveAlias(siteId int, aliasName string) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
var aliasToRemove *Alias
|
||||
newAliases := make([]Alias, 0, len(peer.Aliases))
|
||||
for _, a := range peer.Aliases {
|
||||
if a.Alias == aliasName {
|
||||
aliasToRemove = &a
|
||||
continue
|
||||
}
|
||||
newAliases = append(newAliases, a)
|
||||
}
|
||||
|
||||
if aliasToRemove != nil {
|
||||
address := net.ParseIP(aliasToRemove.AliasAddress)
|
||||
if address != nil {
|
||||
pm.dnsProxy.RemoveDNSRecord(aliasName, address)
|
||||
}
|
||||
}
|
||||
|
||||
peer.Aliases = newAliases
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
// Check if any other alias is still using this IP address before removing from allowed IPs
|
||||
ipStillInUse := false
|
||||
aliasIP := aliasToRemove.AliasAddress + "/32"
|
||||
for _, a := range newAliases {
|
||||
if a.AliasAddress+"/32" == aliasIP {
|
||||
ipStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Only remove the allowed IP if no other alias is using it
|
||||
if !ipStillInUse {
|
||||
if err := pm.removeAllowedIp(siteId, aliasIP); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelayPeer handles failover to the relay server when a peer is disconnected
|
||||
func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string) {
|
||||
pm.mu.Lock()
|
||||
peer, exists := pm.peers[siteId]
|
||||
if exists {
|
||||
// Store the relay endpoint
|
||||
peer.RelayEndpoint = relayEndpoint
|
||||
pm.peers[siteId] = peer
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for IPv6 and format the endpoint correctly
|
||||
formattedEndpoint := relayEndpoint
|
||||
if strings.Contains(relayEndpoint, ":") {
|
||||
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
|
||||
}
|
||||
|
||||
// Update only the endpoint for this peer (update_only preserves other settings)
|
||||
wgConfig := fmt.Sprintf(`public_key=%s
|
||||
update_only=true
|
||||
endpoint=%s:21820`, util.FixKey(peer.PublicKey), formattedEndpoint)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
logger.Error("Failed to configure WireGuard device: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the peer as relayed in the monitor
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.MarkPeerRelayed(siteId, true)
|
||||
}
|
||||
|
||||
logger.Info("Adjusted peer %d to point to relay!\n", siteId)
|
||||
}
|
||||
|
||||
// performRapidInitialTest performs a rapid holepunch test for a newly added peer.
|
||||
// If the test fails, it immediately requests relay to minimize connection delay.
|
||||
// This runs in a goroutine to avoid blocking AddPeer.
|
||||
func (pm *PeerManager) performRapidInitialTest(siteId int, endpoint string) {
|
||||
if pm.peerMonitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Perform rapid test - this takes ~1-2 seconds max
|
||||
holepunchViable := pm.peerMonitor.RapidTestPeer(siteId, endpoint)
|
||||
|
||||
if !holepunchViable {
|
||||
// Holepunch failed rapid test, request relay immediately
|
||||
logger.Info("Rapid test failed for site %d, requesting relay", siteId)
|
||||
if err := pm.peerMonitor.RequestRelay(siteId); err != nil {
|
||||
logger.Error("Failed to request relay for site %d: %v", siteId, err)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Rapid test passed for site %d, using direct connection", siteId)
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the peer monitor
|
||||
func (pm *PeerManager) Start() {
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.Start()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the peer monitor
|
||||
func (pm *PeerManager) Stop() {
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the peer monitor and cleans up resources
|
||||
func (pm *PeerManager) Close() {
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.Close()
|
||||
pm.peerMonitor = nil
|
||||
}
|
||||
}
|
||||
|
||||
// MarkPeerRelayed marks a peer as currently using relay
|
||||
func (pm *PeerManager) MarkPeerRelayed(siteID int, relayed bool) {
|
||||
pm.mu.Lock()
|
||||
if peer, exists := pm.peers[siteID]; exists {
|
||||
if relayed {
|
||||
// We're being relayed, store the current endpoint as the original
|
||||
// (RelayEndpoint is set by HandleFailover)
|
||||
} else {
|
||||
// Clear relay endpoint when switching back to direct
|
||||
peer.RelayEndpoint = ""
|
||||
pm.peers[siteID] = peer
|
||||
}
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.MarkPeerRelayed(siteID, relayed)
|
||||
}
|
||||
}
|
||||
|
||||
// UnRelayPeer switches a peer from relay back to direct connection
|
||||
func (pm *PeerManager) UnRelayPeer(siteId int, endpoint string) error {
|
||||
pm.mu.Lock()
|
||||
peer, exists := pm.peers[siteId]
|
||||
if exists {
|
||||
// Store the relay endpoint
|
||||
peer.Endpoint = endpoint
|
||||
pm.peers[siteId] = peer
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update WireGuard to use the direct endpoint
|
||||
wgConfig := fmt.Sprintf(`public_key=%s
|
||||
update_only=true
|
||||
endpoint=%s`, util.FixKey(peer.PublicKey), endpoint)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
logger.Error("Failed to switch peer %d to direct connection: %v", siteId, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark as not relayed in monitor
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.MarkPeerRelayed(siteId, false)
|
||||
}
|
||||
|
||||
logger.Info("Switched peer %d back to direct connection at %s", siteId, endpoint)
|
||||
return nil
|
||||
}
|
||||
924
peers/monitor/monitor.go
Normal file
924
peers/monitor/monitor.go
Normal file
@@ -0,0 +1,924 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/api"
|
||||
middleDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||
type PeerMonitor struct {
|
||||
monitors map[int]*Client
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
wsClient *websocket.Client
|
||||
|
||||
// Netstack fields
|
||||
middleDev *middleDevice.MiddleDevice
|
||||
localIP string
|
||||
stack *stack.Stack
|
||||
ep *channel.Endpoint
|
||||
activePorts map[uint16]bool
|
||||
portsLock sync.Mutex
|
||||
nsCtx context.Context
|
||||
nsCancel context.CancelFunc
|
||||
nsWg sync.WaitGroup
|
||||
|
||||
// Holepunch testing fields
|
||||
sharedBind *bind.SharedBind
|
||||
holepunchTester *holepunch.HolepunchTester
|
||||
holepunchInterval time.Duration
|
||||
holepunchTimeout time.Duration
|
||||
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
||||
holepunchStatus map[int]bool // siteID -> connected status
|
||||
holepunchStopChan chan struct{}
|
||||
|
||||
// Relay tracking fields
|
||||
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
|
||||
holepunchMaxAttempts int // max consecutive failures before triggering relay
|
||||
holepunchFailures map[int]int // siteID -> consecutive failure count
|
||||
|
||||
// Rapid initial test fields
|
||||
rapidTestInterval time.Duration // interval between rapid test attempts
|
||||
rapidTestTimeout time.Duration // timeout for each rapid test attempt
|
||||
rapidTestMaxAttempts int // max attempts during rapid test phase
|
||||
|
||||
// API server for status updates
|
||||
apiServer *api.API
|
||||
|
||||
// WG connection status tracking
|
||||
wgConnectionStatus map[int]bool // siteID -> WG connected status
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pm := &PeerMonitor{
|
||||
monitors: make(map[int]*Client),
|
||||
interval: 2 * time.Second, // Default check interval (faster)
|
||||
timeout: 3 * time.Second,
|
||||
maxAttempts: 3,
|
||||
wsClient: wsClient,
|
||||
middleDev: middleDev,
|
||||
localIP: localIP,
|
||||
activePorts: make(map[uint16]bool),
|
||||
nsCtx: ctx,
|
||||
nsCancel: cancel,
|
||||
sharedBind: sharedBind,
|
||||
holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds
|
||||
holepunchTimeout: 2 * time.Second, // Faster timeout
|
||||
holepunchEndpoints: make(map[int]string),
|
||||
holepunchStatus: make(map[int]bool),
|
||||
relayedPeers: make(map[int]bool),
|
||||
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
|
||||
holepunchFailures: make(map[int]int),
|
||||
// Rapid initial test settings: complete within ~1.5 seconds
|
||||
rapidTestInterval: 200 * time.Millisecond, // 200ms between attempts
|
||||
rapidTestTimeout: 400 * time.Millisecond, // 400ms timeout per attempt
|
||||
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
|
||||
apiServer: apiServer,
|
||||
wgConnectionStatus: make(map[int]bool),
|
||||
}
|
||||
|
||||
if err := pm.initNetstack(); err != nil {
|
||||
logger.Error("Failed to initialize netstack for peer monitor: %v", err)
|
||||
}
|
||||
|
||||
// Initialize holepunch tester if sharedBind is available
|
||||
if sharedBind != nil {
|
||||
pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind)
|
||||
}
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
// SetInterval changes how frequently peers are checked
|
||||
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.interval = interval
|
||||
|
||||
// Update interval for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetPacketInterval(interval)
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeout changes the timeout for waiting for responses
|
||||
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.timeout = timeout
|
||||
|
||||
// Update timeout for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetTimeout(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.maxAttempts = attempts
|
||||
|
||||
// Update max attempts for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetMaxAttempts(attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeer adds a new peer to monitor
|
||||
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint string) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if _, exists := pm.monitors[siteID]; exists {
|
||||
return nil // Already monitoring
|
||||
}
|
||||
|
||||
// Use our custom dialer that uses netstack
|
||||
client, err := NewClient(endpoint, pm.dial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.SetPacketInterval(pm.interval)
|
||||
client.SetTimeout(pm.timeout)
|
||||
client.SetMaxAttempts(pm.maxAttempts)
|
||||
|
||||
pm.monitors[siteID] = client
|
||||
|
||||
pm.holepunchEndpoints[siteID] = holepunchEndpoint
|
||||
pm.holepunchStatus[siteID] = false // Initially unknown/disconnected
|
||||
|
||||
if pm.running {
|
||||
if err := client.StartMonitor(func(status ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteID, status)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// update holepunch endpoint for a peer
|
||||
func (pm *PeerMonitor) UpdateHolepunchEndpoint(siteID int, endpoint string) {
|
||||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
pm.holepunchEndpoints[siteID] = endpoint
|
||||
}()
|
||||
}
|
||||
|
||||
// RapidTestPeer performs a rapid connectivity test for a newly added peer.
|
||||
// This is designed to quickly determine if holepunch is viable within ~1-2 seconds.
|
||||
// Returns true if the connection is viable (holepunch works), false if it should relay.
|
||||
func (pm *PeerMonitor) RapidTestPeer(siteID int, endpoint string) bool {
|
||||
if pm.holepunchTester == nil {
|
||||
logger.Warn("Cannot perform rapid test: holepunch tester not initialized")
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mutex.Lock()
|
||||
interval := pm.rapidTestInterval
|
||||
timeout := pm.rapidTestTimeout
|
||||
maxAttempts := pm.rapidTestMaxAttempts
|
||||
pm.mutex.Unlock()
|
||||
|
||||
logger.Info("Starting rapid holepunch test for site %d at %s (max %d attempts, %v timeout each)",
|
||||
siteID, endpoint, maxAttempts, timeout)
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
|
||||
|
||||
if result.Success {
|
||||
logger.Info("Rapid test: site %d holepunch SUCCEEDED on attempt %d (RTT: %v)",
|
||||
siteID, attempt, result.RTT)
|
||||
|
||||
// Update status
|
||||
pm.mutex.Lock()
|
||||
pm.holepunchStatus[siteID] = true
|
||||
pm.holepunchFailures[siteID] = 0
|
||||
pm.mutex.Unlock()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if attempt < maxAttempts {
|
||||
time.Sleep(interval)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Warn("Rapid test: site %d holepunch FAILED after %d attempts, will relay",
|
||||
siteID, maxAttempts)
|
||||
|
||||
// Update status to reflect failure
|
||||
pm.mutex.Lock()
|
||||
pm.holepunchStatus[siteID] = false
|
||||
pm.holepunchFailures[siteID] = maxAttempts
|
||||
pm.mutex.Unlock()
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdatePeerEndpoint updates the monitor endpoint for a peer
|
||||
func (pm *PeerMonitor) UpdatePeerEndpoint(siteID int, monitorPeer string) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
client, exists := pm.monitors[siteID]
|
||||
if !exists {
|
||||
logger.Warn("Cannot update endpoint: peer %d not found in monitor", siteID)
|
||||
return
|
||||
}
|
||||
|
||||
// Update the client's server address
|
||||
client.UpdateServerAddr(monitorPeer)
|
||||
|
||||
logger.Info("Updated monitor endpoint for site %d to %s", siteID, monitorPeer)
|
||||
}
|
||||
|
||||
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
|
||||
// This function assumes the mutex is already held by the caller
|
||||
func (pm *PeerMonitor) removePeerUnlocked(siteID int) {
|
||||
client, exists := pm.monitors[siteID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
}
|
||||
|
||||
// RemovePeer stops monitoring a peer and removes it from the monitor
|
||||
func (pm *PeerMonitor) RemovePeer(siteID int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// remove the holepunch endpoint info
|
||||
delete(pm.holepunchEndpoints, siteID)
|
||||
delete(pm.holepunchStatus, siteID)
|
||||
delete(pm.relayedPeers, siteID)
|
||||
delete(pm.holepunchFailures, siteID)
|
||||
|
||||
pm.removePeerUnlocked(siteID)
|
||||
}
|
||||
|
||||
// Start begins monitoring all peers
|
||||
func (pm *PeerMonitor) Start() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if pm.running {
|
||||
return // Already running
|
||||
}
|
||||
|
||||
pm.running = true
|
||||
|
||||
// Start monitoring all peers
|
||||
for siteID, client := range pm.monitors {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
err := client.StartMonitor(func(status ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err)
|
||||
continue
|
||||
}
|
||||
logger.Info("Started monitoring peer %d\n", siteID)
|
||||
}
|
||||
|
||||
pm.startHolepunchMonitor()
|
||||
}
|
||||
|
||||
// handleConnectionStatusChange is called when a peer's connection status changes
|
||||
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) {
|
||||
pm.mutex.Lock()
|
||||
previousStatus, exists := pm.wgConnectionStatus[siteID]
|
||||
pm.wgConnectionStatus[siteID] = status.Connected
|
||||
isRelayed := pm.relayedPeers[siteID]
|
||||
endpoint := pm.holepunchEndpoints[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
// Log status changes
|
||||
if !exists || previousStatus != status.Connected {
|
||||
if status.Connected {
|
||||
logger.Info("WireGuard connection to site %d is CONNECTED (RTT: %v)", siteID, status.RTT)
|
||||
} else {
|
||||
logger.Warn("WireGuard connection to site %d is DISCONNECTED", siteID)
|
||||
}
|
||||
}
|
||||
|
||||
// Update API with connection status
|
||||
if pm.apiServer != nil {
|
||||
pm.apiServer.UpdatePeerStatus(siteID, status.Connected, status.RTT, endpoint, isRelayed)
|
||||
}
|
||||
}
|
||||
|
||||
// sendRelay sends a relay message to the server
|
||||
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent relay message")
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestRelay is a public method to request relay for a peer.
|
||||
// This is used when rapid initial testing determines holepunch is not viable.
|
||||
func (pm *PeerMonitor) RequestRelay(siteID int) error {
|
||||
return pm.sendRelay(siteID)
|
||||
}
|
||||
|
||||
// sendUnRelay sends an unrelay message to the server
|
||||
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent unrelay message")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops monitoring all peers
|
||||
func (pm *PeerMonitor) Stop() {
|
||||
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||
pm.stopHolepunchMonitor()
|
||||
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
|
||||
// Stop all monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.StopMonitor()
|
||||
}
|
||||
}
|
||||
|
||||
// MarkPeerRelayed marks a peer as currently using relay
|
||||
func (pm *PeerMonitor) MarkPeerRelayed(siteID int, relayed bool) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
pm.relayedPeers[siteID] = relayed
|
||||
if relayed {
|
||||
// Reset failure count when marked as relayed
|
||||
pm.holepunchFailures[siteID] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// IsPeerRelayed returns whether a peer is currently using relay
|
||||
func (pm *PeerMonitor) IsPeerRelayed(siteID int) bool {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
return pm.relayedPeers[siteID]
|
||||
}
|
||||
|
||||
// startHolepunchMonitor starts the holepunch connection monitoring
|
||||
// Note: This function assumes the mutex is already held by the caller (called from Start())
|
||||
func (pm *PeerMonitor) startHolepunchMonitor() error {
|
||||
if pm.holepunchTester == nil {
|
||||
return fmt.Errorf("holepunch tester not initialized (sharedBind not provided)")
|
||||
}
|
||||
|
||||
if pm.holepunchStopChan != nil {
|
||||
return fmt.Errorf("holepunch monitor already running")
|
||||
}
|
||||
|
||||
if err := pm.holepunchTester.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start holepunch tester: %w", err)
|
||||
}
|
||||
|
||||
pm.holepunchStopChan = make(chan struct{})
|
||||
|
||||
go pm.runHolepunchMonitor()
|
||||
|
||||
logger.Info("Started holepunch connection monitor")
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopHolepunchMonitor stops the holepunch connection monitoring
|
||||
func (pm *PeerMonitor) stopHolepunchMonitor() {
|
||||
pm.mutex.Lock()
|
||||
stopChan := pm.holepunchStopChan
|
||||
pm.holepunchStopChan = nil
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if stopChan != nil {
|
||||
close(stopChan)
|
||||
}
|
||||
|
||||
if pm.holepunchTester != nil {
|
||||
pm.holepunchTester.Stop()
|
||||
}
|
||||
|
||||
logger.Info("Stopped holepunch connection monitor")
|
||||
}
|
||||
|
||||
// runHolepunchMonitor runs the holepunch monitoring loop
|
||||
func (pm *PeerMonitor) runHolepunchMonitor() {
|
||||
ticker := time.NewTicker(pm.holepunchInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Do initial check immediately
|
||||
pm.checkHolepunchEndpoints()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pm.holepunchStopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
pm.checkHolepunchEndpoints()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkHolepunchEndpoints tests all holepunch endpoints
|
||||
func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
pm.mutex.Lock()
|
||||
// Check if we're still running before doing any work
|
||||
if !pm.running {
|
||||
pm.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
|
||||
for siteID, endpoint := range pm.holepunchEndpoints {
|
||||
endpoints[siteID] = endpoint
|
||||
}
|
||||
timeout := pm.holepunchTimeout
|
||||
maxAttempts := pm.holepunchMaxAttempts
|
||||
pm.mutex.Unlock()
|
||||
|
||||
for siteID, endpoint := range endpoints {
|
||||
logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
|
||||
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
|
||||
|
||||
pm.mutex.Lock()
|
||||
// Check if peer was removed while we were testing
|
||||
if _, stillExists := pm.holepunchEndpoints[siteID]; !stillExists {
|
||||
pm.mutex.Unlock()
|
||||
continue // Peer was removed, skip processing
|
||||
}
|
||||
|
||||
previousStatus, exists := pm.holepunchStatus[siteID]
|
||||
pm.holepunchStatus[siteID] = result.Success
|
||||
isRelayed := pm.relayedPeers[siteID]
|
||||
|
||||
// Track consecutive failures for relay triggering
|
||||
if result.Success {
|
||||
pm.holepunchFailures[siteID] = 0
|
||||
} else {
|
||||
pm.holepunchFailures[siteID]++
|
||||
}
|
||||
failureCount := pm.holepunchFailures[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
// Log status changes
|
||||
if !exists || previousStatus != result.Success {
|
||||
if result.Success {
|
||||
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
|
||||
} else {
|
||||
if result.Error != nil {
|
||||
logger.Warn("Holepunch to site %d (%s) is DISCONNECTED: %v", siteID, endpoint, result.Error)
|
||||
} else {
|
||||
logger.Warn("Holepunch to site %d (%s) is DISCONNECTED", siteID, endpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update API with holepunch status
|
||||
if pm.apiServer != nil {
|
||||
// Update holepunch connection status
|
||||
pm.apiServer.UpdatePeerHolepunchStatus(siteID, result.Success)
|
||||
|
||||
// Get the current WG connection status for this peer
|
||||
pm.mutex.Lock()
|
||||
wgConnected := pm.wgConnectionStatus[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
// Update API - use holepunch endpoint and relay status
|
||||
pm.apiServer.UpdatePeerStatus(siteID, wgConnected, result.RTT, endpoint, isRelayed)
|
||||
}
|
||||
|
||||
// Handle relay logic based on holepunch status
|
||||
// Check if we're still running before sending relay messages
|
||||
pm.mutex.Lock()
|
||||
stillRunning := pm.running
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !stillRunning {
|
||||
return // Stop processing if shutdown is in progress
|
||||
}
|
||||
|
||||
if !result.Success && !isRelayed && failureCount >= maxAttempts {
|
||||
// Holepunch failed and we're not relayed - trigger relay
|
||||
logger.Info("Holepunch to site %d failed %d times, triggering relay", siteID, failureCount)
|
||||
if pm.wsClient != nil {
|
||||
pm.sendRelay(siteID)
|
||||
}
|
||||
} else if result.Success && isRelayed {
|
||||
// Holepunch succeeded and we ARE relayed - switch back to direct
|
||||
logger.Info("Holepunch to site %d succeeded while relayed, switching to direct connection", siteID)
|
||||
if pm.wsClient != nil {
|
||||
pm.sendUnRelay(siteID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetHolepunchStatus returns the current holepunch status for all endpoints
|
||||
func (pm *PeerMonitor) GetHolepunchStatus() map[int]bool {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
status := make(map[int]bool, len(pm.holepunchStatus))
|
||||
for siteID, connected := range pm.holepunchStatus {
|
||||
status[siteID] = connected
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
// Close stops monitoring and cleans up resources
|
||||
func (pm *PeerMonitor) Close() {
|
||||
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||
pm.stopHolepunchMonitor()
|
||||
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
logger.Debug("PeerMonitor: Starting cleanup")
|
||||
|
||||
// Stop and close all clients first
|
||||
for siteID, client := range pm.monitors {
|
||||
logger.Debug("PeerMonitor: Stopping client for site %d", siteID)
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
|
||||
// Clean up netstack resources
|
||||
logger.Debug("PeerMonitor: Cancelling netstack context")
|
||||
if pm.nsCancel != nil {
|
||||
pm.nsCancel() // Signal goroutines to stop
|
||||
}
|
||||
|
||||
// Close the channel endpoint to unblock any pending reads
|
||||
logger.Debug("PeerMonitor: Closing endpoint")
|
||||
if pm.ep != nil {
|
||||
pm.ep.Close()
|
||||
}
|
||||
|
||||
// Wait for packet sender goroutine to finish with timeout
|
||||
logger.Debug("PeerMonitor: Waiting for goroutines to finish")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
pm.nsWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
logger.Debug("PeerMonitor: Goroutines finished cleanly")
|
||||
case <-time.After(2 * time.Second):
|
||||
logger.Warn("PeerMonitor: Timeout waiting for goroutines to finish, proceeding anyway")
|
||||
}
|
||||
|
||||
// Destroy the stack last, after all goroutines are done
|
||||
logger.Debug("PeerMonitor: Destroying stack")
|
||||
if pm.stack != nil {
|
||||
pm.stack.Destroy()
|
||||
pm.stack = nil
|
||||
}
|
||||
|
||||
logger.Debug("PeerMonitor: Cleanup complete")
|
||||
}
|
||||
|
||||
// TestPeer tests connectivity to a specific peer
|
||||
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||
pm.mutex.Lock()
|
||||
client, exists := pm.monitors[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
defer cancel()
|
||||
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
return connected, rtt, nil
|
||||
}
|
||||
|
||||
// TestAllPeers tests connectivity to all peers
|
||||
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
} {
|
||||
pm.mutex.Lock()
|
||||
peers := make(map[int]*Client, len(pm.monitors))
|
||||
for siteID, client := range pm.monitors {
|
||||
peers[siteID] = client
|
||||
}
|
||||
pm.mutex.Unlock()
|
||||
|
||||
results := make(map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
})
|
||||
for siteID, client := range peers {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
cancel()
|
||||
|
||||
results[siteID] = struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
}{
|
||||
Connected: connected,
|
||||
RTT: rtt,
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// initNetstack initializes the gvisor netstack
|
||||
func (pm *PeerMonitor) initNetstack() error {
|
||||
if pm.localIP == "" {
|
||||
return fmt.Errorf("local IP not provided")
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(pm.localIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid local IP: %v", err)
|
||||
}
|
||||
|
||||
// Create gvisor netstack
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
HandleLocal: true,
|
||||
}
|
||||
|
||||
pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG)
|
||||
pm.stack = stack.New(stackOpts)
|
||||
|
||||
// Create NIC
|
||||
if err := pm.stack.CreateNIC(1, pm.ep); err != nil {
|
||||
return fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
// Add IP address
|
||||
ipBytes := addr.As4()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
||||
}
|
||||
|
||||
if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return fmt.Errorf("failed to add protocol address: %v", err)
|
||||
}
|
||||
|
||||
// Add default route
|
||||
pm.stack.AddRoute(tcpip.Route{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
// Register filter rule on MiddleDevice
|
||||
// We want to intercept packets destined to our local IP
|
||||
// But ONLY if they are for ports we are listening on
|
||||
pm.middleDev.AddRule(addr, pm.handlePacket)
|
||||
|
||||
// Start packet sender (Stack -> WG)
|
||||
pm.nsWg.Add(1)
|
||||
go pm.runPacketSender()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handlePacket is called by MiddleDevice when a packet arrives for our IP
|
||||
func (pm *PeerMonitor) handlePacket(packet []byte) bool {
|
||||
// Check if it's UDP
|
||||
proto, ok := util.GetProtocol(packet)
|
||||
if !ok || proto != 17 { // UDP
|
||||
return false
|
||||
}
|
||||
|
||||
// Check destination port
|
||||
port, ok := util.GetDestPort(packet)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if we are listening on this port
|
||||
pm.portsLock.Lock()
|
||||
active := pm.activePorts[uint16(port)]
|
||||
pm.portsLock.Unlock()
|
||||
|
||||
if !active {
|
||||
return false
|
||||
}
|
||||
|
||||
// Inject into netstack
|
||||
version := packet[0] >> 4
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||
case 6:
|
||||
pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||
default:
|
||||
pkb.DecRef()
|
||||
return false
|
||||
}
|
||||
|
||||
pkb.DecRef()
|
||||
return true // Handled
|
||||
}
|
||||
|
||||
// runPacketSender reads packets from netstack and injects them into WireGuard
|
||||
func (pm *PeerMonitor) runPacketSender() {
|
||||
defer pm.nsWg.Done()
|
||||
logger.Debug("PeerMonitor: Packet sender goroutine started")
|
||||
|
||||
// Use a ticker to periodically check for packets without blocking indefinitely
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pm.nsCtx.Done():
|
||||
logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets")
|
||||
// Drain any remaining packets before exiting
|
||||
for {
|
||||
pkt := pm.ep.Read()
|
||||
if pkt == nil {
|
||||
break
|
||||
}
|
||||
pkt.DecRef()
|
||||
}
|
||||
logger.Debug("PeerMonitor: Packet sender goroutine exiting")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Try to read packets in batches
|
||||
for i := 0; i < 10; i++ {
|
||||
pkt := pm.ep.Read()
|
||||
if pkt == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Extract packet data
|
||||
slices := pkt.AsSlices()
|
||||
if len(slices) > 0 {
|
||||
var totalSize int
|
||||
for _, slice := range slices {
|
||||
totalSize += len(slice)
|
||||
}
|
||||
|
||||
buf := make([]byte, totalSize)
|
||||
pos := 0
|
||||
for _, slice := range slices {
|
||||
copy(buf[pos:], slice)
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Inject into MiddleDevice (outbound to WG)
|
||||
pm.middleDev.InjectOutbound(buf)
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dial creates a UDP connection using the netstack
|
||||
func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) {
|
||||
if pm.stack == nil {
|
||||
return nil, fmt.Errorf("netstack not initialized")
|
||||
}
|
||||
|
||||
// Parse remote address
|
||||
raddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse local IP
|
||||
localIP, err := netip.ParseAddr(pm.localIP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ipBytes := localIP.As4()
|
||||
|
||||
// Create UDP connection
|
||||
// We bind to port 0 (ephemeral)
|
||||
laddr := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4(ipBytes),
|
||||
Port: 0,
|
||||
}
|
||||
|
||||
raddrTcpip := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
|
||||
Port: uint16(raddr.Port),
|
||||
}
|
||||
|
||||
conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get local port
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
port := uint16(localAddr.Port)
|
||||
|
||||
// Register port
|
||||
pm.portsLock.Lock()
|
||||
pm.activePorts[port] = true
|
||||
pm.portsLock.Unlock()
|
||||
|
||||
// Wrap connection to cleanup port on close
|
||||
return &trackedConn{
|
||||
Conn: conn,
|
||||
pm: pm,
|
||||
port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (pm *PeerMonitor) removePort(port uint16) {
|
||||
pm.portsLock.Lock()
|
||||
delete(pm.activePorts, port)
|
||||
pm.portsLock.Unlock()
|
||||
}
|
||||
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
pm *PeerMonitor
|
||||
port uint16
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.pm.removePort(c.port)
|
||||
if c.Conn != nil {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package wgtester
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -26,7 +26,7 @@ const (
|
||||
|
||||
// Client handles checking connectivity to a server
|
||||
type Client struct {
|
||||
conn *net.UDPConn
|
||||
conn net.Conn
|
||||
serverAddr string
|
||||
monitorRunning bool
|
||||
monitorLock sync.Mutex
|
||||
@@ -35,8 +35,12 @@ type Client struct {
|
||||
packetInterval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
dialer Dialer
|
||||
}
|
||||
|
||||
// Dialer is a function that creates a connection
|
||||
type Dialer func(network, addr string) (net.Conn, error)
|
||||
|
||||
// ConnectionStatus represents the current connection state
|
||||
type ConnectionStatus struct {
|
||||
Connected bool
|
||||
@@ -44,13 +48,14 @@ type ConnectionStatus struct {
|
||||
}
|
||||
|
||||
// NewClient creates a new connection test client
|
||||
func NewClient(serverAddr string) (*Client, error) {
|
||||
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
|
||||
return &Client{
|
||||
serverAddr: serverAddr,
|
||||
shutdownCh: make(chan struct{}),
|
||||
packetInterval: 2 * time.Second,
|
||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||
maxAttempts: 3, // Default max attempts
|
||||
dialer: dialer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -69,6 +74,20 @@ func (c *Client) SetMaxAttempts(attempts int) {
|
||||
c.maxAttempts = attempts
|
||||
}
|
||||
|
||||
// UpdateServerAddr updates the server address and resets the connection
|
||||
func (c *Client) UpdateServerAddr(serverAddr string) {
|
||||
c.connLock.Lock()
|
||||
defer c.connLock.Unlock()
|
||||
|
||||
// Close existing connection if any
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
c.serverAddr = serverAddr
|
||||
}
|
||||
|
||||
// Close cleans up client resources
|
||||
func (c *Client) Close() {
|
||||
c.StopMonitor()
|
||||
@@ -91,12 +110,14 @@ func (c *Client) ensureConnection() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
var err error
|
||||
if c.dialer != nil {
|
||||
c.conn, err = c.dialer("udp", c.serverAddr)
|
||||
} else {
|
||||
// Fallback to standard net.Dial
|
||||
c.conn, err = net.Dial("udp", c.serverAddr)
|
||||
}
|
||||
|
||||
c.conn, err = net.DialUDP("udp", nil, serverAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -136,14 +157,14 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
|
||||
// logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
|
||||
_, err := c.conn.Write(packet)
|
||||
if err != nil {
|
||||
c.connLock.Unlock()
|
||||
logger.Info("Error sending packet: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Debug("Successfully sent monitor packet")
|
||||
// logger.Debug("Successfully sent monitor packet")
|
||||
|
||||
// Set read deadline
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||
142
peers/peer.go
Normal file
142
peers/peer.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package peers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error {
|
||||
var endpoint string
|
||||
if relay && siteConfig.RelayEndpoint != "" {
|
||||
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
|
||||
} else {
|
||||
endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||
}
|
||||
siteHost, err := util.ResolveDomain(endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||
}
|
||||
|
||||
// Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP
|
||||
allowedIp := strings.Split(siteConfig.ServerIP, "/")
|
||||
if len(allowedIp) > 1 {
|
||||
allowedIp[1] = "32"
|
||||
} else {
|
||||
allowedIp = append(allowedIp, "32")
|
||||
}
|
||||
allowedIpStr := strings.Join(allowedIp, "/")
|
||||
|
||||
// Collect all allowed IPs in a slice
|
||||
var allowedIPs []string
|
||||
allowedIPs = append(allowedIPs, allowedIpStr)
|
||||
|
||||
// Use AllowedIps if available, otherwise fall back to RemoteSubnets for backwards compatibility
|
||||
subnetsToAdd := siteConfig.AllowedIps
|
||||
|
||||
// If we have anything to add, process them
|
||||
if len(subnetsToAdd) > 0 {
|
||||
// Add each subnet
|
||||
for _, subnet := range subnetsToAdd {
|
||||
subnet = strings.TrimSpace(subnet)
|
||||
if subnet != "" {
|
||||
allowedIPs = append(allowedIPs, subnet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct WireGuard config for this peer
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", util.FixKey(privateKey.String())))
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(siteConfig.PublicKey)))
|
||||
|
||||
// Add each allowed IP separately
|
||||
for _, allowedIP := range allowedIPs {
|
||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||
}
|
||||
|
||||
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
||||
configBuilder.WriteString("persistent_keepalive_interval=5\n")
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Configuring peer with config: %s", config)
|
||||
|
||||
err = dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeer removes a peer from the WireGuard device
|
||||
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
|
||||
// Construct WireGuard config to remove the peer
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||
configBuilder.WriteString("remove=true\n")
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Removing peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddAllowedIP adds a single allowed IP to an existing peer without reconfiguring the entire peer
|
||||
func AddAllowedIP(dev *device.Device, publicKey string, allowedIP string) error {
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||
configBuilder.WriteString("update_only=true\n")
|
||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Adding allowed IP to peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add allowed IP to WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAllowedIP removes a single allowed IP from an existing peer by replacing the allowed IPs list
|
||||
// This requires providing all the allowed IPs that should remain after removal
|
||||
func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs []string) error {
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||
configBuilder.WriteString("update_only=true\n")
|
||||
configBuilder.WriteString("replace_allowed_ips=true\n")
|
||||
|
||||
// Add each remaining allowed IP
|
||||
for _, allowedIP := range remainingAllowedIPs {
|
||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||
}
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Removing allowed IP from peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove allowed IP from WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatEndpoint(endpoint string) string {
|
||||
if strings.Contains(endpoint, ":") {
|
||||
return endpoint
|
||||
}
|
||||
return endpoint + ":51820"
|
||||
}
|
||||
63
peers/types.go
Normal file
63
peers/types.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package peers
|
||||
|
||||
// PeerAction represents a request to add, update, or remove a peer
|
||||
type PeerAction struct {
|
||||
Action string `json:"action"` // "add", "update", or "remove"
|
||||
SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information
|
||||
}
|
||||
|
||||
// UpdatePeerData represents the data needed to update a peer
|
||||
type SiteConfig struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
RelayEndpoint string `json:"relayEndpoint,omitempty"`
|
||||
PublicKey string `json:"publicKey,omitempty"`
|
||||
ServerIP string `json:"serverIP,omitempty"`
|
||||
ServerPort uint16 `json:"serverPort,omitempty"`
|
||||
RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access
|
||||
AllowedIps []string `json:"allowedIps,omitempty"` // optional, array of allowed IPs for the peer
|
||||
Aliases []Alias `json:"aliases,omitempty"` // optional, array of alias configurations
|
||||
}
|
||||
|
||||
type Alias struct {
|
||||
Alias string `json:"alias"` // the alias name
|
||||
AliasAddress string `json:"aliasAddress"` // the alias IP address
|
||||
}
|
||||
|
||||
// RemovePeer represents the data needed to remove a peer
|
||||
type PeerRemove struct {
|
||||
SiteId int `json:"siteId"`
|
||||
}
|
||||
|
||||
type RelayPeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
RelayEndpoint string `json:"relayEndpoint"`
|
||||
}
|
||||
|
||||
type UnRelayPeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
// PeerAdd represents the data needed to add remote subnets to a peer
|
||||
type PeerAdd struct {
|
||||
SiteId int `json:"siteId"`
|
||||
RemoteSubnets []string `json:"remoteSubnets"` // subnets to add
|
||||
Aliases []Alias `json:"aliases,omitempty"` // aliases to add
|
||||
}
|
||||
|
||||
// RemovePeerData represents the data needed to remove remote subnets from a peer
|
||||
type RemovePeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
RemoteSubnets []string `json:"remoteSubnets"` // subnets to remove
|
||||
Aliases []Alias `json:"aliases,omitempty"` // aliases to remove
|
||||
}
|
||||
|
||||
type UpdatePeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
OldRemoteSubnets []string `json:"oldRemoteSubnets"` // old list of remote subnets
|
||||
NewRemoteSubnets []string `json:"newRemoteSubnets"` // new list of remote subnets
|
||||
OldAliases []Alias `json:"oldAliases,omitempty"` // old list of aliases
|
||||
NewAliases []Alias `json:"newAliases,omitempty"` // new list of aliases
|
||||
}
|
||||
@@ -99,15 +99,32 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes
|
||||
// Continue with empty args if loading fails
|
||||
savedArgs = []string{}
|
||||
}
|
||||
s.elog.Info(1, fmt.Sprintf("Loaded saved service args: %v", savedArgs))
|
||||
|
||||
// Combine service start args with saved args, giving priority to service start args
|
||||
// Note: When the service is started via SCM, args[0] is the service name
|
||||
// When started via s.Start(args...), the args passed are exactly what we provide
|
||||
finalArgs := []string{}
|
||||
|
||||
// Check if we have args passed directly to Execute (from s.Start())
|
||||
if len(args) > 0 {
|
||||
// Skip the first arg which is typically the service name
|
||||
if len(args) > 1 {
|
||||
// The first arg from SCM is the service name, but when we call s.Start(args...),
|
||||
// the args we pass become args[1:] in Execute. However, if started by SCM without
|
||||
// args, args[0] will be the service name.
|
||||
// We need to check if args[0] looks like the service name or a flag
|
||||
if len(args) == 1 && args[0] == serviceName {
|
||||
// Only service name, no actual args
|
||||
s.elog.Info(1, "Only service name in args, checking saved args")
|
||||
} else if len(args) > 1 && args[0] == serviceName {
|
||||
// Service name followed by actual args
|
||||
finalArgs = append(finalArgs, args[1:]...)
|
||||
s.elog.Info(1, fmt.Sprintf("Using service start parameters (after service name): %v", finalArgs))
|
||||
} else {
|
||||
// Args don't start with service name, use them all
|
||||
// This happens when args are passed via s.Start(args...)
|
||||
finalArgs = append(finalArgs, args...)
|
||||
s.elog.Info(1, fmt.Sprintf("Using service start parameters (direct): %v", finalArgs))
|
||||
}
|
||||
s.elog.Info(1, fmt.Sprintf("Using service start parameters: %v", finalArgs))
|
||||
}
|
||||
|
||||
// If no service start parameters, use saved args
|
||||
@@ -116,6 +133,7 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes
|
||||
s.elog.Info(1, fmt.Sprintf("Using saved service args: %v", finalArgs))
|
||||
}
|
||||
|
||||
s.elog.Info(1, fmt.Sprintf("Final args to use: %v", finalArgs))
|
||||
s.args = finalArgs
|
||||
|
||||
// Start the main olm functionality
|
||||
@@ -163,6 +181,9 @@ func (s *olmService) runOlm() {
|
||||
// Create a context that can be cancelled when the service stops
|
||||
s.ctx, s.stop = context.WithCancel(context.Background())
|
||||
|
||||
// Create a separate context for programmatic shutdown (e.g., via API exit)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Setup logging for service mode
|
||||
s.elog.Info(1, "Starting Olm main logic")
|
||||
|
||||
@@ -177,7 +198,8 @@ func (s *olmService) runOlm() {
|
||||
}()
|
||||
|
||||
// Call the main olm function with stored arguments
|
||||
runOlmMainWithArgs(s.ctx, s.args)
|
||||
// Use s.ctx as the signal context since the service manages shutdown
|
||||
runOlmMainWithArgs(ctx, cancel, s.ctx, s.args)
|
||||
}()
|
||||
|
||||
// Wait for either context cancellation or main logic completion
|
||||
@@ -321,12 +343,15 @@ func removeService() error {
|
||||
}
|
||||
|
||||
func startService(args []string) error {
|
||||
// Save the service arguments as backup
|
||||
if len(args) > 0 {
|
||||
err := saveServiceArgs(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save service args: %v", err)
|
||||
}
|
||||
fmt.Printf("Starting service with args: %v\n", args)
|
||||
|
||||
// Always save the service arguments so they can be loaded on service restart
|
||||
err := saveServiceArgs(args)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to save service args: %v\n", err)
|
||||
// Continue anyway, args will still be passed directly
|
||||
} else {
|
||||
fmt.Printf("Saved service args to: %s\n", getServiceArgsPath())
|
||||
}
|
||||
|
||||
m, err := mgr.Connect()
|
||||
@@ -342,6 +367,7 @@ func startService(args []string) error {
|
||||
defer s.Close()
|
||||
|
||||
// Pass arguments directly to the service start call
|
||||
// Note: These args will appear in Execute() after the service name
|
||||
err = s.Start(args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start service: %v", err)
|
||||
|
||||
@@ -20,14 +20,37 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// AuthError represents an authentication/authorization error (401/403)
|
||||
type AuthError struct {
|
||||
StatusCode int
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *AuthError) Error() string {
|
||||
return fmt.Sprintf("authentication error (status %d): %s", e.StatusCode, e.Message)
|
||||
}
|
||||
|
||||
// IsAuthError checks if an error is an authentication error
|
||||
func IsAuthError(err error) bool {
|
||||
_, ok := err.(*AuthError)
|
||||
return ok
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
Data struct {
|
||||
Token string `json:"token"`
|
||||
Token string `json:"token"`
|
||||
ExitNodes []ExitNode `json:"exitNodes"`
|
||||
ServerVersion string `json:"serverVersion"`
|
||||
} `json:"data"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
type WSMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
@@ -40,6 +63,7 @@ type Config struct {
|
||||
Endpoint string
|
||||
TlsClientCert string // legacy PKCS12 file path
|
||||
UserToken string // optional user token for websocket authentication
|
||||
OrgID string // optional organization ID for websocket authentication
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
@@ -55,7 +79,8 @@ type Client struct {
|
||||
pingInterval time.Duration
|
||||
pingTimeout time.Duration
|
||||
onConnect func() error
|
||||
onTokenUpdate func(token string)
|
||||
onTokenUpdate func(token string, exitNodes []ExitNode)
|
||||
onAuthError func(statusCode int, message string) // Callback for auth errors
|
||||
writeMux sync.Mutex
|
||||
clientType string // Type of client (e.g., "newt", "olm")
|
||||
tlsConfig TLSConfig
|
||||
@@ -99,17 +124,22 @@ func (c *Client) OnConnect(callback func() error) {
|
||||
c.onConnect = callback
|
||||
}
|
||||
|
||||
func (c *Client) OnTokenUpdate(callback func(token string)) {
|
||||
func (c *Client) OnTokenUpdate(callback func(token string, exitNodes []ExitNode)) {
|
||||
c.onTokenUpdate = callback
|
||||
}
|
||||
|
||||
func (c *Client) OnAuthError(callback func(statusCode int, message string)) {
|
||||
c.onAuthError = callback
|
||||
}
|
||||
|
||||
// NewClient creates a new websocket client
|
||||
func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
|
||||
func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
|
||||
config := &Config{
|
||||
ID: ID,
|
||||
Secret: secret,
|
||||
Endpoint: endpoint,
|
||||
UserToken: userToken,
|
||||
OrgID: orgId,
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
@@ -191,13 +221,17 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||
return c.conn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
|
||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
|
||||
stopChan := make(chan struct{})
|
||||
updateChan := make(chan interface{})
|
||||
var dataMux sync.Mutex
|
||||
currentData := data
|
||||
|
||||
go func() {
|
||||
count := 0
|
||||
maxAttempts := 10
|
||||
|
||||
err := c.SendMessage(messageType, data) // Send immediately
|
||||
err := c.SendMessage(messageType, currentData) // Send immediately
|
||||
if err != nil {
|
||||
logger.Error("Failed to send initial message: %v", err)
|
||||
}
|
||||
@@ -212,19 +246,46 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
||||
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||
return
|
||||
}
|
||||
err = c.SendMessage(messageType, data)
|
||||
dataMux.Lock()
|
||||
err = c.SendMessage(messageType, currentData)
|
||||
dataMux.Unlock()
|
||||
if err != nil {
|
||||
logger.Error("Failed to send message: %v", err)
|
||||
}
|
||||
count++
|
||||
case newData := <-updateChan:
|
||||
dataMux.Lock()
|
||||
// Merge newData into currentData if both are maps
|
||||
if currentMap, ok := currentData.(map[string]interface{}); ok {
|
||||
if newMap, ok := newData.(map[string]interface{}); ok {
|
||||
// Update or add keys from newData
|
||||
for key, value := range newMap {
|
||||
currentMap[key] = value
|
||||
}
|
||||
currentData = currentMap
|
||||
} else {
|
||||
// If newData is not a map, replace entirely
|
||||
currentData = newData
|
||||
}
|
||||
} else {
|
||||
// If currentData is not a map, replace entirely
|
||||
currentData = newData
|
||||
}
|
||||
dataMux.Unlock()
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
close(stopChan)
|
||||
}
|
||||
close(stopChan)
|
||||
}, func(newData interface{}) {
|
||||
select {
|
||||
case updateChan <- newData:
|
||||
case <-stopChan:
|
||||
// Channel is closed, ignore update
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a specific message type
|
||||
@@ -234,11 +295,11 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
||||
c.handlers[messageType] = handler
|
||||
}
|
||||
|
||||
func (c *Client) getToken() (string, error) {
|
||||
func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
// Parse the base URL to ensure we have the correct hostname
|
||||
baseURL, err := url.Parse(c.baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse base URL: %w", err)
|
||||
return "", nil, fmt.Errorf("failed to parse base URL: %w", err)
|
||||
}
|
||||
|
||||
// Ensure we have the base URL without trailing slashes
|
||||
@@ -250,7 +311,7 @@ func (c *Client) getToken() (string, error) {
|
||||
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||
tlsConfig, err = c.setupTLS()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||
return "", nil, fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,16 +324,15 @@ func (c *Client) getToken() (string, error) {
|
||||
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
var tokenData map[string]interface{}
|
||||
|
||||
tokenData = map[string]interface{}{
|
||||
tokenData := map[string]interface{}{
|
||||
"olmId": c.config.ID,
|
||||
"secret": c.config.Secret,
|
||||
"orgId": c.config.OrgID,
|
||||
}
|
||||
jsonData, err := json.Marshal(tokenData)
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal token request data: %w", err)
|
||||
return "", nil, fmt.Errorf("failed to marshal token request data: %w", err)
|
||||
}
|
||||
|
||||
// Create a new request
|
||||
@@ -282,13 +342,16 @@ func (c *Client) getToken() (string, error) {
|
||||
bytes.NewBuffer(jsonData),
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
return "", nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
||||
|
||||
// print out the request for debugging
|
||||
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
|
||||
|
||||
// Make the request
|
||||
client := &http.Client{}
|
||||
if tlsConfig != nil {
|
||||
@@ -298,33 +361,43 @@ func (c *Client) getToken() (string, error) {
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to request new token: %w", err)
|
||||
return "", nil, fmt.Errorf("failed to request new token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
|
||||
// Return AuthError for 401/403 status codes
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return "", nil, &AuthError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Message: string(body),
|
||||
}
|
||||
}
|
||||
|
||||
// For other errors (5xx, network issues, etc.), return regular error
|
||||
return "", nil, fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
logger.Error("Failed to decode token response.")
|
||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
if !tokenResp.Success {
|
||||
return "", fmt.Errorf("failed to get token: %s", tokenResp.Message)
|
||||
return "", nil, fmt.Errorf("failed to get token: %s", tokenResp.Message)
|
||||
}
|
||||
|
||||
if tokenResp.Data.Token == "" {
|
||||
return "", fmt.Errorf("received empty token from server")
|
||||
return "", nil, fmt.Errorf("received empty token from server")
|
||||
}
|
||||
|
||||
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
||||
|
||||
return tokenResp.Data.Token, nil
|
||||
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
|
||||
}
|
||||
|
||||
func (c *Client) connectWithRetry() {
|
||||
@@ -335,6 +408,18 @@ func (c *Client) connectWithRetry() {
|
||||
default:
|
||||
err := c.establishConnection()
|
||||
if err != nil {
|
||||
// Check if this is an auth error (401/403)
|
||||
if authErr, ok := err.(*AuthError); ok {
|
||||
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
|
||||
// Trigger auth error callback if set (this should terminate the tunnel)
|
||||
if c.onAuthError != nil {
|
||||
c.onAuthError(authErr.StatusCode, authErr.Message)
|
||||
}
|
||||
// Continue retrying after auth error
|
||||
time.Sleep(c.reconnectInterval)
|
||||
continue
|
||||
}
|
||||
// For other errors (5xx, network issues), continue retrying
|
||||
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||
time.Sleep(c.reconnectInterval)
|
||||
continue
|
||||
@@ -346,13 +431,13 @@ func (c *Client) connectWithRetry() {
|
||||
|
||||
func (c *Client) establishConnection() error {
|
||||
// Get token for authentication
|
||||
token, err := c.getToken()
|
||||
token, exitNodes, err := c.getToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
|
||||
if c.onTokenUpdate != nil {
|
||||
c.onTokenUpdate(token)
|
||||
c.onTokenUpdate(token, exitNodes)
|
||||
}
|
||||
|
||||
// Parse the base URL to determine protocol and hostname
|
||||
|
||||
Reference in New Issue
Block a user