mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 14:06:41 +00:00
Compare commits
155 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d13cc179e8 | ||
|
|
c71828f5a1 | ||
|
|
dc83af6c2e | ||
|
|
35544e1081 | ||
|
|
2ddb4a5645 | ||
|
|
c25fb02f1e | ||
|
|
28583c9507 | ||
|
|
ba41602e4b | ||
|
|
4b8b281d5b | ||
|
|
a07a714d93 | ||
|
|
58ce93f6c3 | ||
|
|
293e507000 | ||
|
|
c948208493 | ||
|
|
2106734aa4 | ||
|
|
51162d6be6 | ||
|
|
45ef6e5279 | ||
|
|
3b2ffe006a | ||
|
|
a497f0873f | ||
|
|
6e4ec246ef | ||
|
|
7270b840cf | ||
|
|
fb007e09a9 | ||
|
|
9ce6450351 | ||
|
|
672fff0ad9 | ||
|
|
22474d92ef | ||
|
|
0e4a657700 | ||
|
|
e24ee0e68b | ||
|
|
cea9ab0932 | ||
|
|
229dc6afce | ||
|
|
e2fe7d53f8 | ||
|
|
e8f1fb507c | ||
|
|
7e410cde28 | ||
|
|
afe0d338be | ||
|
|
a18b367e60 | ||
|
|
91e44e112e | ||
|
|
a38d1ef8a8 | ||
|
|
a32e91de24 | ||
|
|
92b551fa4b | ||
|
|
53c1fa117a | ||
|
|
50525aaf8d | ||
|
|
d8ced86d19 | ||
|
|
2718d15825 | ||
|
|
fff234bdd5 | ||
|
|
0802673048 | ||
|
|
d54b7e3f14 | ||
|
|
650084132b | ||
|
|
9d34c818d7 | ||
|
|
2436a5be15 | ||
|
|
430f2bf7fa | ||
|
|
16362f285d | ||
|
|
34c7f89804 | ||
|
|
ead8fab70a | ||
|
|
50008f3c12 | ||
|
|
24b5122cc1 | ||
|
|
9099b246dc | ||
|
|
30ff3c06eb | ||
|
|
d02ca20c06 | ||
|
|
7afe842a95 | ||
|
|
47d628af73 | ||
|
|
0f1e51f391 | ||
|
|
5d6024ac59 | ||
|
|
6c7ee31330 | ||
|
|
b38357875e | ||
|
|
c230c7be28 | ||
|
|
d7cd746cc9 | ||
|
|
7941479994 | ||
|
|
e3623fd756 | ||
|
|
68c2744ebe | ||
|
|
a9d8d0e5c6 | ||
|
|
7f94fbc1e4 | ||
|
|
542d7e5d61 | ||
|
|
930bf7e0f2 | ||
|
|
d7345c7dbd | ||
|
|
3e2cb70d58 | ||
|
|
d4c5292e8f | ||
|
|
45047343c4 | ||
|
|
c09fb312e8 | ||
|
|
8dfb4b2b20 | ||
|
|
6b17cb08c0 | ||
|
|
aa866493aa | ||
|
|
7b28137cf6 | ||
|
|
a8383f5612 | ||
|
|
2fc385155e | ||
|
|
b7271b77b6 | ||
|
|
1ef6b7ada6 | ||
|
|
ea454d0528 | ||
|
|
a6670ccab3 | ||
|
|
f226e8f7f3 | ||
|
|
75890ca5a6 | ||
|
|
e6cf631dbc | ||
|
|
7fc09f8ed1 | ||
|
|
079843602c | ||
|
|
70bf22c354 | ||
|
|
3d891cfa97 | ||
|
|
78e3bb374a | ||
|
|
a61c7ca1ee | ||
|
|
7696ba2e36 | ||
|
|
235877c379 | ||
|
|
befab0f8d1 | ||
|
|
914d080a57 | ||
|
|
a274b4b38f | ||
|
|
ce3c585514 | ||
|
|
963d8abad5 | ||
|
|
38eb56381f | ||
|
|
43b3822090 | ||
|
|
b0fb370c4d | ||
|
|
99328ee76f | ||
|
|
36fc3ea253 | ||
|
|
a7979259f3 | ||
|
|
ea6fa72bc0 | ||
|
|
f9adde6b1d | ||
|
|
ba25586646 | ||
|
|
952ab63e8d | ||
|
|
5e84f802ed | ||
|
|
f40b0ff820 | ||
|
|
95a4840374 | ||
|
|
27424170e4 | ||
|
|
a8ace6f64a | ||
|
|
3fa1073f49 | ||
|
|
76d86c10ff | ||
|
|
2d34c6c8b2 | ||
|
|
a7f3477bdd | ||
|
|
af0a72d296 | ||
|
|
d1e836e760 | ||
|
|
8dd45c4ca2 | ||
|
|
9db009058b | ||
|
|
29c01deb05 | ||
|
|
7224d9824d | ||
|
|
8afc28fdff | ||
|
|
4ba2fb7b53 | ||
|
|
2e6076923d | ||
|
|
4c001dc751 | ||
|
|
2b8e240752 | ||
|
|
bee490713d | ||
|
|
1cb7fd94ab | ||
|
|
dc9a547950 | ||
|
|
2be0933246 | ||
|
|
c0b1cd6bde | ||
|
|
dd00289f8e | ||
|
|
b23a02ee97 | ||
|
|
80f726cfea | ||
|
|
aa8828186f | ||
|
|
0a990d196d | ||
|
|
00e8050949 | ||
|
|
18ee4c93fb | ||
|
|
8fa2da00b6 | ||
|
|
b851cd73c9 | ||
|
|
4c19d7ef6d | ||
|
|
cbecb9a0ce | ||
|
|
4fc8db08ba | ||
|
|
7ca46e0a75 | ||
|
|
a4ea5143af | ||
|
|
e9257b6423 | ||
|
|
3c9d3a1d2c | ||
|
|
b426f14190 | ||
|
|
d48acfba39 |
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Summary
|
||||
description: A clear and concise summary of the requested feature.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Motivation
|
||||
description: |
|
||||
Why is this feature important?
|
||||
Explain the problem this feature would solve or what use case it would enable.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Proposed Solution
|
||||
description: |
|
||||
How would you like to see this feature implemented?
|
||||
Provide as much detail as possible about the desired behavior, configuration, or changes.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
description: Describe any alternative solutions or workarounds you've thought about.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context, mockups, or screenshots about the feature request here.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before submitting, please:
|
||||
- Check if there is an existing issue for this feature.
|
||||
- Clearly explain the benefit and use case.
|
||||
- Be as specific as possible to help contributors evaluate and implement.
|
||||
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: Bug Report
|
||||
description: Create a bug report
|
||||
labels: []
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Describe the Bug
|
||||
description: A clear and concise description of what the bug is.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Environment
|
||||
description: Please fill out the relevant details below for your environment.
|
||||
value: |
|
||||
- OS Type & Version: (e.g., Ubuntu 22.04)
|
||||
- Pangolin Version:
|
||||
- Gerbil Version:
|
||||
- Traefik Version:
|
||||
- Newt Version:
|
||||
- Olm Version: (if applicable)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: To Reproduce
|
||||
description: |
|
||||
Steps to reproduce the behavior, please provide a clear description of how to reproduce the issue, based on the linked minimal reproduction. Screenshots can be provided in the issue body below.
|
||||
|
||||
If using code blocks, make sure syntax highlighting is correct and double-check that the rendered preview is not broken.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: A clear and concise description of what you expected to happen.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before posting the issue go through the steps you've written down to make sure the steps provided are detailed and clear.
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Contributors should be able to follow the steps provided in order to reproduce the bug.
|
||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Need help or have questions?
|
||||
url: https://github.com/orgs/fosrl/discussions
|
||||
about: Ask questions, get help, and discuss with other community members
|
||||
- name: Request a Feature
|
||||
url: https://github.com/orgs/fosrl/discussions/new?category=feature-requests
|
||||
about: Feature requests should be opened as discussions so others can upvote and comment
|
||||
4
.github/workflows/cicd.yml
vendored
4
.github/workflows/cicd.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
jobs:
|
||||
release:
|
||||
name: Build and Release
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: amd64-runner
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 1.25
|
||||
|
||||
|
||||
132
.github/workflows/mirror.yaml
vendored
Normal file
132
.github/workflows/mirror.yaml
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
name: Mirror & Sign (Docker Hub to GHCR)
|
||||
|
||||
on:
|
||||
workflow_dispatch: {}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write # for keyless OIDC
|
||||
|
||||
env:
|
||||
SOURCE_IMAGE: docker.io/fosrl/olm
|
||||
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||
|
||||
jobs:
|
||||
mirror-and-dual-sign:
|
||||
runs-on: amd64-runner
|
||||
steps:
|
||||
- name: Install skopeo + jq
|
||||
run: |
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y skopeo jq
|
||||
skopeo --version
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||
|
||||
- name: Input check
|
||||
run: |
|
||||
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
|
||||
echo "Source : ${SOURCE_IMAGE}"
|
||||
echo "Target : ${DEST_IMAGE}"
|
||||
|
||||
# Auth for skopeo (containers-auth)
|
||||
- name: Skopeo login to GHCR
|
||||
run: |
|
||||
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
# Auth for cosign (docker-config)
|
||||
- name: Docker login to GHCR (for cosign)
|
||||
run: |
|
||||
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
|
||||
|
||||
- name: List source tags
|
||||
run: |
|
||||
set -euo pipefail
|
||||
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
|
||||
| jq -r '.Tags[]' | sort -u > src-tags.txt
|
||||
echo "Found source tags: $(wc -l < src-tags.txt)"
|
||||
head -n 20 src-tags.txt || true
|
||||
|
||||
- name: List destination tags (skip existing)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
|
||||
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
|
||||
else
|
||||
: > dst-tags.txt
|
||||
fi
|
||||
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
|
||||
|
||||
- name: Mirror, dual-sign, and verify
|
||||
env:
|
||||
# keyless
|
||||
COSIGN_YES: "true"
|
||||
# key-based
|
||||
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||
# verify
|
||||
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
copied=0; skipped=0; v_ok=0; errs=0
|
||||
|
||||
issuer="https://token.actions.githubusercontent.com"
|
||||
id_regex="^https://github.com/${{ github.repository }}/.+"
|
||||
|
||||
while read -r tag; do
|
||||
[ -z "$tag" ] && continue
|
||||
|
||||
if grep -Fxq "$tag" dst-tags.txt; then
|
||||
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
|
||||
skipped=$((skipped+1))
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
|
||||
if ! skopeo copy --all --retry-times 3 \
|
||||
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
|
||||
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
|
||||
errs=$((errs+1)); continue
|
||||
fi
|
||||
copied=$((copied+1))
|
||||
|
||||
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
|
||||
ref="${DEST_IMAGE}@${digest}"
|
||||
|
||||
echo "==> cosign sign (keyless) --recursive ${ref}"
|
||||
if ! cosign sign --recursive "${ref}"; then
|
||||
echo "::warning title=Keyless sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign sign (key) --recursive ${ref}"
|
||||
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
|
||||
echo "::warning title=Key sign failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (public key) ${ref}"
|
||||
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
|
||||
echo "::warning title=Verify(pubkey) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
fi
|
||||
|
||||
echo "==> cosign verify (keyless policy) ${ref}"
|
||||
if ! cosign verify \
|
||||
--certificate-oidc-issuer "${issuer}" \
|
||||
--certificate-identity-regexp "${id_regex}" \
|
||||
"${ref}" -o text; then
|
||||
echo "::warning title=Verify(keyless) failed::${ref}"
|
||||
errs=$((errs+1))
|
||||
else
|
||||
v_ok=$((v_ok+1))
|
||||
fi
|
||||
done < src-tags.txt
|
||||
|
||||
echo "---- Summary ----"
|
||||
echo "Copied : $copied"
|
||||
echo "Skipped : $skipped"
|
||||
echo "Verified OK : $v_ok"
|
||||
echo "Errors : $errs"
|
||||
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 1.25
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,2 @@
|
||||
olm
|
||||
.DS_Store
|
||||
bin/
|
||||
@@ -4,11 +4,7 @@ Contributions are welcome!
|
||||
|
||||
Please see the contribution and local development guide on the docs page before getting started:
|
||||
|
||||
https://docs.fossorial.io/development
|
||||
|
||||
For ideas about what features to work on and our future plans, please see the roadmap:
|
||||
|
||||
https://docs.fossorial.io/roadmap
|
||||
https://docs.pangolin.net/development/contributing
|
||||
|
||||
### Licensing Considerations
|
||||
|
||||
|
||||
2
Makefile
2
Makefile
@@ -10,7 +10,7 @@ docker-build-release:
|
||||
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:$(tag) -f Dockerfile --push .
|
||||
|
||||
local:
|
||||
CGO_ENABLED=0 go build -o olm
|
||||
CGO_ENABLED=0 go build -o bin/olm
|
||||
|
||||
build:
|
||||
docker build -t fosrl/olm:latest .
|
||||
|
||||
@@ -6,7 +6,7 @@ Olm is a [WireGuard](https://www.wireguard.com/) tunnel client designed to secur
|
||||
|
||||
Olm is used with Pangolin and Newt as part of the larger system. See documentation below:
|
||||
|
||||
- [Full Documentation](https://docs.fossorial.io)
|
||||
- [Full Documentation](https://docs.pangolin.net)
|
||||
|
||||
## Key Functions
|
||||
|
||||
@@ -107,7 +107,7 @@ $ cat ~/.config/olm-client/config.json
|
||||
{
|
||||
"id": "spmzu8rbpzj1qq6",
|
||||
"secret": "f6v61mjutwme2kkydbw3fjo227zl60a2tsf5psw9r25hgae3",
|
||||
"endpoint": "https://pangolin.fossorial.io",
|
||||
"endpoint": "https://app.pangolin.net",
|
||||
"tlsClientCert": ""
|
||||
}
|
||||
```
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
||||
|
||||
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
||||
2. Send a detailed report to [security@fossorial.io](mailto:security@fossorial.io) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
2. Send a detailed report to [security@pangolin.net](mailto:security@pangolin.net) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||
|
||||
- Description and location of the vulnerability.
|
||||
- Potential impact of the vulnerability.
|
||||
|
||||
482
api/api.go
Normal file
482
api/api.go
Normal file
@@ -0,0 +1,482 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
)
|
||||
|
||||
// ConnectionRequest defines the structure for an incoming connection request
|
||||
type ConnectionRequest struct {
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
UserToken string `json:"userToken,omitempty"`
|
||||
MTU int `json:"mtu,omitempty"`
|
||||
DNS string `json:"dns,omitempty"`
|
||||
DNSProxyIP string `json:"dnsProxyIP,omitempty"`
|
||||
UpstreamDNS []string `json:"upstreamDNS,omitempty"`
|
||||
InterfaceName string `json:"interfaceName,omitempty"`
|
||||
Holepunch bool `json:"holepunch,omitempty"`
|
||||
TlsClientCert string `json:"tlsClientCert,omitempty"`
|
||||
PingInterval string `json:"pingInterval,omitempty"`
|
||||
PingTimeout string `json:"pingTimeout,omitempty"`
|
||||
OrgID string `json:"orgId,omitempty"`
|
||||
}
|
||||
|
||||
// SwitchOrgRequest defines the structure for switching organizations
|
||||
type SwitchOrgRequest struct {
|
||||
OrgID string `json:"orgId"`
|
||||
}
|
||||
|
||||
// PeerStatus represents the status of a peer connection
|
||||
type PeerStatus struct {
|
||||
SiteID int `json:"siteId"`
|
||||
Connected bool `json:"connected"`
|
||||
RTT time.Duration `json:"rtt"`
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
IsRelay bool `json:"isRelay"`
|
||||
PeerIP string `json:"peerAddress,omitempty"`
|
||||
HolepunchConnected bool `json:"holepunchConnected"`
|
||||
}
|
||||
|
||||
// StatusResponse is returned by the status endpoint
|
||||
type StatusResponse struct {
|
||||
Connected bool `json:"connected"`
|
||||
Registered bool `json:"registered"`
|
||||
Terminated bool `json:"terminated"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Agent string `json:"agent,omitempty"`
|
||||
OrgID string `json:"orgId,omitempty"`
|
||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
|
||||
}
|
||||
|
||||
// API represents the HTTP server and its state
|
||||
type API struct {
|
||||
addr string
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
onConnect func(ConnectionRequest) error
|
||||
onSwitchOrg func(SwitchOrgRequest) error
|
||||
onDisconnect func() error
|
||||
onExit func() error
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
connectedAt time.Time
|
||||
isConnected bool
|
||||
isRegistered bool
|
||||
isTerminated bool
|
||||
version string
|
||||
agent string
|
||||
orgID string
|
||||
}
|
||||
|
||||
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||
func NewAPI(addr string) *API {
|
||||
s := &API{
|
||||
addr: addr,
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
|
||||
func NewAPISocket(socketPath string) *API {
|
||||
s := &API{
|
||||
socketPath: socketPath,
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SetHandlers sets the callback functions for handling API requests
|
||||
func (s *API) SetHandlers(
|
||||
onConnect func(ConnectionRequest) error,
|
||||
onSwitchOrg func(SwitchOrgRequest) error,
|
||||
onDisconnect func() error,
|
||||
onExit func() error,
|
||||
) {
|
||||
s.onConnect = onConnect
|
||||
s.onSwitchOrg = onSwitchOrg
|
||||
s.onDisconnect = onDisconnect
|
||||
s.onExit = onExit
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *API) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/connect", s.handleConnect)
|
||||
mux.HandleFunc("/status", s.handleStatus)
|
||||
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||
mux.HandleFunc("/disconnect", s.handleDisconnect)
|
||||
mux.HandleFunc("/exit", s.handleExit)
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
s.server = &http.Server{
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
var err error
|
||||
if s.socketPath != "" {
|
||||
// Use platform-specific socket listener
|
||||
s.listener, err = createSocketListener(s.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create socket listener: %w", err)
|
||||
}
|
||||
logger.Info("Starting HTTP server on socket %s", s.socketPath)
|
||||
} else {
|
||||
// Use TCP listener
|
||||
s.listener, err = net.Listen("tcp", s.addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create TCP listener: %w", err)
|
||||
}
|
||||
logger.Info("Starting HTTP server on %s", s.addr)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("HTTP server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the HTTP server
|
||||
func (s *API) Stop() error {
|
||||
logger.Info("Stopping api server")
|
||||
|
||||
// Close the server first, which will also close the listener gracefully
|
||||
if s.server != nil {
|
||||
s.server.Close()
|
||||
}
|
||||
|
||||
// Clean up socket file if using Unix socket
|
||||
if s.socketPath != "" {
|
||||
cleanupSocket(s.socketPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.Connected = connected
|
||||
status.RTT = rtt
|
||||
status.LastSeen = time.Now()
|
||||
status.Endpoint = endpoint
|
||||
status.IsRelay = isRelay
|
||||
}
|
||||
|
||||
func (s *API) RemovePeerStatus(siteID int) { // remove the peer from the status map
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
delete(s.peerStatuses, siteID)
|
||||
}
|
||||
|
||||
// SetConnectionStatus sets the overall connection status
|
||||
func (s *API) SetConnectionStatus(isConnected bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
s.isConnected = isConnected
|
||||
|
||||
if isConnected {
|
||||
s.connectedAt = time.Now()
|
||||
} else {
|
||||
// Clear peer statuses when disconnected
|
||||
s.peerStatuses = make(map[int]*PeerStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *API) SetRegistered(registered bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.isRegistered = registered
|
||||
}
|
||||
|
||||
func (s *API) SetTerminated(terminated bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.isTerminated = terminated
|
||||
}
|
||||
|
||||
// ClearPeerStatuses clears all peer statuses
|
||||
func (s *API) ClearPeerStatuses() {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.peerStatuses = make(map[int]*PeerStatus)
|
||||
}
|
||||
|
||||
// SetVersion sets the olm version
|
||||
func (s *API) SetVersion(version string) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.version = version
|
||||
}
|
||||
|
||||
// SetAgent sets the olm agent
|
||||
func (s *API) SetAgent(agent string) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.agent = agent
|
||||
}
|
||||
|
||||
// SetOrgID sets the organization ID
|
||||
func (s *API) SetOrgID(orgID string) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.orgID = orgID
|
||||
}
|
||||
|
||||
// UpdatePeerRelayStatus updates only the relay status of a peer
|
||||
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.Endpoint = endpoint
|
||||
status.IsRelay = isRelay
|
||||
}
|
||||
|
||||
// UpdatePeerHolepunchStatus updates the holepunch connection status of a peer
|
||||
func (s *API) UpdatePeerHolepunchStatus(siteID int, holepunchConnected bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.HolepunchConnected = holepunchConnected
|
||||
}
|
||||
|
||||
// handleConnect handles the /connect endpoint
|
||||
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// if we are already connected, reject new connection requests
|
||||
s.statusMu.RLock()
|
||||
alreadyConnected := s.isConnected
|
||||
s.statusMu.RUnlock()
|
||||
if alreadyConnected {
|
||||
http.Error(w, "Already connected to a server. Disconnect first before connecting again.", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
var req ConnectionRequest
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
|
||||
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Call the connect handler if set
|
||||
if s.onConnect != nil {
|
||||
if err := s.onConnect(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Connection failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "connection request accepted",
|
||||
})
|
||||
}
|
||||
|
||||
// handleStatus handles the /status endpoint
|
||||
func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
s.statusMu.RLock()
|
||||
defer s.statusMu.RUnlock()
|
||||
|
||||
resp := StatusResponse{
|
||||
Connected: s.isConnected,
|
||||
Registered: s.isRegistered,
|
||||
Terminated: s.isTerminated,
|
||||
Version: s.version,
|
||||
Agent: s.agent,
|
||||
OrgID: s.orgID,
|
||||
PeerStatuses: s.peerStatuses,
|
||||
NetworkSettings: network.GetSettings(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// handleHealth handles the /health endpoint
|
||||
func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "ok",
|
||||
})
|
||||
}
|
||||
|
||||
// handleExit handles the /exit endpoint
|
||||
func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received exit request via API")
|
||||
|
||||
// Return a success response first
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "shutdown initiated",
|
||||
})
|
||||
|
||||
// Call the exit handler after responding, in a goroutine with a small delay
|
||||
// to ensure the response is fully sent before shutdown begins
|
||||
if s.onExit != nil {
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if err := s.onExit(); err != nil {
|
||||
logger.Error("Exit handler failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// handleSwitchOrg handles the /switch-org endpoint
|
||||
func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req SwitchOrgRequest
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.OrgID == "" {
|
||||
http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||
|
||||
// Call the switch org handler if set
|
||||
if s.onSwitchOrg != nil {
|
||||
if err := s.onSwitchOrg(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Org switch failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "org switch request accepted",
|
||||
})
|
||||
}
|
||||
|
||||
// handleDisconnect handles the /disconnect endpoint
|
||||
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// if we are already disconnected, reject new disconnect requests
|
||||
s.statusMu.RLock()
|
||||
alreadyDisconnected := !s.isConnected
|
||||
s.statusMu.RUnlock()
|
||||
if alreadyDisconnected {
|
||||
http.Error(w, "Not currently connected to a server.", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received disconnect request via API")
|
||||
|
||||
// Call the disconnect handler if set
|
||||
if s.onDisconnect != nil {
|
||||
if err := s.onDisconnect(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Disconnect failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "disconnect initiated",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *API) GetStatus() StatusResponse {
|
||||
return StatusResponse{
|
||||
Connected: s.isConnected,
|
||||
Registered: s.isRegistered,
|
||||
Terminated: s.isTerminated,
|
||||
Version: s.version,
|
||||
Agent: s.agent,
|
||||
OrgID: s.orgID,
|
||||
PeerStatuses: s.peerStatuses,
|
||||
NetworkSettings: network.GetSettings(),
|
||||
}
|
||||
}
|
||||
50
api/api_unix.go
Normal file
50
api/api_unix.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// createSocketListener creates a Unix domain socket listener
|
||||
func createSocketListener(socketPath string) (net.Listener, error) {
|
||||
// Ensure the directory exists
|
||||
dir := filepath.Dir(socketPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create socket directory: %w", err)
|
||||
}
|
||||
|
||||
// Remove existing socket file if it exists
|
||||
if err := os.RemoveAll(socketPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to remove existing socket: %w", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen on Unix socket: %w", err)
|
||||
}
|
||||
|
||||
// Set socket permissions to allow access
|
||||
if err := os.Chmod(socketPath, 0666); err != nil {
|
||||
listener.Close()
|
||||
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Created Unix socket at %s", socketPath)
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// cleanupSocket removes the Unix socket file
|
||||
func cleanupSocket(socketPath string) {
|
||||
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
|
||||
logger.Error("Failed to remove socket file %s: %v", socketPath, err)
|
||||
} else {
|
||||
logger.Debug("Removed Unix socket at %s", socketPath)
|
||||
}
|
||||
}
|
||||
41
api/api_windows.go
Normal file
41
api/api_windows.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// createSocketListener creates a Windows named pipe listener
|
||||
func createSocketListener(pipePath string) (net.Listener, error) {
|
||||
// Ensure the pipe path has the correct format
|
||||
if pipePath[0] != '\\' {
|
||||
pipePath = `\\.\pipe\` + pipePath
|
||||
}
|
||||
|
||||
// Create a pipe configuration that allows everyone to write
|
||||
config := &winio.PipeConfig{
|
||||
// Set security descriptor to allow everyone full access
|
||||
// This SDDL string grants full access to Everyone (WD) and to the current owner (OW)
|
||||
SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)",
|
||||
}
|
||||
|
||||
// Create a named pipe listener using go-winio with the configuration
|
||||
listener, err := winio.ListenPipe(pipePath, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen on named pipe: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Created named pipe at %s with write access for everyone", pipePath)
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up
|
||||
func cleanupSocket(pipePath string) {
|
||||
logger.Debug("Named pipe %s will be automatically cleaned up", pipePath)
|
||||
}
|
||||
636
config.go
Normal file
636
config.go
Normal file
@@ -0,0 +1,636 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OlmConfig holds all configuration options for the Olm client
|
||||
type OlmConfig struct {
|
||||
// Connection settings
|
||||
Endpoint string `json:"endpoint"`
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
OrgID string `json:"org"`
|
||||
UserToken string `json:"userToken"`
|
||||
|
||||
// Network settings
|
||||
MTU int `json:"mtu"`
|
||||
DNS string `json:"dns"`
|
||||
UpstreamDNS []string `json:"upstreamDNS"`
|
||||
InterfaceName string `json:"interface"`
|
||||
|
||||
// Logging
|
||||
LogLevel string `json:"logLevel"`
|
||||
|
||||
// HTTP server
|
||||
EnableAPI bool `json:"enableApi"`
|
||||
HTTPAddr string `json:"httpAddr"`
|
||||
SocketPath string `json:"socketPath"`
|
||||
|
||||
// Ping settings
|
||||
PingInterval string `json:"pingInterval"`
|
||||
PingTimeout string `json:"pingTimeout"`
|
||||
|
||||
// Advanced
|
||||
Holepunch bool `json:"holepunch"`
|
||||
TlsClientCert string `json:"tlsClientCert"`
|
||||
OverrideDNS bool `json:"overrideDNS"`
|
||||
DisableRelay bool `json:"disableRelay"`
|
||||
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
|
||||
|
||||
// Parsed values (not in JSON)
|
||||
PingIntervalDuration time.Duration `json:"-"`
|
||||
PingTimeoutDuration time.Duration `json:"-"`
|
||||
|
||||
// Source tracking (not in JSON)
|
||||
sources map[string]string `json:"-"`
|
||||
|
||||
Version string
|
||||
}
|
||||
|
||||
// ConfigSource tracks where each config value came from
|
||||
type ConfigSource string
|
||||
|
||||
const (
|
||||
SourceDefault ConfigSource = "default"
|
||||
SourceFile ConfigSource = "file"
|
||||
SourceEnv ConfigSource = "environment"
|
||||
SourceCLI ConfigSource = "cli"
|
||||
)
|
||||
|
||||
// DefaultConfig returns a config with default values
|
||||
func DefaultConfig() *OlmConfig {
|
||||
// Set OS-specific socket path
|
||||
var socketPath string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
socketPath = "olm"
|
||||
default: // darwin, linux, and others
|
||||
socketPath = "/var/run/olm.sock"
|
||||
}
|
||||
|
||||
config := &OlmConfig{
|
||||
MTU: 1280,
|
||||
DNS: "8.8.8.8",
|
||||
UpstreamDNS: []string{"8.8.8.8:53"},
|
||||
LogLevel: "INFO",
|
||||
InterfaceName: "olm",
|
||||
EnableAPI: false,
|
||||
SocketPath: socketPath,
|
||||
PingInterval: "3s",
|
||||
PingTimeout: "5s",
|
||||
Holepunch: false,
|
||||
// DoNotCreateNewClient: false,
|
||||
sources: make(map[string]string),
|
||||
}
|
||||
|
||||
// Track default sources
|
||||
config.sources["mtu"] = string(SourceDefault)
|
||||
config.sources["dns"] = string(SourceDefault)
|
||||
config.sources["upstreamDNS"] = string(SourceDefault)
|
||||
config.sources["logLevel"] = string(SourceDefault)
|
||||
config.sources["interface"] = string(SourceDefault)
|
||||
config.sources["enableApi"] = string(SourceDefault)
|
||||
config.sources["httpAddr"] = string(SourceDefault)
|
||||
config.sources["socketPath"] = string(SourceDefault)
|
||||
config.sources["pingInterval"] = string(SourceDefault)
|
||||
config.sources["pingTimeout"] = string(SourceDefault)
|
||||
config.sources["holepunch"] = string(SourceDefault)
|
||||
config.sources["overrideDNS"] = string(SourceDefault)
|
||||
config.sources["disableRelay"] = string(SourceDefault)
|
||||
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// getOlmConfigPath returns the path to the olm config file
|
||||
func getOlmConfigPath() string {
|
||||
configFile := os.Getenv("CONFIG_FILE")
|
||||
if configFile != "" {
|
||||
return configFile
|
||||
}
|
||||
|
||||
var configDir string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client")
|
||||
case "windows":
|
||||
configDir = filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "olm-client")
|
||||
default: // linux and others
|
||||
configDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
fmt.Printf("Warning: Failed to create config directory: %v\n", err)
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, "config.json")
|
||||
}
|
||||
|
||||
// LoadConfig loads configuration from file, env vars, and CLI args
|
||||
// Priority: CLI args > Env vars > Config file > Defaults
|
||||
// Returns: (config, showVersion, showConfig, error)
|
||||
func LoadConfig(args []string) (*OlmConfig, bool, bool, error) {
|
||||
// Start with defaults
|
||||
config := DefaultConfig()
|
||||
|
||||
// Load from config file (if exists)
|
||||
fileConfig, err := loadConfigFromFile()
|
||||
if err != nil {
|
||||
return nil, false, false, fmt.Errorf("failed to load config file: %w", err)
|
||||
}
|
||||
if fileConfig != nil {
|
||||
mergeConfigs(config, fileConfig)
|
||||
}
|
||||
|
||||
// Override with environment variables
|
||||
loadConfigFromEnv(config)
|
||||
|
||||
// Override with CLI arguments
|
||||
showVersion, showConfig, err := loadConfigFromCLI(config, args)
|
||||
if err != nil {
|
||||
return nil, false, false, err
|
||||
}
|
||||
|
||||
// Parse duration strings
|
||||
if err := config.parseDurations(); err != nil {
|
||||
return nil, false, false, err
|
||||
}
|
||||
|
||||
return config, showVersion, showConfig, nil
|
||||
}
|
||||
|
||||
// loadConfigFromFile loads configuration from the JSON config file
|
||||
func loadConfigFromFile() (*OlmConfig, error) {
|
||||
configPath := getOlmConfigPath()
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil // File doesn't exist, not an error
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config OlmConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// loadConfigFromEnv loads configuration from environment variables
|
||||
func loadConfigFromEnv(config *OlmConfig) {
|
||||
if val := os.Getenv("PANGOLIN_ENDPOINT"); val != "" {
|
||||
config.Endpoint = val
|
||||
config.sources["endpoint"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("OLM_ID"); val != "" {
|
||||
config.ID = val
|
||||
config.sources["id"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("OLM_SECRET"); val != "" {
|
||||
config.Secret = val
|
||||
config.sources["secret"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("ORG"); val != "" {
|
||||
config.OrgID = val
|
||||
config.sources["org"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("USER_TOKEN"); val != "" {
|
||||
config.UserToken = val
|
||||
config.sources["userToken"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("MTU"); val != "" {
|
||||
if mtu, err := strconv.Atoi(val); err == nil {
|
||||
config.MTU = mtu
|
||||
config.sources["mtu"] = string(SourceEnv)
|
||||
} else {
|
||||
fmt.Printf("Invalid MTU value: %s, keeping current value\n", val)
|
||||
}
|
||||
}
|
||||
if val := os.Getenv("DNS"); val != "" {
|
||||
config.DNS = val
|
||||
config.sources["dns"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("UPSTREAM_DNS"); val != "" {
|
||||
config.UpstreamDNS = []string{val}
|
||||
config.sources["upstreamDNS"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("LOG_LEVEL"); val != "" {
|
||||
config.LogLevel = val
|
||||
config.sources["logLevel"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("INTERFACE"); val != "" {
|
||||
config.InterfaceName = val
|
||||
config.sources["interface"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("HTTP_ADDR"); val != "" {
|
||||
config.HTTPAddr = val
|
||||
config.sources["httpAddr"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("PING_INTERVAL"); val != "" {
|
||||
config.PingInterval = val
|
||||
config.sources["pingInterval"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("PING_TIMEOUT"); val != "" {
|
||||
config.PingTimeout = val
|
||||
config.sources["pingTimeout"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("ENABLE_API"); val == "true" {
|
||||
config.EnableAPI = true
|
||||
config.sources["enableApi"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("SOCKET_PATH"); val != "" {
|
||||
config.SocketPath = val
|
||||
config.sources["socketPath"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("HOLEPUNCH"); val == "true" {
|
||||
config.Holepunch = true
|
||||
config.sources["holepunch"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("OVERRIDE_DNS"); val == "true" {
|
||||
config.OverrideDNS = true
|
||||
config.sources["overrideDNS"] = string(SourceEnv)
|
||||
}
|
||||
if val := os.Getenv("DISABLE_RELAY"); val == "true" {
|
||||
config.DisableRelay = true
|
||||
config.sources["disableRelay"] = string(SourceEnv)
|
||||
}
|
||||
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
|
||||
// config.DoNotCreateNewClient = true
|
||||
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
|
||||
// }
|
||||
}
|
||||
|
||||
// loadConfigFromCLI loads configuration from command-line arguments
|
||||
func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
|
||||
|
||||
// Store original values to detect changes
|
||||
origValues := map[string]interface{}{
|
||||
"endpoint": config.Endpoint,
|
||||
"id": config.ID,
|
||||
"secret": config.Secret,
|
||||
"org": config.OrgID,
|
||||
"userToken": config.UserToken,
|
||||
"mtu": config.MTU,
|
||||
"dns": config.DNS,
|
||||
"upstreamDNS": fmt.Sprintf("%v", config.UpstreamDNS),
|
||||
"logLevel": config.LogLevel,
|
||||
"interface": config.InterfaceName,
|
||||
"httpAddr": config.HTTPAddr,
|
||||
"socketPath": config.SocketPath,
|
||||
"pingInterval": config.PingInterval,
|
||||
"pingTimeout": config.PingTimeout,
|
||||
"enableApi": config.EnableAPI,
|
||||
"holepunch": config.Holepunch,
|
||||
"overrideDNS": config.OverrideDNS,
|
||||
"disableRelay": config.DisableRelay,
|
||||
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||
}
|
||||
|
||||
// Define flags
|
||||
serviceFlags.StringVar(&config.Endpoint, "endpoint", config.Endpoint, "Endpoint of your Pangolin server")
|
||||
serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID")
|
||||
serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret")
|
||||
serviceFlags.StringVar(&config.OrgID, "org", config.OrgID, "Organization ID")
|
||||
serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)")
|
||||
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
|
||||
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
|
||||
var upstreamDNSFlag string
|
||||
serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8:53)")
|
||||
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
|
||||
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
|
||||
serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)")
|
||||
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
|
||||
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
|
||||
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
|
||||
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
|
||||
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings")
|
||||
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
|
||||
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
|
||||
|
||||
version := serviceFlags.Bool("version", false, "Print the version")
|
||||
showConfig := serviceFlags.Bool("show-config", false, "Show configuration sources and exit")
|
||||
|
||||
// Parse the arguments
|
||||
if err := serviceFlags.Parse(args); err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
// Parse upstream DNS flag if provided
|
||||
if upstreamDNSFlag != "" {
|
||||
config.UpstreamDNS = []string{}
|
||||
for _, dns := range splitComma(upstreamDNSFlag) {
|
||||
if dns != "" {
|
||||
config.UpstreamDNS = append(config.UpstreamDNS, dns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track which values were changed by CLI args
|
||||
if config.Endpoint != origValues["endpoint"].(string) {
|
||||
config.sources["endpoint"] = string(SourceCLI)
|
||||
}
|
||||
if config.ID != origValues["id"].(string) {
|
||||
config.sources["id"] = string(SourceCLI)
|
||||
}
|
||||
if config.Secret != origValues["secret"].(string) {
|
||||
config.sources["secret"] = string(SourceCLI)
|
||||
}
|
||||
if config.OrgID != origValues["org"].(string) {
|
||||
config.sources["org"] = string(SourceCLI)
|
||||
}
|
||||
if config.UserToken != origValues["userToken"].(string) {
|
||||
config.sources["userToken"] = string(SourceCLI)
|
||||
}
|
||||
if config.MTU != origValues["mtu"].(int) {
|
||||
config.sources["mtu"] = string(SourceCLI)
|
||||
}
|
||||
if config.DNS != origValues["dns"].(string) {
|
||||
config.sources["dns"] = string(SourceCLI)
|
||||
}
|
||||
if fmt.Sprintf("%v", config.UpstreamDNS) != origValues["upstreamDNS"].(string) {
|
||||
config.sources["upstreamDNS"] = string(SourceCLI)
|
||||
}
|
||||
if config.LogLevel != origValues["logLevel"].(string) {
|
||||
config.sources["logLevel"] = string(SourceCLI)
|
||||
}
|
||||
if config.InterfaceName != origValues["interface"].(string) {
|
||||
config.sources["interface"] = string(SourceCLI)
|
||||
}
|
||||
if config.HTTPAddr != origValues["httpAddr"].(string) {
|
||||
config.sources["httpAddr"] = string(SourceCLI)
|
||||
}
|
||||
if config.SocketPath != origValues["socketPath"].(string) {
|
||||
config.sources["socketPath"] = string(SourceCLI)
|
||||
}
|
||||
if config.PingInterval != origValues["pingInterval"].(string) {
|
||||
config.sources["pingInterval"] = string(SourceCLI)
|
||||
}
|
||||
if config.PingTimeout != origValues["pingTimeout"].(string) {
|
||||
config.sources["pingTimeout"] = string(SourceCLI)
|
||||
}
|
||||
if config.EnableAPI != origValues["enableApi"].(bool) {
|
||||
config.sources["enableApi"] = string(SourceCLI)
|
||||
}
|
||||
if config.Holepunch != origValues["holepunch"].(bool) {
|
||||
config.sources["holepunch"] = string(SourceCLI)
|
||||
}
|
||||
if config.OverrideDNS != origValues["overrideDNS"].(bool) {
|
||||
config.sources["overrideDNS"] = string(SourceCLI)
|
||||
}
|
||||
if config.DisableRelay != origValues["disableRelay"].(bool) {
|
||||
config.sources["disableRelay"] = string(SourceCLI)
|
||||
}
|
||||
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
|
||||
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
|
||||
// }
|
||||
|
||||
return *version, *showConfig, nil
|
||||
}
|
||||
|
||||
// parseDurations parses the duration strings into time.Duration
|
||||
func (c *OlmConfig) parseDurations() error {
|
||||
var err error
|
||||
|
||||
// Parse ping interval
|
||||
if c.PingInterval != "" {
|
||||
c.PingIntervalDuration, err = time.ParseDuration(c.PingInterval)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", c.PingInterval)
|
||||
c.PingIntervalDuration = 3 * time.Second
|
||||
c.PingInterval = "3s"
|
||||
}
|
||||
} else {
|
||||
c.PingIntervalDuration = 3 * time.Second
|
||||
c.PingInterval = "3s"
|
||||
}
|
||||
|
||||
// Parse ping timeout
|
||||
if c.PingTimeout != "" {
|
||||
c.PingTimeoutDuration, err = time.ParseDuration(c.PingTimeout)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", c.PingTimeout)
|
||||
c.PingTimeoutDuration = 5 * time.Second
|
||||
c.PingTimeout = "5s"
|
||||
}
|
||||
} else {
|
||||
c.PingTimeoutDuration = 5 * time.Second
|
||||
c.PingTimeout = "5s"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeConfigs merges source config into destination (only non-empty values)
|
||||
// Also tracks that these values came from a file
|
||||
func mergeConfigs(dest, src *OlmConfig) {
|
||||
if src.Endpoint != "" {
|
||||
dest.Endpoint = src.Endpoint
|
||||
dest.sources["endpoint"] = string(SourceFile)
|
||||
}
|
||||
if src.ID != "" {
|
||||
dest.ID = src.ID
|
||||
dest.sources["id"] = string(SourceFile)
|
||||
}
|
||||
if src.Secret != "" {
|
||||
dest.Secret = src.Secret
|
||||
dest.sources["secret"] = string(SourceFile)
|
||||
}
|
||||
if src.OrgID != "" {
|
||||
dest.OrgID = src.OrgID
|
||||
dest.sources["org"] = string(SourceFile)
|
||||
}
|
||||
if src.UserToken != "" {
|
||||
dest.UserToken = src.UserToken
|
||||
dest.sources["userToken"] = string(SourceFile)
|
||||
}
|
||||
if src.MTU != 0 && src.MTU != 1280 {
|
||||
dest.MTU = src.MTU
|
||||
dest.sources["mtu"] = string(SourceFile)
|
||||
}
|
||||
if src.DNS != "" && src.DNS != "8.8.8.8" {
|
||||
dest.DNS = src.DNS
|
||||
dest.sources["dns"] = string(SourceFile)
|
||||
}
|
||||
if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8:53]" {
|
||||
dest.UpstreamDNS = src.UpstreamDNS
|
||||
dest.sources["upstreamDNS"] = string(SourceFile)
|
||||
}
|
||||
if src.LogLevel != "" && src.LogLevel != "INFO" {
|
||||
dest.LogLevel = src.LogLevel
|
||||
dest.sources["logLevel"] = string(SourceFile)
|
||||
}
|
||||
if src.InterfaceName != "" && src.InterfaceName != "olm" {
|
||||
dest.InterfaceName = src.InterfaceName
|
||||
dest.sources["interface"] = string(SourceFile)
|
||||
}
|
||||
if src.HTTPAddr != "" && src.HTTPAddr != ":9452" {
|
||||
dest.HTTPAddr = src.HTTPAddr
|
||||
dest.sources["httpAddr"] = string(SourceFile)
|
||||
}
|
||||
if src.SocketPath != "" {
|
||||
// Check if it's not the default for any OS
|
||||
isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm"
|
||||
if !isDefault {
|
||||
dest.SocketPath = src.SocketPath
|
||||
dest.sources["socketPath"] = string(SourceFile)
|
||||
}
|
||||
}
|
||||
if src.PingInterval != "" && src.PingInterval != "3s" {
|
||||
dest.PingInterval = src.PingInterval
|
||||
dest.sources["pingInterval"] = string(SourceFile)
|
||||
}
|
||||
if src.PingTimeout != "" && src.PingTimeout != "5s" {
|
||||
dest.PingTimeout = src.PingTimeout
|
||||
dest.sources["pingTimeout"] = string(SourceFile)
|
||||
}
|
||||
if src.TlsClientCert != "" {
|
||||
dest.TlsClientCert = src.TlsClientCert
|
||||
dest.sources["tlsClientCert"] = string(SourceFile)
|
||||
}
|
||||
// For booleans, we always take the source value if explicitly set
|
||||
if src.EnableAPI {
|
||||
dest.EnableAPI = src.EnableAPI
|
||||
dest.sources["enableApi"] = string(SourceFile)
|
||||
}
|
||||
if src.Holepunch {
|
||||
dest.Holepunch = src.Holepunch
|
||||
dest.sources["holepunch"] = string(SourceFile)
|
||||
}
|
||||
if src.OverrideDNS {
|
||||
dest.OverrideDNS = src.OverrideDNS
|
||||
dest.sources["overrideDNS"] = string(SourceFile)
|
||||
}
|
||||
if src.DisableRelay {
|
||||
dest.DisableRelay = src.DisableRelay
|
||||
dest.sources["disableRelay"] = string(SourceFile)
|
||||
}
|
||||
// if src.DoNotCreateNewClient {
|
||||
// dest.DoNotCreateNewClient = src.DoNotCreateNewClient
|
||||
// dest.sources["doNotCreateNewClient"] = string(SourceFile)
|
||||
// }
|
||||
}
|
||||
|
||||
// SaveConfig saves the current configuration to the config file
|
||||
func SaveConfig(config *OlmConfig) error {
|
||||
configPath := getOlmConfigPath()
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
return os.WriteFile(configPath, data, 0644)
|
||||
}
|
||||
|
||||
// ShowConfig prints the configuration and the source of each value
|
||||
func (c *OlmConfig) ShowConfig() {
|
||||
configPath := getOlmConfigPath()
|
||||
|
||||
fmt.Print("\n=== Olm Configuration ===\n\n")
|
||||
fmt.Printf("Config File: %s\n", configPath)
|
||||
|
||||
// Check if config file exists
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
fmt.Printf("Config File Status: ✓ exists\n")
|
||||
} else {
|
||||
fmt.Printf("Config File Status: ✗ not found\n")
|
||||
}
|
||||
|
||||
fmt.Println("\n--- Configuration Values ---")
|
||||
fmt.Print("(Format: Setting = Value [source])\n\n")
|
||||
|
||||
// Helper to get source or default
|
||||
getSource := func(key string) string {
|
||||
if source, ok := c.sources[key]; ok {
|
||||
return source
|
||||
}
|
||||
return string(SourceDefault)
|
||||
}
|
||||
|
||||
// Helper to format value (mask secrets)
|
||||
formatValue := func(key, value string) string {
|
||||
if key == "secret" && value != "" {
|
||||
if len(value) > 8 {
|
||||
return value[:4] + "****" + value[len(value)-4:]
|
||||
}
|
||||
return "****"
|
||||
}
|
||||
if value == "" {
|
||||
return "(not set)"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// Connection settings
|
||||
fmt.Println("Connection:")
|
||||
fmt.Printf(" endpoint = %s [%s]\n", formatValue("endpoint", c.Endpoint), getSource("endpoint"))
|
||||
fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id"))
|
||||
fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret"))
|
||||
fmt.Printf(" org = %s [%s]\n", formatValue("org", c.OrgID), getSource("org"))
|
||||
fmt.Printf(" user-token = %s [%s]\n", formatValue("userToken", c.UserToken), getSource("userToken"))
|
||||
|
||||
// Network settings
|
||||
fmt.Println("\nNetwork:")
|
||||
fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu"))
|
||||
fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns"))
|
||||
fmt.Printf(" upstream-dns = %v [%s]\n", c.UpstreamDNS, getSource("upstreamDNS"))
|
||||
fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface"))
|
||||
|
||||
// Logging
|
||||
fmt.Println("\nLogging:")
|
||||
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
|
||||
|
||||
// API server
|
||||
fmt.Println("\nAPI Server:")
|
||||
fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi"))
|
||||
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
|
||||
fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath"))
|
||||
|
||||
// Timing
|
||||
fmt.Println("\nTiming:")
|
||||
fmt.Printf(" ping-interval = %s [%s]\n", c.PingInterval, getSource("pingInterval"))
|
||||
fmt.Printf(" ping-timeout = %s [%s]\n", c.PingTimeout, getSource("pingTimeout"))
|
||||
|
||||
// Advanced
|
||||
fmt.Println("\nAdvanced:")
|
||||
fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch"))
|
||||
fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS"))
|
||||
fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay"))
|
||||
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
|
||||
if c.TlsClientCert != "" {
|
||||
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
|
||||
}
|
||||
|
||||
// Source legend
|
||||
fmt.Println("\n--- Source Legend ---")
|
||||
fmt.Println(" default = Built-in default value")
|
||||
fmt.Println(" file = Loaded from config file")
|
||||
fmt.Println(" environment = Set via environment variable")
|
||||
fmt.Println(" cli = Provided as command-line argument")
|
||||
fmt.Println("\nPriority: cli > environment > file > default")
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// splitComma splits a comma-separated string into a slice of trimmed strings
|
||||
func splitComma(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
43
create_test_creds.py
Normal file
43
create_test_creds.py
Normal file
@@ -0,0 +1,43 @@
|
||||
|
||||
import requests
|
||||
|
||||
def create_olm(base_url, user_token, olm_name, user_id):
|
||||
url = f"{base_url}/api/v1/user/{user_id}/olm"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "pangolin-cli",
|
||||
"X-CSRF-Token": "x-csrf-protection",
|
||||
"Cookie": f"p_session_token={user_token}"
|
||||
}
|
||||
payload = {"name": olm_name}
|
||||
response = requests.put(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
print(f"Response Data: {data}")
|
||||
|
||||
def create_client(base_url, user_token, client_name):
|
||||
url = f"{base_url}/api/v1/api/clients"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "pangolin-cli",
|
||||
"X-CSRF-Token": "x-csrf-protection",
|
||||
"Cookie": f"p_session_token={user_token}"
|
||||
}
|
||||
payload = {"name": client_name}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
print(f"Response Data: {data}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
base_url = input("Enter base URL (e.g., http://localhost:3000): ")
|
||||
user_token = input("Enter user token: ")
|
||||
user_id = input("Enter user ID: ")
|
||||
olm_name = input("Enter OLM name: ")
|
||||
client_name = input("Enter client name: ")
|
||||
|
||||
create_olm(base_url, user_token, olm_name, user_id)
|
||||
# client_id = create_client(base_url, user_token, client_name)
|
||||
331
device/middle_device.go
Normal file
331
device/middle_device.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
// PacketHandler processes intercepted packets and returns true if packet should be dropped
|
||||
type PacketHandler func(packet []byte) bool
|
||||
|
||||
// FilterRule defines a rule for packet filtering
|
||||
type FilterRule struct {
|
||||
DestIP netip.Addr
|
||||
Handler PacketHandler
|
||||
}
|
||||
|
||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||
type MiddleDevice struct {
|
||||
tun.Device
|
||||
rules []FilterRule
|
||||
mutex sync.RWMutex
|
||||
readCh chan readResult
|
||||
injectCh chan []byte
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
bufs [][]byte
|
||||
sizes []int
|
||||
offset int
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
// NewMiddleDevice creates a new filtered TUN device wrapper
|
||||
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||
d := &MiddleDevice{
|
||||
Device: device,
|
||||
rules: make([]FilterRule, 0),
|
||||
readCh: make(chan readResult),
|
||||
injectCh: make(chan []byte, 100),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go d.pump()
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) pump() {
|
||||
const defaultOffset = 16
|
||||
batchSize := d.Device.BatchSize()
|
||||
logger.Debug("MiddleDevice: pump started")
|
||||
|
||||
for {
|
||||
// Check closed first with priority
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Allocate buffers for reading
|
||||
// We allocate new buffers for each read to avoid race conditions
|
||||
// since we pass them to the channel
|
||||
bufs := make([][]byte, batchSize)
|
||||
sizes := make([]int, batchSize)
|
||||
for i := range bufs {
|
||||
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
|
||||
}
|
||||
|
||||
n, err := d.Device.Read(bufs, sizes, defaultOffset)
|
||||
|
||||
// Check closed again after read returns
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Now try to send the result
|
||||
select {
|
||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Debug("MiddleDevice: pump exiting due to read error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
|
||||
func (d *MiddleDevice) InjectOutbound(packet []byte) {
|
||||
select {
|
||||
case d.injectCh <- packet:
|
||||
case <-d.closed:
|
||||
}
|
||||
}
|
||||
|
||||
// AddRule adds a packet filtering rule
|
||||
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
d.rules = append(d.rules, FilterRule{
|
||||
DestIP: destIP,
|
||||
Handler: handler,
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveRule removes all rules for a given destination IP
|
||||
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
newRules := make([]FilterRule, 0, len(d.rules))
|
||||
for _, rule := range d.rules {
|
||||
if rule.DestIP != destIP {
|
||||
newRules = append(newRules, rule)
|
||||
}
|
||||
}
|
||||
d.rules = newRules
|
||||
}
|
||||
|
||||
// Close stops the device
|
||||
func (d *MiddleDevice) Close() error {
|
||||
select {
|
||||
case <-d.closed:
|
||||
// Already closed
|
||||
return nil
|
||||
default:
|
||||
logger.Debug("MiddleDevice: Closing, signaling closed channel")
|
||||
close(d.closed)
|
||||
}
|
||||
logger.Debug("MiddleDevice: Closing underlying TUN device")
|
||||
err := d.Device.Close()
|
||||
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// extractDestIP extracts destination IP from packet (fast path)
|
||||
func extractDestIP(packet []byte) (netip.Addr, bool) {
|
||||
if len(packet) < 20 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
version := packet[0] >> 4
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
if len(packet) < 20 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
// Destination IP is at bytes 16-19 for IPv4
|
||||
ip := netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
|
||||
return ip, true
|
||||
case 6:
|
||||
if len(packet) < 40 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
// Destination IP is at bytes 24-39 for IPv6
|
||||
var ip16 [16]byte
|
||||
copy(ip16[:], packet[24:40])
|
||||
ip := netip.AddrFrom16(ip16)
|
||||
return ip, true
|
||||
}
|
||||
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
// Check if already closed first (non-blocking)
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
|
||||
return 0, os.ErrClosed
|
||||
default:
|
||||
}
|
||||
|
||||
// Now block waiting for data
|
||||
select {
|
||||
case res := <-d.readCh:
|
||||
if res.err != nil {
|
||||
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
||||
return 0, res.err
|
||||
}
|
||||
|
||||
// Copy packets from result to provided buffers
|
||||
count := 0
|
||||
for i := 0; i < res.n && i < len(bufs); i++ {
|
||||
// Handle offset mismatch if necessary
|
||||
// We assume the pump used defaultOffset (16)
|
||||
// If caller asks for different offset, we need to shift
|
||||
src := res.bufs[i]
|
||||
srcOffset := res.offset
|
||||
srcSize := res.sizes[i]
|
||||
|
||||
// Calculate where the packet data starts and ends in src
|
||||
pktData := src[srcOffset : srcOffset+srcSize]
|
||||
|
||||
// Ensure dest buffer is large enough
|
||||
if len(bufs[i]) < offset+len(pktData) {
|
||||
continue // Skip if buffer too small
|
||||
}
|
||||
|
||||
copy(bufs[i][offset:], pktData)
|
||||
sizes[i] = len(pktData)
|
||||
count++
|
||||
}
|
||||
n = count
|
||||
|
||||
case pkt := <-d.injectCh:
|
||||
if len(bufs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if len(bufs[0]) < offset+len(pkt) {
|
||||
return 0, nil // Buffer too small
|
||||
}
|
||||
copy(bufs[0][offset:], pkt)
|
||||
sizes[0] = len(pkt)
|
||||
n = 1
|
||||
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
|
||||
return 0, os.ErrClosed // Signal that device is closed
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
rules := d.rules
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Process packets and filter out handled ones
|
||||
writeIdx := 0
|
||||
for readIdx := 0; readIdx < n; readIdx++ {
|
||||
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
||||
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
// Can't parse, keep packet
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if packet matches any rule
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
// Packet was handled and should be dropped
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
// Keep packet
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return writeIdx, err
|
||||
}
|
||||
|
||||
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
||||
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
d.mutex.RLock()
|
||||
rules := d.rules
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return d.Device.Write(bufs, offset)
|
||||
}
|
||||
|
||||
// Filter packets going down
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
for _, buf := range bufs {
|
||||
if len(buf) <= offset {
|
||||
continue
|
||||
}
|
||||
|
||||
packet := buf[offset:]
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
// Can't parse, keep packet
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if packet matches any rule
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
// Packet was handled and should be dropped
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredBufs) == 0 {
|
||||
return len(bufs), nil // All packets were handled
|
||||
}
|
||||
|
||||
return d.Device.Write(filteredBufs, offset)
|
||||
}
|
||||
102
device/middle_device_test.go
Normal file
102
device/middle_device_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/fosrl/newt/util"
|
||||
)
|
||||
|
||||
func TestExtractDestIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
packet []byte
|
||||
wantIP string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 packet",
|
||||
packet: []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
|
||||
0x0a, 0x1e, 0x1e, 0x1e, // Dest IP: 10.30.30.30
|
||||
},
|
||||
wantIP: "10.30.30.30",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "Too short packet",
|
||||
packet: []byte{0x45, 0x00},
|
||||
wantIP: "",
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotIP, gotOk := extractDestIP(tt.packet)
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("extractDestIP() ok = %v, want %v", gotOk, tt.wantOk)
|
||||
return
|
||||
}
|
||||
if tt.wantOk {
|
||||
wantAddr := netip.MustParseAddr(tt.wantIP)
|
||||
if gotIP != wantAddr {
|
||||
t.Errorf("extractDestIP() ip = %v, want %v", gotIP, wantAddr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProtocol(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
packet []byte
|
||||
wantProto uint8
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "UDP packet",
|
||||
packet: []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01, // Protocol: UDP (17) at byte 9
|
||||
0x0a, 0x1e, 0x1e, 0x1e,
|
||||
},
|
||||
wantProto: 17,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "Too short",
|
||||
packet: []byte{0x45, 0x00},
|
||||
wantProto: 0,
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotProto, gotOk := util.GetProtocol(tt.packet)
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
|
||||
return
|
||||
}
|
||||
if gotProto != tt.wantProto {
|
||||
t.Errorf("GetProtocol() proto = %v, want %v", gotProto, tt.wantProto)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExtractDestIP(b *testing.B) {
|
||||
packet := []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
|
||||
0x0a, 0x1e, 0x1e, 0x1e,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
extractDestIP(packet)
|
||||
}
|
||||
}
|
||||
44
device/tun_unix.go
Normal file
44
device/tun_unix.go
Normal file
@@ -0,0 +1,44 @@
|
||||
//go:build !windows
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
dupTunFd, err := unix.Dup(int(tunFd))
|
||||
if err != nil {
|
||||
logger.Error("Unable to dup tun fd: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(dupTunFd, true)
|
||||
if err != nil {
|
||||
unix.Close(dupTunFd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
|
||||
device, err := tun.CreateTUNFromFile(file, mtuInt)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(interfaceName)
|
||||
}
|
||||
|
||||
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -11,15 +11,15 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
|
||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
return nil, errors.New("CreateTUNFromFile not supported on Windows")
|
||||
}
|
||||
|
||||
func uapiOpen(interfaceName string) (*os.File, error) {
|
||||
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
// On Windows, UAPIListen only takes one parameter
|
||||
return ipc.UAPIListen(interfaceName)
|
||||
}
|
||||
457
dns/dns_proxy.go
Normal file
457
dns/dns_proxy.go
Normal file
@@ -0,0 +1,457 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/device"
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
DNSPort = 53
|
||||
)
|
||||
|
||||
// DNSProxy implements a DNS proxy using gvisor netstack
|
||||
type DNSProxy struct {
|
||||
stack *stack.Stack
|
||||
ep *channel.Endpoint
|
||||
proxyIP netip.Addr
|
||||
upstreamDNS []string
|
||||
mtu int
|
||||
tunDevice tun.Device // Direct reference to underlying TUN device for responses
|
||||
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
|
||||
recordStore *DNSRecordStore // Local DNS records
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewDNSProxy creates a new DNS proxy
|
||||
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string) (*DNSProxy, error) {
|
||||
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
||||
}
|
||||
|
||||
if len(upstreamDns) == 0 {
|
||||
return nil, fmt.Errorf("at least one upstream DNS server must be specified")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
proxy := &DNSProxy{
|
||||
proxyIP: proxyIP,
|
||||
mtu: mtu,
|
||||
tunDevice: tunDevice,
|
||||
middleDevice: middleDevice,
|
||||
upstreamDNS: upstreamDns,
|
||||
recordStore: NewDNSRecordStore(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create gvisor netstack
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
HandleLocal: true,
|
||||
}
|
||||
|
||||
proxy.ep = channel.New(256, uint32(mtu), "")
|
||||
proxy.stack = stack.New(stackOpts)
|
||||
|
||||
// Create NIC
|
||||
if err := proxy.stack.CreateNIC(1, proxy.ep); err != nil {
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
// Add IP address
|
||||
// Parse the proxy IP to get the octets
|
||||
ipBytes := proxyIP.As4()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
||||
}
|
||||
|
||||
if err := proxy.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to add protocol address: %v", err)
|
||||
}
|
||||
|
||||
// Add default route
|
||||
proxy.stack.AddRoute(tcpip.Route{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// Start starts the DNS proxy and registers with the filter
|
||||
func (p *DNSProxy) Start() error {
|
||||
// Install packet filter rule
|
||||
p.middleDevice.AddRule(p.proxyIP, p.handlePacket)
|
||||
|
||||
// Start DNS listener
|
||||
p.wg.Add(2)
|
||||
go p.runDNSListener()
|
||||
go p.runPacketSender()
|
||||
|
||||
logger.Info("DNS proxy started on %s:%d", p.proxyIP.String(), DNSPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the DNS proxy
|
||||
func (p *DNSProxy) Stop() {
|
||||
if p.middleDevice != nil {
|
||||
p.middleDevice.RemoveRule(p.proxyIP)
|
||||
}
|
||||
p.cancel()
|
||||
|
||||
// Close the endpoint first to unblock any pending Read() calls in runPacketSender
|
||||
if p.ep != nil {
|
||||
p.ep.Close()
|
||||
}
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
if p.stack != nil {
|
||||
p.stack.Close()
|
||||
}
|
||||
|
||||
logger.Info("DNS proxy stopped")
|
||||
}
|
||||
|
||||
func (p *DNSProxy) GetProxyIP() netip.Addr {
|
||||
return p.proxyIP
|
||||
}
|
||||
|
||||
// handlePacket is called by the filter for packets destined to DNS proxy IP
|
||||
func (p *DNSProxy) handlePacket(packet []byte) bool {
|
||||
if len(packet) < 20 {
|
||||
return false // Don't drop, malformed
|
||||
}
|
||||
|
||||
// Quick check for UDP port 53
|
||||
proto, ok := util.GetProtocol(packet)
|
||||
if !ok || proto != 17 { // 17 = UDP
|
||||
return false // Not UDP, don't handle
|
||||
}
|
||||
|
||||
port, ok := util.GetDestPort(packet)
|
||||
if !ok || port != DNSPort {
|
||||
return false // Not DNS port
|
||||
}
|
||||
|
||||
// Inject packet into our netstack
|
||||
version := packet[0] >> 4
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
p.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||
case 6:
|
||||
p.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||
default:
|
||||
pkb.DecRef()
|
||||
return false
|
||||
}
|
||||
|
||||
pkb.DecRef()
|
||||
return true // Drop packet from normal path
|
||||
}
|
||||
|
||||
// runDNSListener listens for DNS queries on the netstack
|
||||
func (p *DNSProxy) runDNSListener() {
|
||||
defer p.wg.Done()
|
||||
|
||||
// Create UDP listener using gonet
|
||||
// Parse the proxy IP to get the octets
|
||||
ipBytes := p.proxyIP.As4()
|
||||
laddr := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4(ipBytes),
|
||||
Port: DNSPort,
|
||||
}
|
||||
|
||||
udpConn, err := gonet.DialUDP(p.stack, laddr, nil, ipv4.ProtocolNumber)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create DNS listener: %v", err)
|
||||
return
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
logger.Debug("DNS proxy listening on netstack")
|
||||
|
||||
// Handle DNS queries
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
udpConn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
n, remoteAddr, err := udpConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
logger.Error("DNS read error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
query := make([]byte, n)
|
||||
copy(query, buf[:n])
|
||||
|
||||
// Handle query in background
|
||||
go p.handleDNSQuery(udpConn, query, remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// handleDNSQuery processes a DNS query, checking local records first, then forwarding upstream
|
||||
func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clientAddr net.Addr) {
|
||||
// Parse the DNS query
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(queryData); err != nil {
|
||||
logger.Error("Failed to parse DNS query: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(msg.Question) == 0 {
|
||||
logger.Debug("DNS query has no questions")
|
||||
return
|
||||
}
|
||||
|
||||
question := msg.Question[0]
|
||||
logger.Debug("DNS query for %s (type %s)", question.Name, dns.TypeToString[question.Qtype])
|
||||
|
||||
// Check if we have local records for this query
|
||||
var response *dns.Msg
|
||||
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
|
||||
response = p.checkLocalRecords(msg, question)
|
||||
}
|
||||
|
||||
// If no local records, forward to upstream
|
||||
if response == nil {
|
||||
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
|
||||
response = p.forwardToUpstream(msg)
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
logger.Error("Failed to get DNS response for %s", question.Name)
|
||||
return
|
||||
}
|
||||
|
||||
// Pack and send response
|
||||
responseData, err := response.Pack()
|
||||
if err != nil {
|
||||
logger.Error("Failed to pack DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = udpConn.WriteTo(responseData, clientAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to send DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// checkLocalRecords checks if we have local records for the query
|
||||
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
|
||||
var recordType RecordType
|
||||
if question.Qtype == dns.TypeA {
|
||||
recordType = RecordTypeA
|
||||
} else if question.Qtype == dns.TypeAAAA {
|
||||
recordType = RecordTypeAAAA
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
ips := p.recordStore.GetRecords(question.Name, recordType)
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
|
||||
|
||||
// Create response message
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(query)
|
||||
response.Authoritative = true
|
||||
|
||||
// Add answer records
|
||||
for _, ip := range ips {
|
||||
var rr dns.RR
|
||||
if question.Qtype == dns.TypeA {
|
||||
rr = &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300, // 5 minutes
|
||||
},
|
||||
A: ip.To4(),
|
||||
}
|
||||
} else { // TypeAAAA
|
||||
rr = &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300, // 5 minutes
|
||||
},
|
||||
AAAA: ip.To16(),
|
||||
}
|
||||
}
|
||||
response.Answer = append(response.Answer, rr)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// forwardToUpstream forwards a DNS query to upstream DNS servers
|
||||
func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
|
||||
// Try primary DNS server
|
||||
response, err := p.queryUpstream(p.upstreamDNS[0], query, 2*time.Second)
|
||||
if err != nil && len(p.upstreamDNS) > 1 {
|
||||
// Try secondary DNS server
|
||||
logger.Debug("Primary DNS failed, trying secondary: %v", err)
|
||||
response, err = p.queryUpstream(p.upstreamDNS[1], query, 2*time.Second)
|
||||
if err != nil {
|
||||
logger.Error("Both DNS servers failed: %v", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
// queryUpstream sends a DNS query to upstream server using miekg/dns
|
||||
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||
client := &dns.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
response, _, err := client.Exchange(query, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// runPacketSender sends packets from netstack back to TUN
|
||||
func (p *DNSProxy) runPacketSender() {
|
||||
defer p.wg.Done()
|
||||
|
||||
// MessageTransportHeaderSize is the offset used by WireGuard device
|
||||
// for reading/writing packets to the TUN interface
|
||||
const offset = 16
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read packets from netstack endpoint
|
||||
pkt := p.ep.Read()
|
||||
if pkt == nil {
|
||||
// No packet available, small sleep to avoid busy loop
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract packet data as slices
|
||||
slices := pkt.AsSlices()
|
||||
if len(slices) > 0 {
|
||||
// Flatten all slices into a single packet buffer
|
||||
var totalSize int
|
||||
for _, slice := range slices {
|
||||
totalSize += len(slice)
|
||||
}
|
||||
|
||||
// Allocate buffer with offset space for WireGuard transport header
|
||||
// The first 'offset' bytes are reserved for the transport header
|
||||
buf := make([]byte, offset+totalSize)
|
||||
|
||||
// Copy packet data after the offset
|
||||
pos := offset
|
||||
for _, slice := range slices {
|
||||
copy(buf[pos:], slice)
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Write packet to TUN device
|
||||
// offset=16 indicates packet data starts at position 16 in the buffer
|
||||
_, err := p.tunDevice.Write([][]byte{buf}, offset)
|
||||
if err != nil {
|
||||
logger.Error("Failed to write DNS response to TUN: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
// AddDNSRecord adds a DNS record to the local store
|
||||
// domain should be a domain name (e.g., "example.com" or "example.com.")
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
|
||||
return p.recordStore.AddRecord(domain, ip)
|
||||
}
|
||||
|
||||
// RemoveDNSRecord removes a DNS record from the local store
|
||||
// If ip is nil, removes all records for the domain
|
||||
func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
|
||||
p.recordStore.RemoveRecord(domain, ip)
|
||||
}
|
||||
|
||||
// GetDNSRecords returns all IP addresses for a domain and record type
|
||||
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
|
||||
return p.recordStore.GetRecords(domain, recordType)
|
||||
}
|
||||
|
||||
// ClearDNSRecords removes all DNS records from the local store
|
||||
func (p *DNSProxy) ClearDNSRecords() {
|
||||
p.recordStore.Clear()
|
||||
}
|
||||
|
||||
func PickIPFromSubnet(subnet string) (netip.Addr, error) {
|
||||
// given a subnet in CIDR notation, pick the first usable IP
|
||||
prefix, err := netip.ParsePrefix(subnet)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invalid subnet: %w", err)
|
||||
}
|
||||
|
||||
// Pick the first usable IP address from the subnet
|
||||
ip := prefix.Addr().Next()
|
||||
if !ip.IsValid() {
|
||||
return netip.Addr{}, fmt.Errorf("no valid IP address found in subnet: %s", subnet)
|
||||
}
|
||||
|
||||
return ip, nil
|
||||
}
|
||||
166
dns/dns_records.go
Normal file
166
dns/dns_records.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// RecordType represents the type of DNS record
|
||||
type RecordType uint16
|
||||
|
||||
const (
|
||||
RecordTypeA RecordType = RecordType(dns.TypeA)
|
||||
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
|
||||
)
|
||||
|
||||
// DNSRecordStore manages local DNS records for A and AAAA queries
|
||||
type DNSRecordStore struct {
|
||||
mu sync.RWMutex
|
||||
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
||||
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
||||
}
|
||||
|
||||
// NewDNSRecordStore creates a new DNS record store
|
||||
func NewDNSRecordStore() *DNSRecordStore {
|
||||
return &DNSRecordStore{
|
||||
aRecords: make(map[string][]net.IP),
|
||||
aaaaRecords: make(map[string][]net.IP),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRecord adds a DNS record mapping (A or AAAA)
|
||||
// domain should be in FQDN format (e.g., "example.com.")
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase
|
||||
domain = dns.Fqdn(domain)
|
||||
|
||||
if ip.To4() != nil {
|
||||
// IPv4 address
|
||||
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
||||
} else if ip.To16() != nil {
|
||||
// IPv6 address
|
||||
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
||||
} else {
|
||||
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRecord removes a specific DNS record mapping
|
||||
// If ip is nil, removes all records for the domain
|
||||
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase
|
||||
domain = dns.Fqdn(domain)
|
||||
|
||||
if ip == nil {
|
||||
// Remove all records for this domain
|
||||
delete(s.aRecords, domain)
|
||||
delete(s.aaaaRecords, domain)
|
||||
return
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
// Remove specific IPv4 address
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
s.aRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aRecords[domain]) == 0 {
|
||||
delete(s.aRecords, domain)
|
||||
}
|
||||
}
|
||||
} else if ip.To16() != nil {
|
||||
// Remove specific IPv6 address
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
s.aaaaRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aaaaRecords[domain]) == 0 {
|
||||
delete(s.aaaaRecords, domain)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRecords returns all IP addresses for a domain and record type
|
||||
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = dns.Fqdn(domain)
|
||||
|
||||
var records []net.IP
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
}
|
||||
case RecordTypeAAAA:
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// HasRecord checks if a domain has any records of the specified type
|
||||
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = dns.Fqdn(domain)
|
||||
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
_, ok := s.aRecords[domain]
|
||||
return ok
|
||||
case RecordTypeAAAA:
|
||||
_, ok := s.aaaaRecords[domain]
|
||||
return ok
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Clear removes all records from the store
|
||||
func (s *DNSRecordStore) Clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.aRecords = make(map[string][]net.IP)
|
||||
s.aaaaRecords = make(map[string][]net.IP)
|
||||
}
|
||||
|
||||
// removeIP is a helper function to remove a specific IP from a slice
|
||||
func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
|
||||
result := make([]net.IP, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
if !ip.Equal(toRemove) {
|
||||
result = append(result, ip)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
68
dns/override/dns_override_darwin.go
Normal file
68
dns/override/dns_override_darwin.go
Normal file
@@ -0,0 +1,68 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/dns"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
|
||||
// Uses scutil for DNS configuration
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
if dnsProxy == nil {
|
||||
return fmt.Errorf("DNS proxy is nil")
|
||||
}
|
||||
|
||||
var err error
|
||||
configurator, err = platform.NewDarwinDNSConfigurator()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Darwin DNS configurator: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Using Darwin scutil DNS configurator")
|
||||
|
||||
// Get current DNS servers before changing
|
||||
currentDNS, err := configurator.GetCurrentDNS()
|
||||
if err != nil {
|
||||
logger.Warn("Could not get current DNS: %v", err)
|
||||
} else {
|
||||
logger.Info("Current DNS servers: %v", currentDNS)
|
||||
}
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
dnsProxy.GetProxyIP(),
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
originalDNS, err := configurator.SetDNS(newDNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Original DNS servers backed up: %v", originalDNS)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride restores the original DNS configuration
|
||||
func RestoreDNSOverride() error {
|
||||
if configurator == nil {
|
||||
logger.Debug("No DNS configurator to restore")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Restoring original DNS configuration")
|
||||
if err := configurator.RestoreDNS(); err != nil {
|
||||
return fmt.Errorf("failed to restore DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("DNS configuration restored successfully")
|
||||
return nil
|
||||
}
|
||||
105
dns/override/dns_override_unix.go
Normal file
105
dns/override/dns_override_unix.go
Normal file
@@ -0,0 +1,105 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/dns"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
|
||||
// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
if dnsProxy == nil {
|
||||
return fmt.Errorf("DNS proxy is nil")
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability
|
||||
managerType := platform.DetectDNSManager(interfaceName)
|
||||
logger.Info("Detected DNS manager: %s", managerType.String())
|
||||
|
||||
// Create configurator based on detected manager
|
||||
switch managerType {
|
||||
case platform.SystemdResolvedManager:
|
||||
configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using systemd-resolved DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err)
|
||||
|
||||
case platform.NetworkManagerManager:
|
||||
configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using NetworkManager DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err)
|
||||
|
||||
case platform.ResolvconfManager:
|
||||
configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName)
|
||||
if err == nil {
|
||||
logger.Info("Using resolvconf DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
}
|
||||
logger.Warn("Failed to create resolvconf configurator: %v, falling back", err)
|
||||
}
|
||||
|
||||
// Fall back to direct file manipulation
|
||||
configurator, err = platform.NewFileDNSConfigurator()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file DNS configurator: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Using file-based DNS configurator")
|
||||
return setDNS(dnsProxy, configurator)
|
||||
}
|
||||
|
||||
// setDNS is a helper function to set DNS and log the results
|
||||
func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
|
||||
// Get current DNS servers before changing
|
||||
currentDNS, err := conf.GetCurrentDNS()
|
||||
if err != nil {
|
||||
logger.Warn("Could not get current DNS: %v", err)
|
||||
} else {
|
||||
logger.Info("Current DNS servers: %v", currentDNS)
|
||||
}
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
dnsProxy.GetProxyIP(),
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
originalDNS, err := conf.SetDNS(newDNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Original DNS servers backed up: %v", originalDNS)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride restores the original DNS configuration
|
||||
func RestoreDNSOverride() error {
|
||||
if configurator == nil {
|
||||
logger.Debug("No DNS configurator to restore")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Restoring original DNS configuration")
|
||||
if err := configurator.RestoreDNS(); err != nil {
|
||||
return fmt.Errorf("failed to restore DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("DNS configuration restored successfully")
|
||||
return nil
|
||||
}
|
||||
68
dns/override/dns_override_windows.go
Normal file
68
dns/override/dns_override_windows.go
Normal file
@@ -0,0 +1,68 @@
|
||||
//go:build windows
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/dns"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
|
||||
// Uses registry-based configuration (automatically extracts interface GUID)
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
if dnsProxy == nil {
|
||||
return fmt.Errorf("DNS proxy is nil")
|
||||
}
|
||||
|
||||
var err error
|
||||
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Windows DNS configurator: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Using Windows registry DNS configurator for interface: %s", interfaceName)
|
||||
|
||||
// Get current DNS servers before changing
|
||||
currentDNS, err := configurator.GetCurrentDNS()
|
||||
if err != nil {
|
||||
logger.Warn("Could not get current DNS: %v", err)
|
||||
} else {
|
||||
logger.Info("Current DNS servers: %v", currentDNS)
|
||||
}
|
||||
|
||||
// Set new DNS servers to point to our proxy
|
||||
newDNS := []netip.Addr{
|
||||
dnsProxy.GetProxyIP(),
|
||||
}
|
||||
|
||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||
originalDNS, err := configurator.SetDNS(newDNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Original DNS servers backed up: %v", originalDNS)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDNSOverride restores the original DNS configuration
|
||||
func RestoreDNSOverride() error {
|
||||
if configurator == nil {
|
||||
logger.Debug("No DNS configurator to restore")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Restoring original DNS configuration")
|
||||
if err := configurator.RestoreDNS(); err != nil {
|
||||
return fmt.Errorf("failed to restore DNS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("DNS configuration restored successfully")
|
||||
return nil
|
||||
}
|
||||
268
dns/platform/darwin.go
Normal file
268
dns/platform/darwin.go
Normal file
@@ -0,0 +1,268 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
scutilPath = "/usr/sbin/scutil"
|
||||
dscacheutilPath = "/usr/bin/dscacheutil"
|
||||
|
||||
dnsStateKeyFormat = "State:/Network/Service/Olm-%s/DNS"
|
||||
globalIPv4State = "State:/Network/Global/IPv4"
|
||||
primaryServiceFormat = "State:/Network/Service/%s/DNS"
|
||||
|
||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
|
||||
keyServerAddresses = "ServerAddresses"
|
||||
keyServerPort = "ServerPort"
|
||||
arraySymbol = "* "
|
||||
digitSymbol = "# "
|
||||
)
|
||||
|
||||
// DarwinDNSConfigurator manages DNS settings on macOS using scutil
|
||||
type DarwinDNSConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewDarwinDNSConfigurator creates a new macOS DNS configurator
|
||||
func NewDarwinDNSConfigurator() (*DarwinDNSConfigurator, error) {
|
||||
return &DarwinDNSConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (d *DarwinDNSConfigurator) Name() string {
|
||||
return "darwin-scutil"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (d *DarwinDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := d.GetCurrentDNS()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current DNS: %w", err)
|
||||
}
|
||||
|
||||
// Store original state
|
||||
d.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: d.Name(),
|
||||
}
|
||||
|
||||
// Set new DNS servers
|
||||
if err := d.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := d.flushDNSCache(); err != nil {
|
||||
// Non-fatal, just log
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (d *DarwinDNSConfigurator) RestoreDNS() error {
|
||||
// Remove all created keys
|
||||
for key := range d.createdKeys {
|
||||
if err := d.removeKey(key); err != nil {
|
||||
return fmt.Errorf("remove key %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := d.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (d *DarwinDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
primaryServiceKey, err := d.getPrimaryServiceKey()
|
||||
if err != nil || primaryServiceKey == "" {
|
||||
return nil, fmt.Errorf("get primary service: %w", err)
|
||||
}
|
||||
|
||||
dnsKey := fmt.Sprintf(primaryServiceFormat, primaryServiceKey)
|
||||
cmd := fmt.Sprintf("show %s\n", dnsKey)
|
||||
|
||||
output, err := d.runScutil(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("run scutil: %w", err)
|
||||
}
|
||||
|
||||
servers := d.parseServerAddresses(output)
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// applyDNSServers applies the DNS server configuration
|
||||
func (d *DarwinDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
key := fmt.Sprintf(dnsStateKeyFormat, "Override")
|
||||
|
||||
// Use SupplementalMatchDomains with empty string to match ALL domains
|
||||
// This is the key to making DNS override work on macOS
|
||||
// Setting SupplementalMatchDomainsNoSearch to 0 enables search domain behavior
|
||||
err := d.addDNSState(key, "\"\"", servers[0], 53, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set DNS servers: %w", err)
|
||||
}
|
||||
|
||||
d.createdKeys[key] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addDNSState adds a DNS state entry with the specified configuration
|
||||
func (d *DarwinDNSConfigurator) addDNSState(state, domains string, dnsServer netip.Addr, port int, enableSearch bool) error {
|
||||
noSearch := "1"
|
||||
if enableSearch {
|
||||
noSearch = "0"
|
||||
}
|
||||
|
||||
// Build the scutil command following NetBird's approach
|
||||
var commands strings.Builder
|
||||
commands.WriteString("d.init\n")
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keySupplementalMatchDomains, arraySymbol, domains))
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keySupplementalMatchDomainsNoSearch, digitSymbol, noSearch))
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keyServerAddresses, arraySymbol, dnsServer.String()))
|
||||
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keyServerPort, digitSymbol, strconv.Itoa(port)))
|
||||
commands.WriteString(fmt.Sprintf("set %s\n", state))
|
||||
|
||||
if _, err := d.runScutil(commands.String()); err != nil {
|
||||
return fmt.Errorf("applying state for domains %s, error: %w", domains, err)
|
||||
}
|
||||
|
||||
logger.Info("Added DNS override with server %s:%d for domains: %s", dnsServer.String(), port, domains)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeKey removes a DNS configuration key
|
||||
func (d *DarwinDNSConfigurator) removeKey(key string) error {
|
||||
cmd := fmt.Sprintf("remove %s\n", key)
|
||||
|
||||
if _, err := d.runScutil(cmd); err != nil {
|
||||
return fmt.Errorf("remove key: %w", err)
|
||||
}
|
||||
|
||||
delete(d.createdKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPrimaryServiceKey gets the primary network service key
|
||||
func (d *DarwinDNSConfigurator) getPrimaryServiceKey() (string, error) {
|
||||
cmd := fmt.Sprintf("show %s\n", globalIPv4State)
|
||||
|
||||
output, err := d.runScutil(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("run scutil: %w", err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(output))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.Contains(line, "PrimaryService") {
|
||||
parts := strings.Split(line, ":")
|
||||
if len(parts) >= 2 {
|
||||
return strings.TrimSpace(parts[1]), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("scan output: %w", err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("primary service not found")
|
||||
}
|
||||
|
||||
// parseServerAddresses parses DNS server addresses from scutil output
|
||||
func (d *DarwinDNSConfigurator) parseServerAddresses(output []byte) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
inServerArray := false
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(output))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if strings.HasPrefix(line, "ServerAddresses : <array> {") {
|
||||
inServerArray = true
|
||||
continue
|
||||
}
|
||||
|
||||
if line == "}" {
|
||||
inServerArray = false
|
||||
continue
|
||||
}
|
||||
|
||||
if inServerArray {
|
||||
// Line format: "0 : 8.8.8.8"
|
||||
parts := strings.Split(line, " : ")
|
||||
if len(parts) >= 2 {
|
||||
if addr, err := netip.ParseAddr(parts[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// flushDNSCache flushes the system DNS cache
|
||||
func (d *DarwinDNSConfigurator) flushDNSCache() error {
|
||||
logger.Debug("Flushing dscacheutil cache")
|
||||
cmd := exec.Command(dscacheutilPath, "-flushcache")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("flush cache: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Flushing mDNSResponder cache")
|
||||
|
||||
cmd = exec.Command("killall", "-HUP", "mDNSResponder")
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Non-fatal, mDNSResponder might not be running
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runScutil executes an scutil command
|
||||
func (d *DarwinDNSConfigurator) runScutil(commands string) ([]byte, error) {
|
||||
// Wrap commands with open/quit
|
||||
wrapped := fmt.Sprintf("open\n%squit\n", commands)
|
||||
|
||||
logger.Debug("Running scutil with commands:\n%s\n", wrapped)
|
||||
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader(wrapped)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scutil command failed: %w, output: %s", err, output)
|
||||
}
|
||||
|
||||
logger.Debug("scutil output:\n%s\n", output)
|
||||
|
||||
return output, nil
|
||||
}
|
||||
158
dns/platform/detect_unix.go
Normal file
158
dns/platform/detect_unix.go
Normal file
@@ -0,0 +1,158 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
const defaultResolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// DNSManagerType represents the type of DNS manager detected
|
||||
type DNSManagerType int
|
||||
|
||||
const (
|
||||
// UnknownManager indicates we couldn't determine the DNS manager
|
||||
UnknownManager DNSManagerType = iota
|
||||
// SystemdResolvedManager indicates systemd-resolved is managing DNS
|
||||
SystemdResolvedManager
|
||||
// NetworkManagerManager indicates NetworkManager is managing DNS
|
||||
NetworkManagerManager
|
||||
// ResolvconfManager indicates resolvconf is managing DNS
|
||||
ResolvconfManager
|
||||
// FileManager indicates direct file management (no DNS manager)
|
||||
FileManager
|
||||
)
|
||||
|
||||
// DetectDNSManagerFromFile reads /etc/resolv.conf to determine which DNS manager is in use
|
||||
// This provides a hint based on comments in the file, similar to Netbird's approach
|
||||
func DetectDNSManagerFromFile() DNSManagerType {
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return UnknownManager
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
if len(text) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// If we hit a non-comment line, default to file-based
|
||||
if text[0] != '#' {
|
||||
return FileManager
|
||||
}
|
||||
|
||||
// Check for DNS manager signatures in comments
|
||||
if strings.Contains(text, "NetworkManager") {
|
||||
return NetworkManagerManager
|
||||
}
|
||||
|
||||
if strings.Contains(text, "systemd-resolved") {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
return ResolvconfManager
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return UnknownManager
|
||||
}
|
||||
|
||||
// No indicators found, assume file-based management
|
||||
return FileManager
|
||||
}
|
||||
|
||||
// String returns a human-readable name for the DNS manager type
|
||||
func (d DNSManagerType) String() string {
|
||||
switch d {
|
||||
case SystemdResolvedManager:
|
||||
return "systemd-resolved"
|
||||
case NetworkManagerManager:
|
||||
return "NetworkManager"
|
||||
case ResolvconfManager:
|
||||
return "resolvconf"
|
||||
case FileManager:
|
||||
return "file"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// DetectDNSManager combines file detection with runtime availability checks
|
||||
// to determine the best DNS configurator to use
|
||||
func DetectDNSManager(interfaceName string) DNSManagerType {
|
||||
// First check what the file suggests
|
||||
fileHint := DetectDNSManagerFromFile()
|
||||
|
||||
// Verify the hint with runtime checks
|
||||
switch fileHint {
|
||||
case SystemdResolvedManager:
|
||||
// Verify systemd-resolved is actually running
|
||||
if IsSystemdResolvedAvailable() {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
logger.Warn("dns platform: Found systemd-resolved but it is not running. Falling back to file...")
|
||||
os.Exit(0)
|
||||
return FileManager
|
||||
|
||||
case NetworkManagerManager:
|
||||
// Verify NetworkManager is actually running
|
||||
if IsNetworkManagerAvailable() {
|
||||
// Check if NetworkManager is delegating to systemd-resolved
|
||||
if !IsNetworkManagerDNSModeSupported() {
|
||||
logger.Info("NetworkManager is delegating DNS to systemd-resolved, using systemd-resolved configurator")
|
||||
if IsSystemdResolvedAvailable() {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
}
|
||||
return NetworkManagerManager
|
||||
}
|
||||
logger.Warn("dns platform: Found network manager but it is not running. Falling back to file...")
|
||||
return FileManager
|
||||
|
||||
case ResolvconfManager:
|
||||
// Verify resolvconf is available
|
||||
if IsResolvconfAvailable() {
|
||||
return ResolvconfManager
|
||||
}
|
||||
// If resolvconf is mentioned but not available, fall back to file
|
||||
return FileManager
|
||||
|
||||
case FileManager:
|
||||
// File suggests direct file management
|
||||
// But we should still check if a manager is available that wasn't mentioned
|
||||
if IsSystemdResolvedAvailable() && interfaceName != "" {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
if IsNetworkManagerAvailable() && interfaceName != "" {
|
||||
return NetworkManagerManager
|
||||
}
|
||||
if IsResolvconfAvailable() && interfaceName != "" {
|
||||
return ResolvconfManager
|
||||
}
|
||||
return FileManager
|
||||
|
||||
default:
|
||||
// Unknown - do runtime detection
|
||||
if IsSystemdResolvedAvailable() && interfaceName != "" {
|
||||
return SystemdResolvedManager
|
||||
}
|
||||
if IsNetworkManagerAvailable() && interfaceName != "" {
|
||||
return NetworkManagerManager
|
||||
}
|
||||
if IsResolvconfAvailable() && interfaceName != "" {
|
||||
return ResolvconfManager
|
||||
}
|
||||
return FileManager
|
||||
}
|
||||
}
|
||||
192
dns/platform/file.go
Normal file
192
dns/platform/file.go
Normal file
@@ -0,0 +1,192 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
resolvConfPath = "/etc/resolv.conf"
|
||||
resolvConfBackupPath = "/etc/resolv.conf.olm.backup"
|
||||
resolvConfHeader = "# Generated by Olm DNS Manager\n# Original file backed up to " + resolvConfBackupPath + "\n\n"
|
||||
)
|
||||
|
||||
// FileDNSConfigurator manages DNS settings by directly modifying /etc/resolv.conf
|
||||
type FileDNSConfigurator struct {
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewFileDNSConfigurator creates a new file-based DNS configurator
|
||||
func NewFileDNSConfigurator() (*FileDNSConfigurator, error) {
|
||||
return &FileDNSConfigurator{}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (f *FileDNSConfigurator) Name() string {
|
||||
return "file-resolv.conf"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (f *FileDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := f.GetCurrentDNS()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current DNS: %w", err)
|
||||
}
|
||||
|
||||
// Backup original resolv.conf if not already backed up
|
||||
if !f.isBackupExists() {
|
||||
if err := f.backupResolvConf(); err != nil {
|
||||
return nil, fmt.Errorf("backup resolv.conf: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
f.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: f.Name(),
|
||||
}
|
||||
|
||||
// Write new resolv.conf
|
||||
if err := f.writeResolvConf(servers); err != nil {
|
||||
return nil, fmt.Errorf("write resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (f *FileDNSConfigurator) RestoreDNS() error {
|
||||
if !f.isBackupExists() {
|
||||
return fmt.Errorf("no backup file exists")
|
||||
}
|
||||
|
||||
// Copy backup back to original location
|
||||
if err := copyFile(resolvConfBackupPath, resolvConfPath); err != nil {
|
||||
return fmt.Errorf("restore from backup: %w", err)
|
||||
}
|
||||
|
||||
// Remove backup file
|
||||
if err := os.Remove(resolvConfBackupPath); err != nil {
|
||||
return fmt.Errorf("remove backup file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (f *FileDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
content, err := os.ReadFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return f.parseNameservers(string(content)), nil
|
||||
}
|
||||
|
||||
// backupResolvConf creates a backup of the current resolv.conf
|
||||
func (f *FileDNSConfigurator) backupResolvConf() error {
|
||||
// Get file info for permissions
|
||||
info, err := os.Stat(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
if err := copyFile(resolvConfPath, resolvConfBackupPath); err != nil {
|
||||
return fmt.Errorf("copy file: %w", err)
|
||||
}
|
||||
|
||||
// Preserve permissions
|
||||
if err := os.Chmod(resolvConfBackupPath, info.Mode()); err != nil {
|
||||
return fmt.Errorf("chmod backup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeResolvConf writes a new resolv.conf with the specified DNS servers
|
||||
func (f *FileDNSConfigurator) writeResolvConf(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Get file info for permissions
|
||||
info, err := os.Stat(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
var content strings.Builder
|
||||
content.WriteString(resolvConfHeader)
|
||||
|
||||
// Write nameservers
|
||||
for _, server := range servers {
|
||||
content.WriteString("nameserver ")
|
||||
content.WriteString(server.String())
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
// Write the file
|
||||
if err := os.WriteFile(resolvConfPath, []byte(content.String()), info.Mode()); err != nil {
|
||||
return fmt.Errorf("write resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isBackupExists checks if a backup file exists
|
||||
func (f *FileDNSConfigurator) isBackupExists() bool {
|
||||
_, err := os.Stat(resolvConfBackupPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// parseNameservers extracts nameserver entries from resolv.conf content
|
||||
func (f *FileDNSConfigurator) parseNameservers(content string) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// Skip comments and empty lines
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Look for nameserver lines
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
if addr, err := netip.ParseAddr(fields[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// copyFile copies a file from src to dst
|
||||
func copyFile(src, dst string) error {
|
||||
content, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read source: %w", err)
|
||||
}
|
||||
|
||||
// Get source file permissions
|
||||
info, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat source: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(dst, content, info.Mode()); err != nil {
|
||||
return fmt.Errorf("write destination: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
294
dns/platform/network_manager.go
Normal file
294
dns/platform/network_manager.go
Normal file
@@ -0,0 +1,294 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbus "github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
// NetworkManager D-Bus constants
|
||||
networkManagerDest = "org.freedesktop.NetworkManager"
|
||||
networkManagerDbusObjectNode = "/org/freedesktop/NetworkManager"
|
||||
networkManagerDbusDNSManagerInterface = "org.freedesktop.NetworkManager.DnsManager"
|
||||
networkManagerDbusDNSManagerObjectNode = networkManagerDbusObjectNode + "/DnsManager"
|
||||
networkManagerDbusDNSManagerModeProperty = networkManagerDbusDNSManagerInterface + ".Mode"
|
||||
networkManagerDbusVersionProperty = "org.freedesktop.NetworkManager.Version"
|
||||
|
||||
// NetworkManager dispatcher script path
|
||||
networkManagerDispatcherDir = "/etc/NetworkManager/dispatcher.d"
|
||||
networkManagerConfDir = "/etc/NetworkManager/conf.d"
|
||||
networkManagerDNSConfFile = "olm-dns.conf"
|
||||
networkManagerDispatcherFile = "01-olm-dns"
|
||||
)
|
||||
|
||||
// NetworkManagerDNSConfigurator manages DNS settings using NetworkManager configuration files
|
||||
// This approach works with unmanaged interfaces by modifying NetworkManager's global DNS settings
|
||||
type NetworkManagerDNSConfigurator struct {
|
||||
ifaceName string
|
||||
originalState *DNSState
|
||||
confPath string
|
||||
dispatchPath string
|
||||
}
|
||||
|
||||
// NewNetworkManagerDNSConfigurator creates a new NetworkManager DNS configurator
|
||||
func NewNetworkManagerDNSConfigurator(ifaceName string) (*NetworkManagerDNSConfigurator, error) {
|
||||
if ifaceName == "" {
|
||||
return nil, fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
// Check that NetworkManager conf.d directory exists
|
||||
if _, err := os.Stat(networkManagerConfDir); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("NetworkManager conf.d directory not found: %s", networkManagerConfDir)
|
||||
}
|
||||
|
||||
return &NetworkManagerDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
confPath: networkManagerConfDir + "/" + networkManagerDNSConfFile,
|
||||
dispatchPath: networkManagerDispatcherDir + "/" + networkManagerDispatcherFile,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (n *NetworkManagerDNSConfigurator) Name() string {
|
||||
return "network-manager"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (n *NetworkManagerDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := n.GetCurrentDNS()
|
||||
if err != nil {
|
||||
// If we can't get current DNS, proceed anyway
|
||||
originalServers = []netip.Addr{}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
n.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: n.Name(),
|
||||
}
|
||||
|
||||
// Apply new DNS servers
|
||||
if err := n.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (n *NetworkManagerDNSConfigurator) RestoreDNS() error {
|
||||
// Remove our configuration file
|
||||
if err := os.Remove(n.confPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove DNS config file: %w", err)
|
||||
}
|
||||
|
||||
// Reload NetworkManager to apply the change
|
||||
if err := n.reloadNetworkManager(); err != nil {
|
||||
return fmt.Errorf("reload NetworkManager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers by reading /etc/resolv.conf
|
||||
func (n *NetworkManagerDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
content, err := os.ReadFile("/etc/resolv.conf")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
var servers []netip.Addr
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
if addr, err := netip.ParseAddr(fields[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// applyDNSServers applies DNS server configuration via NetworkManager config file
|
||||
func (n *NetworkManagerDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Build DNS server list
|
||||
var dnsServers []string
|
||||
for _, server := range servers {
|
||||
dnsServers = append(dnsServers, server.String())
|
||||
}
|
||||
|
||||
// Create NetworkManager configuration file that sets global DNS
|
||||
// This overrides DNS for all connections
|
||||
configContent := fmt.Sprintf(`# Generated by Olm DNS Manager - DO NOT EDIT
|
||||
# This file configures NetworkManager to use Olm's DNS proxy
|
||||
|
||||
[global-dns-domain-*]
|
||||
servers=%s
|
||||
`, strings.Join(dnsServers, ","))
|
||||
|
||||
// Write the configuration file
|
||||
if err := os.WriteFile(n.confPath, []byte(configContent), 0644); err != nil {
|
||||
return fmt.Errorf("write DNS config file: %w", err)
|
||||
}
|
||||
|
||||
// Reload NetworkManager to apply the new configuration
|
||||
if err := n.reloadNetworkManager(); err != nil {
|
||||
// Try to clean up
|
||||
os.Remove(n.confPath)
|
||||
return fmt.Errorf("reload NetworkManager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reloadNetworkManager tells NetworkManager to reload its configuration
|
||||
func (n *NetworkManagerDNSConfigurator) reloadNetworkManager() error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Call Reload method with flags=0 (reload everything)
|
||||
// See: https://networkmanager.dev/docs/api/latest/gdbus-org.freedesktop.NetworkManager.html#gdbus-method-org-freedesktop-NetworkManager.Reload
|
||||
err = obj.CallWithContext(ctx, networkManagerDest+".Reload", 0, uint32(0)).Store()
|
||||
if err != nil {
|
||||
return fmt.Errorf("call Reload: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsNetworkManagerAvailable checks if NetworkManager is available and responsive
|
||||
func IsNetworkManagerAvailable() bool {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Try to ping NetworkManager
|
||||
if err := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsNetworkManagerDNSModeSupported checks if NetworkManager's DNS mode is one we can work with
|
||||
// Some DNS modes delegate to other systems (like systemd-resolved) which we should use directly
|
||||
func IsNetworkManagerDNSModeSupported() bool {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
|
||||
|
||||
modeVariant, err := obj.GetProperty(networkManagerDbusDNSManagerModeProperty)
|
||||
if err != nil {
|
||||
// If we can't get the mode, assume it's not supported
|
||||
return false
|
||||
}
|
||||
|
||||
mode, ok := modeVariant.Value().(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// If NetworkManager is delegating DNS to systemd-resolved, we should use
|
||||
// systemd-resolved directly for better control
|
||||
switch mode {
|
||||
case "systemd-resolved":
|
||||
// NetworkManager is delegating to systemd-resolved
|
||||
// We should use systemd-resolved configurator instead
|
||||
return false
|
||||
case "dnsmasq", "unbound":
|
||||
// NetworkManager is using a local resolver that it controls
|
||||
// We can configure DNS through NetworkManager
|
||||
return true
|
||||
case "default", "none", "":
|
||||
// NetworkManager is managing DNS directly or not at all
|
||||
// We can configure DNS through NetworkManager
|
||||
return true
|
||||
default:
|
||||
// Unknown mode, try to use it
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// GetNetworkManagerDNSMode returns the current DNS mode of NetworkManager
|
||||
func GetNetworkManagerDNSMode() (string, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
|
||||
|
||||
modeVariant, err := obj.GetProperty(networkManagerDbusDNSManagerModeProperty)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get DNS mode property: %w", err)
|
||||
}
|
||||
|
||||
mode, ok := modeVariant.Value().(string)
|
||||
if !ok {
|
||||
return "", errors.New("DNS mode is not a string")
|
||||
}
|
||||
|
||||
return mode, nil
|
||||
}
|
||||
|
||||
// GetNetworkManagerVersion returns the version of NetworkManager
|
||||
func GetNetworkManagerVersion() (string, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
|
||||
|
||||
versionVariant, err := obj.GetProperty(networkManagerDbusVersionProperty)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get version property: %w", err)
|
||||
}
|
||||
|
||||
version, ok := versionVariant.Value().(string)
|
||||
if !ok {
|
||||
return "", errors.New("version is not a string")
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
192
dns/platform/resolvconf.go
Normal file
192
dns/platform/resolvconf.go
Normal file
@@ -0,0 +1,192 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const resolvconfCommand = "resolvconf"
|
||||
|
||||
// ResolvconfDNSConfigurator manages DNS settings using the resolvconf utility
|
||||
type ResolvconfDNSConfigurator struct {
|
||||
ifaceName string
|
||||
implType string
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewResolvconfDNSConfigurator creates a new resolvconf DNS configurator
|
||||
func NewResolvconfDNSConfigurator(ifaceName string) (*ResolvconfDNSConfigurator, error) {
|
||||
if ifaceName == "" {
|
||||
return nil, fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
// Detect resolvconf implementation type
|
||||
implType, err := detectResolvconfType()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("detect resolvconf type: %w", err)
|
||||
}
|
||||
|
||||
return &ResolvconfDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
implType: implType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (r *ResolvconfDNSConfigurator) Name() string {
|
||||
return fmt.Sprintf("resolvconf-%s", r.implType)
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (r *ResolvconfDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := r.GetCurrentDNS()
|
||||
if err != nil {
|
||||
// If we can't get current DNS, proceed anyway
|
||||
originalServers = []netip.Addr{}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
r.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: r.Name(),
|
||||
}
|
||||
|
||||
// Apply new DNS servers
|
||||
if err := r.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (r *ResolvconfDNSConfigurator) RestoreDNS() error {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch r.implType {
|
||||
case "openresolv":
|
||||
// Force delete with -f
|
||||
cmd = exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||
default:
|
||||
cmd = exec.Command(resolvconfCommand, "-d", r.ifaceName)
|
||||
}
|
||||
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("delete resolvconf config: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (r *ResolvconfDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
// resolvconf doesn't provide a direct way to query per-interface DNS
|
||||
// We can try to read /etc/resolv.conf but it's merged from all sources
|
||||
content, err := exec.Command(resolvconfCommand, "-l").CombinedOutput()
|
||||
if err != nil {
|
||||
// Fall back to reading resolv.conf
|
||||
return readResolvConfServers()
|
||||
}
|
||||
|
||||
// Parse the output (format varies by implementation)
|
||||
return parseResolvconfOutput(string(content)), nil
|
||||
}
|
||||
|
||||
// applyDNSServers applies DNS server configuration via resolvconf
|
||||
func (r *ResolvconfDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Build resolv.conf content
|
||||
var content bytes.Buffer
|
||||
content.WriteString("# Generated by Olm DNS Manager\n\n")
|
||||
|
||||
for _, server := range servers {
|
||||
content.WriteString("nameserver ")
|
||||
content.WriteString(server.String())
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
// Apply via resolvconf
|
||||
var cmd *exec.Cmd
|
||||
switch r.implType {
|
||||
case "openresolv":
|
||||
// OpenResolv supports exclusive mode with -x
|
||||
cmd = exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||
default:
|
||||
cmd = exec.Command(resolvconfCommand, "-a", r.ifaceName)
|
||||
}
|
||||
|
||||
cmd.Stdin = &content
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("apply resolvconf config: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectResolvconfType detects which resolvconf implementation is being used
|
||||
func detectResolvconfType() (string, error) {
|
||||
cmd := exec.Command(resolvconfCommand, "--version")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("detect resolvconf type: %w", err)
|
||||
}
|
||||
|
||||
if strings.Contains(string(out), "openresolv") {
|
||||
return "openresolv", nil
|
||||
}
|
||||
|
||||
return "resolvconf", nil
|
||||
}
|
||||
|
||||
// parseResolvconfOutput parses resolvconf -l output for DNS servers
|
||||
func parseResolvconfOutput(output string) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
|
||||
lines := strings.Split(output, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// Skip comments and empty lines
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Look for nameserver lines
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
if addr, err := netip.ParseAddr(fields[1]); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// readResolvConfServers reads DNS servers from /etc/resolv.conf
|
||||
func readResolvConfServers() ([]netip.Addr, error) {
|
||||
cmd := exec.Command("cat", "/etc/resolv.conf")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
return parseResolvconfOutput(string(out)), nil
|
||||
}
|
||||
|
||||
// IsResolvconfAvailable checks if resolvconf is available
|
||||
func IsResolvconfAvailable() bool {
|
||||
cmd := exec.Command(resolvconfCommand, "--version")
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
286
dns/platform/systemd.go
Normal file
286
dns/platform/systemd.go
Normal file
@@ -0,0 +1,286 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
dbus "github.com/godbus/dbus/v5"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
systemdResolvedDest = "org.freedesktop.resolve1"
|
||||
systemdDbusObjectNode = "/org/freedesktop/resolve1"
|
||||
systemdDbusManagerIface = "org.freedesktop.resolve1.Manager"
|
||||
systemdDbusGetLinkMethod = systemdDbusManagerIface + ".GetLink"
|
||||
systemdDbusFlushCachesMethod = systemdDbusManagerIface + ".FlushCaches"
|
||||
systemdDbusLinkInterface = "org.freedesktop.resolve1.Link"
|
||||
systemdDbusSetDNSMethod = systemdDbusLinkInterface + ".SetDNS"
|
||||
systemdDbusSetDefaultRouteMethod = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||
systemdDbusSetDomainsMethod = systemdDbusLinkInterface + ".SetDomains"
|
||||
systemdDbusSetDNSSECMethod = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||
systemdDbusSetDNSOverTLSMethod = systemdDbusLinkInterface + ".SetDNSOverTLS"
|
||||
systemdDbusRevertMethod = systemdDbusLinkInterface + ".Revert"
|
||||
|
||||
// RootZone is the root DNS zone that matches all queries
|
||||
RootZone = "."
|
||||
)
|
||||
|
||||
// systemdDbusDNSInput maps to (iay) dbus input for SetDNS method
|
||||
type systemdDbusDNSInput struct {
|
||||
Family int32
|
||||
Address []byte
|
||||
}
|
||||
|
||||
// systemdDbusDomainsInput maps to (sb) dbus input for SetDomains method
|
||||
type systemdDbusDomainsInput struct {
|
||||
Domain string
|
||||
MatchOnly bool
|
||||
}
|
||||
|
||||
// SystemdResolvedDNSConfigurator manages DNS settings using systemd-resolved D-Bus API
|
||||
type SystemdResolvedDNSConfigurator struct {
|
||||
ifaceName string
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewSystemdResolvedDNSConfigurator creates a new systemd-resolved DNS configurator
|
||||
func NewSystemdResolvedDNSConfigurator(ifaceName string) (*SystemdResolvedDNSConfigurator, error) {
|
||||
// Get network interface
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get interface: %w", err)
|
||||
}
|
||||
|
||||
// Connect to D-Bus
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
|
||||
|
||||
// Get the link object for this interface
|
||||
var linkPath string
|
||||
if err := obj.Call(systemdDbusGetLinkMethod, 0, iface.Index).Store(&linkPath); err != nil {
|
||||
return nil, fmt.Errorf("get link: %w", err)
|
||||
}
|
||||
|
||||
return &SystemdResolvedDNSConfigurator{
|
||||
ifaceName: ifaceName,
|
||||
dbusLinkObject: dbus.ObjectPath(linkPath),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (s *SystemdResolvedDNSConfigurator) Name() string {
|
||||
return "systemd-resolved"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (s *SystemdResolvedDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := s.GetCurrentDNS()
|
||||
if err != nil {
|
||||
// If we can't get current DNS, proceed anyway
|
||||
originalServers = []netip.Addr{}
|
||||
}
|
||||
|
||||
// Store original state
|
||||
s.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: s.Name(),
|
||||
}
|
||||
|
||||
// Apply new DNS servers
|
||||
if err := s.applyDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("apply DNS servers: %w", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (s *SystemdResolvedDNSConfigurator) RestoreDNS() error {
|
||||
// Call Revert method to restore systemd-resolved defaults
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := obj.CallWithContext(ctx, systemdDbusRevertMethod, 0).Store(); err != nil {
|
||||
return fmt.Errorf("revert DNS settings: %w", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache after reverting
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
// Note: systemd-resolved doesn't easily expose current per-link DNS servers via D-Bus
|
||||
// This is a placeholder that returns an empty list
|
||||
func (s *SystemdResolvedDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
// systemd-resolved's D-Bus API doesn't have a simple way to query current DNS servers
|
||||
// We would need to parse resolvectl status output or read from /run/systemd/resolve/
|
||||
// For now, return empty list
|
||||
return []netip.Addr{}, nil
|
||||
}
|
||||
|
||||
// applyDNSServers applies DNS server configuration via systemd-resolved
|
||||
func (s *SystemdResolvedDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
// Convert servers to systemd-resolved format
|
||||
var dnsInputs []systemdDbusDNSInput
|
||||
for _, server := range servers {
|
||||
family := unix.AF_INET
|
||||
if server.Is6() {
|
||||
family = unix.AF_INET6
|
||||
}
|
||||
|
||||
dnsInputs = append(dnsInputs, systemdDbusDNSInput{
|
||||
Family: int32(family),
|
||||
Address: server.AsSlice(),
|
||||
})
|
||||
}
|
||||
|
||||
// Connect to D-Bus
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Call SetDNS method to set the DNS servers
|
||||
if err := obj.CallWithContext(ctx, systemdDbusSetDNSMethod, 0, dnsInputs).Store(); err != nil {
|
||||
return fmt.Errorf("set DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Set this interface as the default route for DNS
|
||||
// This ensures all DNS queries prefer this interface
|
||||
if err := s.callLinkMethod(systemdDbusSetDefaultRouteMethod, true); err != nil {
|
||||
return fmt.Errorf("set default route: %w", err)
|
||||
}
|
||||
|
||||
// Set the root zone "." as a match-only domain
|
||||
// This captures ALL DNS queries and routes them through this interface
|
||||
domainsInput := []systemdDbusDomainsInput{
|
||||
{
|
||||
Domain: RootZone,
|
||||
MatchOnly: true,
|
||||
},
|
||||
}
|
||||
if err := s.callLinkMethod(systemdDbusSetDomainsMethod, domainsInput); err != nil {
|
||||
return fmt.Errorf("set domains: %w", err)
|
||||
}
|
||||
|
||||
// Disable DNSSEC - we don't support it and it may be enabled by default
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSSECMethod, "no"); err != nil {
|
||||
// Log warning but don't fail - this is optional
|
||||
fmt.Printf("warning: failed to disable DNSSEC: %v\n", err)
|
||||
}
|
||||
|
||||
// Disable DNSOverTLS - we don't support it and it may be enabled by default
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethod, "no"); err != nil {
|
||||
// Log warning but don't fail - this is optional
|
||||
fmt.Printf("warning: failed to disable DNSOverTLS: %v\n", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache to ensure new settings take effect immediately
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callLinkMethod is a helper to call methods on the link object
|
||||
func (s *SystemdResolvedDNSConfigurator) callLinkMethod(method string, value any) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if value != nil {
|
||||
if err := obj.CallWithContext(ctx, method, 0, value).Store(); err != nil {
|
||||
return fmt.Errorf("call %s: %w", method, err)
|
||||
}
|
||||
} else {
|
||||
if err := obj.CallWithContext(ctx, method, 0).Store(); err != nil {
|
||||
return fmt.Errorf("call %s: %w", method, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// flushDNSCache flushes the systemd-resolved DNS cache
|
||||
func (s *SystemdResolvedDNSConfigurator) flushDNSCache() error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to system bus: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, 0).Store(); err != nil {
|
||||
return fmt.Errorf("flush caches: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsSystemdResolvedAvailable checks if systemd-resolved is available and responsive
|
||||
func IsSystemdResolvedAvailable() bool {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Try to ping systemd-resolved
|
||||
if err := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
41
dns/platform/types.go
Normal file
41
dns/platform/types.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package dns
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// DNSConfigurator provides an interface for managing system DNS settings
|
||||
// across different platforms and implementations
|
||||
type DNSConfigurator interface {
|
||||
// SetDNS overrides the system DNS servers with the specified ones
|
||||
// Returns the original DNS servers that were replaced
|
||||
SetDNS(servers []netip.Addr) ([]netip.Addr, error)
|
||||
|
||||
// RestoreDNS restores the original DNS servers
|
||||
RestoreDNS() error
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
GetCurrentDNS() ([]netip.Addr, error)
|
||||
|
||||
// Name returns the name of this configurator implementation
|
||||
Name() string
|
||||
}
|
||||
|
||||
// DNSConfig contains the configuration for DNS override
|
||||
type DNSConfig struct {
|
||||
// Servers is the list of DNS servers to use
|
||||
Servers []netip.Addr
|
||||
|
||||
// SearchDomains is an optional list of search domains
|
||||
SearchDomains []string
|
||||
}
|
||||
|
||||
// DNSState represents the saved state of DNS configuration
|
||||
type DNSState struct {
|
||||
// OriginalServers are the DNS servers before override
|
||||
OriginalServers []netip.Addr
|
||||
|
||||
// OriginalSearchDomains are the search domains before override
|
||||
OriginalSearchDomains []string
|
||||
|
||||
// ConfiguratorName is the name of the configurator that saved this state
|
||||
ConfiguratorName string
|
||||
}
|
||||
343
dns/platform/windows.go
Normal file
343
dns/platform/windows.go
Normal file
@@ -0,0 +1,343 @@
|
||||
//go:build windows
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
|
||||
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
|
||||
)
|
||||
|
||||
const (
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServer = "NameServer"
|
||||
interfaceConfigDhcpNameServer = "DhcpNameServer"
|
||||
)
|
||||
|
||||
// WindowsDNSConfigurator manages DNS settings on Windows using the registry
|
||||
type WindowsDNSConfigurator struct {
|
||||
guid string
|
||||
originalState *DNSState
|
||||
}
|
||||
|
||||
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
|
||||
// Accepts an interface name and extracts the GUID internally
|
||||
func NewWindowsDNSConfigurator(interfaceName string) (*WindowsDNSConfigurator, error) {
|
||||
if interfaceName == "" {
|
||||
return nil, fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
guid, err := getInterfaceGUIDString(interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get interface GUID: %w", err)
|
||||
}
|
||||
|
||||
return &WindowsDNSConfigurator{
|
||||
guid: guid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newWindowsDNSConfiguratorFromGUID creates a configurator from a GUID string
|
||||
// This is an internal function for use by DetectBestConfigurator
|
||||
func newWindowsDNSConfiguratorFromGUID(guid string) (*WindowsDNSConfigurator, error) {
|
||||
if guid == "" {
|
||||
return nil, fmt.Errorf("interface GUID is required")
|
||||
}
|
||||
|
||||
return &WindowsDNSConfigurator{
|
||||
guid: guid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the configurator name
|
||||
func (w *WindowsDNSConfigurator) Name() string {
|
||||
return "windows-registry"
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS servers and returns the original servers
|
||||
func (w *WindowsDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
|
||||
// Get current DNS settings before overriding
|
||||
originalServers, err := w.GetCurrentDNS()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current DNS: %w", err)
|
||||
}
|
||||
|
||||
// Store original state
|
||||
w.originalState = &DNSState{
|
||||
OriginalServers: originalServers,
|
||||
ConfiguratorName: w.Name(),
|
||||
}
|
||||
|
||||
// Set new DNS servers
|
||||
if err := w.setDNSServers(servers); err != nil {
|
||||
return nil, fmt.Errorf("set DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := w.flushDNSCache(); err != nil {
|
||||
// Non-fatal, just log
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return originalServers, nil
|
||||
}
|
||||
|
||||
// RestoreDNS restores the original DNS configuration
|
||||
func (w *WindowsDNSConfigurator) RestoreDNS() error {
|
||||
if w.originalState == nil {
|
||||
return fmt.Errorf("no original state to restore")
|
||||
}
|
||||
|
||||
// Clear the static DNS setting
|
||||
if err := w.clearDNSServers(); err != nil {
|
||||
return fmt.Errorf("clear DNS servers: %w", err)
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
if err := w.flushDNSCache(); err != nil {
|
||||
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDNS returns the currently configured DNS servers
|
||||
func (w *WindowsDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
|
||||
regKey, err := w.getInterfaceRegistryKey(registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer closeKey(regKey)
|
||||
|
||||
// Try to get static DNS first
|
||||
nameServer, _, err := regKey.GetStringValue(interfaceConfigNameServer)
|
||||
if err == nil && nameServer != "" {
|
||||
return w.parseServerList(nameServer), nil
|
||||
}
|
||||
|
||||
// Fall back to DHCP DNS
|
||||
dhcpNameServer, _, err := regKey.GetStringValue(interfaceConfigDhcpNameServer)
|
||||
if err == nil && dhcpNameServer != "" {
|
||||
return w.parseServerList(dhcpNameServer), nil
|
||||
}
|
||||
|
||||
return []netip.Addr{}, nil
|
||||
}
|
||||
|
||||
// setDNSServers sets the DNS servers in the registry
|
||||
func (w *WindowsDNSConfigurator) setDNSServers(servers []netip.Addr) error {
|
||||
if len(servers) == 0 {
|
||||
return fmt.Errorf("no DNS servers provided")
|
||||
}
|
||||
|
||||
regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer closeKey(regKey)
|
||||
|
||||
// Build comma-separated or space-separated list of servers
|
||||
var serverList string
|
||||
for i, server := range servers {
|
||||
if i > 0 {
|
||||
serverList += ","
|
||||
}
|
||||
serverList += server.String()
|
||||
}
|
||||
|
||||
if err := regKey.SetStringValue(interfaceConfigNameServer, serverList); err != nil {
|
||||
return fmt.Errorf("set NameServer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearDNSServers clears the static DNS server setting
|
||||
func (w *WindowsDNSConfigurator) clearDNSServers() error {
|
||||
regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer closeKey(regKey)
|
||||
|
||||
// Set empty string to revert to DHCP
|
||||
if err := regKey.SetStringValue(interfaceConfigNameServer, ""); err != nil {
|
||||
return fmt.Errorf("clear NameServer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInterfaceRegistryKey opens the registry key for the network interface
|
||||
func (w *WindowsDNSConfigurator) getInterfaceRegistryKey(access uint32) (registry.Key, error) {
|
||||
regKeyPath := interfaceConfigPath + `\` + w.guid
|
||||
|
||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, access)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("open HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
|
||||
}
|
||||
|
||||
return regKey, nil
|
||||
}
|
||||
|
||||
// parseServerList parses a comma or space-separated list of DNS servers
|
||||
func (w *WindowsDNSConfigurator) parseServerList(serverList string) []netip.Addr {
|
||||
var servers []netip.Addr
|
||||
|
||||
// Split by comma or space
|
||||
parts := splitByDelimiters(serverList, []rune{',', ' '})
|
||||
|
||||
for _, part := range parts {
|
||||
if addr, err := netip.ParseAddr(part); err == nil {
|
||||
servers = append(servers, addr)
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// flushDNSCache flushes the Windows DNS resolver cache
|
||||
func (w *WindowsDNSConfigurator) flushDNSCache() error {
|
||||
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
fmt.Printf("warning: DnsFlushResolverCache panicked: %v\n", rec)
|
||||
}
|
||||
}()
|
||||
|
||||
ret, _, err := dnsFlushResolverCacheFn.Call()
|
||||
if ret == 0 {
|
||||
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
|
||||
}
|
||||
return fmt.Errorf("DnsFlushResolverCache failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitByDelimiters splits a string by multiple delimiters
|
||||
func splitByDelimiters(s string, delimiters []rune) []string {
|
||||
var result []string
|
||||
var current []rune
|
||||
|
||||
for _, char := range s {
|
||||
isDelimiter := false
|
||||
for _, delim := range delimiters {
|
||||
if char == delim {
|
||||
isDelimiter = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isDelimiter {
|
||||
if len(current) > 0 {
|
||||
result = append(result, string(current))
|
||||
current = []rune{}
|
||||
}
|
||||
} else {
|
||||
current = append(current, char)
|
||||
}
|
||||
}
|
||||
|
||||
if len(current) > 0 {
|
||||
result = append(result, string(current))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// closeKey closes a registry key and logs errors
|
||||
func closeKey(closer io.Closer) {
|
||||
if err := closer.Close(); err != nil {
|
||||
fmt.Printf("warning: failed to close registry key: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
|
||||
// This is required for registry-based DNS configuration on Windows
|
||||
func getInterfaceGUIDString(interfaceName string) (string, error) {
|
||||
if interfaceName == "" {
|
||||
return "", fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get interface %s: %w", interfaceName, err)
|
||||
}
|
||||
|
||||
luid, err := indexToLUID(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to convert index to LUID: %w", err)
|
||||
}
|
||||
|
||||
// Convert LUID to GUID using Windows API
|
||||
guid, err := luidToGUID(luid)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to convert LUID to GUID: %w", err)
|
||||
}
|
||||
|
||||
return guid, nil
|
||||
}
|
||||
|
||||
// indexToLUID converts a Windows interface index to a LUID
|
||||
func indexToLUID(index uint32) (uint64, error) {
|
||||
var luid uint64
|
||||
|
||||
// Load the iphlpapi.dll and get the ConvertInterfaceIndexToLuid function
|
||||
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
convertInterfaceIndexToLuid := iphlpapi.NewProc("ConvertInterfaceIndexToLuid")
|
||||
|
||||
// Call the Windows API
|
||||
ret, _, err := convertInterfaceIndexToLuid.Call(
|
||||
uintptr(index),
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
return 0, fmt.Errorf("ConvertInterfaceIndexToLuid failed with code %d: %w", ret, err)
|
||||
}
|
||||
|
||||
return luid, nil
|
||||
}
|
||||
|
||||
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
|
||||
// using the Windows ConvertInterface* APIs
|
||||
func luidToGUID(luid uint64) (string, error) {
|
||||
var guid windows.GUID
|
||||
|
||||
// Load the iphlpapi.dll and get the ConvertInterfaceLuidToGuid function
|
||||
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
convertLuidToGuid := iphlpapi.NewProc("ConvertInterfaceLuidToGuid")
|
||||
|
||||
// Call the Windows API
|
||||
// NET_LUID is a 64-bit value on Windows
|
||||
ret, _, err := convertLuidToGuid.Call(
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
uintptr(unsafe.Pointer(&guid)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
return "", fmt.Errorf("ConvertInterfaceLuidToGuid failed with code %d: %w", ret, err)
|
||||
}
|
||||
|
||||
// Format the GUID as a string with curly braces
|
||||
guidStr := fmt.Sprintf("{%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}",
|
||||
guid.Data1, guid.Data2, guid.Data3,
|
||||
guid.Data4[0], guid.Data4[1], guid.Data4[2], guid.Data4[3],
|
||||
guid.Data4[4], guid.Data4[5], guid.Data4[6], guid.Data4[7])
|
||||
|
||||
return guidStr, nil
|
||||
}
|
||||
279
get-olm.sh
Normal file
279
get-olm.sh
Normal file
@@ -0,0 +1,279 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Get Olm - Cross-platform installation script
|
||||
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/olm/refs/heads/main/get-olm.sh | bash
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# GitHub repository info
|
||||
REPO="fosrl/olm"
|
||||
GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
|
||||
|
||||
# Function to print colored output
|
||||
print_status() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
print_warning() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
print_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Function to get latest version from GitHub API
|
||||
get_latest_version() {
|
||||
local latest_info
|
||||
|
||||
if command -v curl >/dev/null 2>&1; then
|
||||
latest_info=$(curl -fsSL "$GITHUB_API_URL" 2>/dev/null)
|
||||
elif command -v wget >/dev/null 2>&1; then
|
||||
latest_info=$(wget -qO- "$GITHUB_API_URL" 2>/dev/null)
|
||||
else
|
||||
print_error "Neither curl nor wget is available. Please install one of them." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$latest_info" ]; then
|
||||
print_error "Failed to fetch latest version information" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract version from JSON response (works without jq)
|
||||
local version=$(echo "$latest_info" | grep '"tag_name"' | head -1 | sed 's/.*"tag_name": *"\([^"]*\)".*/\1/')
|
||||
|
||||
if [ -z "$version" ]; then
|
||||
print_error "Could not parse version from GitHub API response" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Remove 'v' prefix if present
|
||||
version=$(echo "$version" | sed 's/^v//')
|
||||
|
||||
echo "$version"
|
||||
}
|
||||
|
||||
# Detect OS and architecture
|
||||
detect_platform() {
|
||||
local os arch
|
||||
|
||||
# Detect OS
|
||||
case "$(uname -s)" in
|
||||
Linux*) os="linux" ;;
|
||||
Darwin*) os="darwin" ;;
|
||||
MINGW*|MSYS*|CYGWIN*) os="windows" ;;
|
||||
FreeBSD*) os="freebsd" ;;
|
||||
*)
|
||||
print_error "Unsupported operating system: $(uname -s)"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
# Detect architecture
|
||||
case "$(uname -m)" in
|
||||
x86_64|amd64) arch="amd64" ;;
|
||||
arm64|aarch64) arch="arm64" ;;
|
||||
armv7l|armv6l)
|
||||
if [ "$os" = "linux" ]; then
|
||||
if [ "$(uname -m)" = "armv6l" ]; then
|
||||
arch="arm32v6"
|
||||
else
|
||||
arch="arm32"
|
||||
fi
|
||||
else
|
||||
arch="arm64" # Default for non-Linux ARM
|
||||
fi
|
||||
;;
|
||||
riscv64)
|
||||
if [ "$os" = "linux" ]; then
|
||||
arch="riscv64"
|
||||
else
|
||||
print_error "RISC-V architecture only supported on Linux"
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
print_error "Unsupported architecture: $(uname -m)"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "${os}_${arch}"
|
||||
}
|
||||
|
||||
# Get installation directory
|
||||
get_install_dir() {
|
||||
local platform="$1"
|
||||
|
||||
if [[ "$platform" == *"windows"* ]]; then
|
||||
echo "$HOME/bin"
|
||||
else
|
||||
# For Unix-like systems, prioritize system-wide directories for sudo access
|
||||
# Check in order of preference: /usr/local/bin, /usr/bin, ~/.local/bin
|
||||
if [ -d "/usr/local/bin" ]; then
|
||||
echo "/usr/local/bin"
|
||||
elif [ -d "/usr/bin" ]; then
|
||||
echo "/usr/bin"
|
||||
else
|
||||
# Fallback to user directory if system directories don't exist
|
||||
echo "$HOME/.local/bin"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Check if we need sudo for installation
|
||||
need_sudo() {
|
||||
local install_dir="$1"
|
||||
|
||||
# If installing to system directory and we don't have write permission, need sudo
|
||||
if [[ "$install_dir" == "/usr/local/bin" || "$install_dir" == "/usr/bin" ]]; then
|
||||
if [ ! -w "$install_dir" ] 2>/dev/null; then
|
||||
return 0 # Need sudo
|
||||
fi
|
||||
fi
|
||||
return 1 # Don't need sudo
|
||||
}
|
||||
|
||||
# Download and install olm
|
||||
install_olm() {
|
||||
local platform="$1"
|
||||
local install_dir="$2"
|
||||
local binary_name="olm_${platform}"
|
||||
local exe_suffix=""
|
||||
|
||||
# Add .exe suffix for Windows
|
||||
if [[ "$platform" == *"windows"* ]]; then
|
||||
binary_name="${binary_name}.exe"
|
||||
exe_suffix=".exe"
|
||||
fi
|
||||
|
||||
local download_url="${BASE_URL}/${binary_name}"
|
||||
local temp_file="/tmp/olm${exe_suffix}"
|
||||
local final_path="${install_dir}/olm${exe_suffix}"
|
||||
|
||||
print_status "Downloading olm from ${download_url}"
|
||||
|
||||
# Download the binary
|
||||
if command -v curl >/dev/null 2>&1; then
|
||||
curl -fsSL "$download_url" -o "$temp_file"
|
||||
elif command -v wget >/dev/null 2>&1; then
|
||||
wget -q "$download_url" -O "$temp_file"
|
||||
else
|
||||
print_error "Neither curl nor wget is available. Please install one of them."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if we need sudo for installation
|
||||
local use_sudo=""
|
||||
if need_sudo "$install_dir"; then
|
||||
print_status "Administrator privileges required for system-wide installation"
|
||||
if command -v sudo >/dev/null 2>&1; then
|
||||
use_sudo="sudo"
|
||||
else
|
||||
print_error "sudo is required for system-wide installation but not available"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Create install directory if it doesn't exist
|
||||
if [ -n "$use_sudo" ]; then
|
||||
$use_sudo mkdir -p "$install_dir"
|
||||
else
|
||||
mkdir -p "$install_dir"
|
||||
fi
|
||||
|
||||
# Move binary to install directory
|
||||
if [ -n "$use_sudo" ]; then
|
||||
$use_sudo mv "$temp_file" "$final_path"
|
||||
$use_sudo chmod +x "$final_path"
|
||||
else
|
||||
mv "$temp_file" "$final_path"
|
||||
chmod +x "$final_path"
|
||||
fi
|
||||
|
||||
print_status "olm installed to ${final_path}"
|
||||
|
||||
# Check if install directory is in PATH (only warn for non-system directories)
|
||||
if [[ "$install_dir" != "/usr/local/bin" && "$install_dir" != "/usr/bin" ]]; then
|
||||
if ! echo "$PATH" | grep -q "$install_dir"; then
|
||||
print_warning "Install directory ${install_dir} is not in your PATH."
|
||||
print_warning "Add it to your PATH by adding this line to your shell profile:"
|
||||
print_warning " export PATH=\"${install_dir}:\$PATH\""
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Verify installation
|
||||
verify_installation() {
|
||||
local install_dir="$1"
|
||||
local exe_suffix=""
|
||||
|
||||
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||
exe_suffix=".exe"
|
||||
fi
|
||||
|
||||
local olm_path="${install_dir}/olm${exe_suffix}"
|
||||
|
||||
if [ -f "$olm_path" ] && [ -x "$olm_path" ]; then
|
||||
print_status "Installation successful!"
|
||||
print_status "olm version: $("$olm_path" --version 2>/dev/null || echo "unknown")"
|
||||
return 0
|
||||
else
|
||||
print_error "Installation failed. Binary not found or not executable."
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Main installation process
|
||||
main() {
|
||||
print_status "Installing latest version of olm..."
|
||||
|
||||
# Get latest version
|
||||
print_status "Fetching latest version from GitHub..."
|
||||
VERSION=$(get_latest_version)
|
||||
print_status "Latest version: v${VERSION}"
|
||||
|
||||
# Set base URL with the fetched version
|
||||
BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}"
|
||||
|
||||
# Detect platform
|
||||
PLATFORM=$(detect_platform)
|
||||
print_status "Detected platform: ${PLATFORM}"
|
||||
|
||||
# Get install directory
|
||||
INSTALL_DIR=$(get_install_dir "$PLATFORM")
|
||||
print_status "Install directory: ${INSTALL_DIR}"
|
||||
|
||||
# Inform user about system-wide installation
|
||||
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
|
||||
print_status "Installing system-wide for sudo access"
|
||||
fi
|
||||
|
||||
# Install olm
|
||||
install_olm "$PLATFORM" "$INSTALL_DIR"
|
||||
|
||||
# Verify installation
|
||||
if verify_installation "$INSTALL_DIR"; then
|
||||
print_status "olm is ready to use!"
|
||||
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
|
||||
print_status "olm is installed system-wide and accessible via sudo"
|
||||
fi
|
||||
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||
print_status "Run 'olm --help' to get started"
|
||||
else
|
||||
print_status "Run 'olm --help' or 'sudo olm --help' to get started"
|
||||
fi
|
||||
else
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Run main function
|
||||
main "$@"
|
||||
27
go.mod
27
go.mod
@@ -3,20 +3,31 @@ module github.com/fosrl/olm
|
||||
go 1.25
|
||||
|
||||
require (
|
||||
github.com/fosrl/newt v0.0.0-20250730062419-3ccd755d557a
|
||||
github.com/Microsoft/go-winio v0.6.2
|
||||
github.com/fosrl/newt v0.0.0
|
||||
github.com/godbus/dbus/v5 v5.2.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/miekg/dns v1.1.68
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.41.0
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
|
||||
golang.org/x/sys v0.35.0
|
||||
golang.org/x/sys v0.38.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
golang.org/x/net v0.42.0 // indirect
|
||||
golang.org/x/crypto v0.44.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
||||
golang.org/x/mod v0.30.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
golang.org/x/tools v0.39.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0 // indirect
|
||||
)
|
||||
|
||||
replace github.com/fosrl/newt => ../newt
|
||||
|
||||
36
go.sum
36
go.sum
@@ -1,34 +1,46 @@
|
||||
github.com/fosrl/newt v0.0.0-20250730062419-3ccd755d557a h1:bUGN4piHlcqgfdRLrwqiLZZxgcitzBzNDQS1+CHSmJI=
|
||||
github.com/fosrl/newt v0.0.0-20250730062419-3ccd755d557a/go.mod h1:PbiPYp1hbL07awrmbqTSTz7lTenieTHN6cIkUVCGD3I=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
|
||||
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
|
||||
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
|
||||
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c=
|
||||
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
// ConnectionRequest defines the structure for an incoming connection request
|
||||
type ConnectionRequest struct {
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
// PeerStatus represents the status of a peer connection
|
||||
type PeerStatus struct {
|
||||
SiteID int `json:"siteId"`
|
||||
Connected bool `json:"connected"`
|
||||
RTT time.Duration `json:"rtt"`
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
IsRelay bool `json:"isRelay"`
|
||||
}
|
||||
|
||||
// StatusResponse is returned by the status endpoint
|
||||
type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Connected bool `json:"connected"`
|
||||
TunnelIP string `json:"tunnelIP,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPServer represents the HTTP server and its state
|
||||
type HTTPServer struct {
|
||||
addr string
|
||||
server *http.Server
|
||||
connectionChan chan ConnectionRequest
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
connectedAt time.Time
|
||||
isConnected bool
|
||||
tunnelIP string
|
||||
version string
|
||||
}
|
||||
|
||||
// NewHTTPServer creates a new HTTP server
|
||||
func NewHTTPServer(addr string) *HTTPServer {
|
||||
s := &HTTPServer{
|
||||
addr: addr,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *HTTPServer) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/connect", s.handleConnect)
|
||||
mux.HandleFunc("/status", s.handleStatus)
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: s.addr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
logger.Info("Starting HTTP server on %s", s.addr)
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("HTTP server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the HTTP server
|
||||
func (s *HTTPServer) Stop() error {
|
||||
logger.Info("Stopping HTTP server")
|
||||
return s.server.Close()
|
||||
}
|
||||
|
||||
// GetConnectionChannel returns the channel for receiving connection requests
|
||||
func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest {
|
||||
return s.connectionChan
|
||||
}
|
||||
|
||||
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||
func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.Connected = connected
|
||||
status.RTT = rtt
|
||||
status.LastSeen = time.Now()
|
||||
status.Endpoint = endpoint
|
||||
status.IsRelay = isRelay
|
||||
}
|
||||
|
||||
// SetConnectionStatus sets the overall connection status
|
||||
func (s *HTTPServer) SetConnectionStatus(isConnected bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
s.isConnected = isConnected
|
||||
|
||||
if isConnected {
|
||||
s.connectedAt = time.Now()
|
||||
} else {
|
||||
// Clear peer statuses when disconnected
|
||||
s.peerStatuses = make(map[int]*PeerStatus)
|
||||
}
|
||||
}
|
||||
|
||||
// SetTunnelIP sets the tunnel IP address
|
||||
func (s *HTTPServer) SetTunnelIP(tunnelIP string) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.tunnelIP = tunnelIP
|
||||
}
|
||||
|
||||
// SetVersion sets the olm version
|
||||
func (s *HTTPServer) SetVersion(version string) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.version = version
|
||||
}
|
||||
|
||||
// UpdatePeerRelayStatus updates only the relay status of a peer
|
||||
func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.Endpoint = endpoint
|
||||
status.IsRelay = isRelay
|
||||
}
|
||||
|
||||
// handleConnect handles the /connect endpoint
|
||||
func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req ConnectionRequest
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
|
||||
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Send the request to the main goroutine
|
||||
s.connectionChan <- req
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "connection request accepted",
|
||||
})
|
||||
}
|
||||
|
||||
// handleStatus handles the /status endpoint
|
||||
func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
s.statusMu.RLock()
|
||||
defer s.statusMu.RUnlock()
|
||||
|
||||
resp := StatusResponse{
|
||||
Connected: s.isConnected,
|
||||
TunnelIP: s.tunnelIP,
|
||||
Version: s.version,
|
||||
PeerStatuses: s.peerStatuses,
|
||||
}
|
||||
|
||||
if s.isConnected {
|
||||
resp.Status = "connected"
|
||||
} else {
|
||||
resp.Status = "disconnected"
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
126
namespace.sh
Normal file
126
namespace.sh
Normal file
@@ -0,0 +1,126 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Configuration
|
||||
NS_NAME="isolated_ns" # Name of the namespace
|
||||
VETH_HOST="veth_host" # Interface name on host side
|
||||
VETH_NS="veth_ns" # Interface name inside namespace
|
||||
HOST_IP="192.168.15.1" # Gateway IP for the namespace (host side)
|
||||
NS_IP="192.168.15.2" # IP address for the namespace
|
||||
SUBNET_CIDR="24" # Subnet mask
|
||||
DNS_SERVER="8.8.8.8" # DNS to use inside namespace
|
||||
|
||||
# Detect the main physical interface (gateway to internet)
|
||||
PHY_IFACE=$(ip route get 8.8.8.8 | awk -- '{printf $5}')
|
||||
|
||||
# Helper function to check for root
|
||||
check_root() {
|
||||
if [ "$EUID" -ne 0 ]; then
|
||||
echo "Error: This script must be run as root."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
setup_ns() {
|
||||
echo "Bringing up namespace '$NS_NAME'..."
|
||||
|
||||
# 1. Create the network namespace
|
||||
if ip netns list | grep -q "$NS_NAME"; then
|
||||
echo "Namespace $NS_NAME already exists. Run 'down' first."
|
||||
exit 1
|
||||
fi
|
||||
ip netns add "$NS_NAME"
|
||||
|
||||
# 2. Create veth pair
|
||||
ip link add "$VETH_HOST" type veth peer name "$VETH_NS"
|
||||
|
||||
# 3. Move peer interface to namespace
|
||||
ip link set "$VETH_NS" netns "$NS_NAME"
|
||||
|
||||
# 4. Configure Host Side Interface
|
||||
ip addr add "${HOST_IP}/${SUBNET_CIDR}" dev "$VETH_HOST"
|
||||
ip link set "$VETH_HOST" up
|
||||
|
||||
# 5. Configure Namespace Side Interface
|
||||
ip netns exec "$NS_NAME" ip addr add "${NS_IP}/${SUBNET_CIDR}" dev "$VETH_NS"
|
||||
ip netns exec "$NS_NAME" ip link set "$VETH_NS" up
|
||||
|
||||
# 6. Bring up loopback inside namespace (crucial for many apps)
|
||||
ip netns exec "$NS_NAME" ip link set lo up
|
||||
|
||||
# 7. Routing: Add default gateway inside namespace pointing to host
|
||||
ip netns exec "$NS_NAME" ip route add default via "$HOST_IP"
|
||||
|
||||
# 8. Enable IP forwarding on host
|
||||
echo 1 > /proc/sys/net/ipv4/ip_forward
|
||||
|
||||
# 9. NAT/Masquerade: Allow traffic from namespace to go out physical interface
|
||||
# We verify rule doesn't exist first to avoid duplicates
|
||||
iptables -t nat -C POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE 2>/dev/null || \
|
||||
iptables -t nat -A POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE
|
||||
|
||||
# Allow forwarding from host veth to WAN and back
|
||||
iptables -C FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT 2>/dev/null || \
|
||||
iptables -A FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT
|
||||
|
||||
iptables -C FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT 2>/dev/null || \
|
||||
iptables -A FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT
|
||||
|
||||
# 10. DNS Setup
|
||||
# Netns uses /etc/netns/<name>/resolv.conf if it exists
|
||||
mkdir -p "/etc/netns/$NS_NAME"
|
||||
echo "nameserver $DNS_SERVER" > "/etc/netns/$NS_NAME/resolv.conf"
|
||||
|
||||
echo "Namespace $NS_NAME is UP."
|
||||
echo "To enter shell: sudo ip netns exec $NS_NAME bash"
|
||||
}
|
||||
|
||||
teardown_ns() {
|
||||
echo "Tearing down namespace '$NS_NAME'..."
|
||||
|
||||
# 1. Remove Namespace (this automatically deletes the veth pair inside it)
|
||||
# The host side veth usually disappears when the peer is destroyed.
|
||||
if ip netns list | grep -q "$NS_NAME"; then
|
||||
ip netns del "$NS_NAME"
|
||||
else
|
||||
echo "Namespace $NS_NAME does not exist."
|
||||
fi
|
||||
|
||||
# 2. Clean up veth host side if it still lingers
|
||||
if ip link show "$VETH_HOST" > /dev/null 2>&1; then
|
||||
ip link delete "$VETH_HOST"
|
||||
fi
|
||||
|
||||
# 3. Remove iptables rules
|
||||
# We use -D to delete the specific rules we added
|
||||
iptables -t nat -D POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE 2>/dev/null
|
||||
iptables -D FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT 2>/dev/null
|
||||
iptables -D FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT 2>/dev/null
|
||||
|
||||
# 4. Remove DNS config
|
||||
rm -rf "/etc/netns/$NS_NAME"
|
||||
|
||||
echo "Namespace $NS_NAME is DOWN."
|
||||
}
|
||||
|
||||
test_connectivity() {
|
||||
echo "Testing connectivity inside $NS_NAME..."
|
||||
ip netns exec "$NS_NAME" ping -c 3 8.8.8.8
|
||||
}
|
||||
|
||||
# Main execution logic
|
||||
check_root
|
||||
|
||||
case "$1" in
|
||||
up)
|
||||
setup_ns
|
||||
;;
|
||||
down)
|
||||
teardown_ns
|
||||
;;
|
||||
test)
|
||||
test_connectivity
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {up|down|test}"
|
||||
exit 1
|
||||
esac
|
||||
88
olm.iss
Normal file
88
olm.iss
Normal file
@@ -0,0 +1,88 @@
|
||||
; Script generated by the Inno Setup Script Wizard.
|
||||
; SEE THE DOCUMENTATION FOR DETAILS ON CREATING INNO SETUP SCRIPT FILES!
|
||||
|
||||
#define MyAppName "olm"
|
||||
#define MyAppVersion "1.0.0"
|
||||
#define MyAppPublisher "Fossorial Inc."
|
||||
#define MyAppURL "https://pangolin.net"
|
||||
#define MyAppExeName "olm.exe"
|
||||
|
||||
[Setup]
|
||||
; NOTE: The value of AppId uniquely identifies this application. Do not use the same AppId value in installers for other applications.
|
||||
; (To generate a new GUID, click Tools | Generate GUID inside the IDE.)
|
||||
AppId={{44A24E4C-B616-476F-ADE7-8D56B930959E}
|
||||
AppName={#MyAppName}
|
||||
AppVersion={#MyAppVersion}
|
||||
;AppVerName={#MyAppName} {#MyAppVersion}
|
||||
AppPublisher={#MyAppPublisher}
|
||||
AppPublisherURL={#MyAppURL}
|
||||
AppSupportURL={#MyAppURL}
|
||||
AppUpdatesURL={#MyAppURL}
|
||||
DefaultDirName={autopf}\{#MyAppName}
|
||||
UninstallDisplayIcon={app}\{#MyAppExeName}
|
||||
; "ArchitecturesAllowed=x64compatible" specifies that Setup cannot run
|
||||
; on anything but x64 and Windows 11 on Arm.
|
||||
ArchitecturesAllowed=x64compatible
|
||||
; "ArchitecturesInstallIn64BitMode=x64compatible" requests that the
|
||||
; install be done in "64-bit mode" on x64 or Windows 11 on Arm,
|
||||
; meaning it should use the native 64-bit Program Files directory and
|
||||
; the 64-bit view of the registry.
|
||||
ArchitecturesInstallIn64BitMode=x64compatible
|
||||
DefaultGroupName={#MyAppName}
|
||||
DisableProgramGroupPage=yes
|
||||
; Uncomment the following line to run in non administrative install mode (install for current user only).
|
||||
;PrivilegesRequired=lowest
|
||||
OutputBaseFilename=mysetup
|
||||
SolidCompression=yes
|
||||
WizardStyle=modern
|
||||
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed
|
||||
RestartIfNeededByRun=no
|
||||
ChangesEnvironment=true
|
||||
|
||||
[Languages]
|
||||
Name: "english"; MessagesFile: "compiler:Default.isl"
|
||||
|
||||
[Files]
|
||||
; The 'DestName' flag ensures that 'olm_windows_amd64.exe' is installed as 'olm.exe'
|
||||
Source: "C:\Users\Administrator\Downloads\olm_windows_amd64.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}"; Flags: ignoreversion
|
||||
Source: "C:\Users\Administrator\Downloads\wintun.dll"; DestDir: "{app}"; Flags: ignoreversion
|
||||
; NOTE: Don't use "Flags: ignoreversion" on any shared system files
|
||||
|
||||
[Icons]
|
||||
Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"
|
||||
|
||||
[Registry]
|
||||
; Add the application's installation directory to the system PATH environment variable.
|
||||
; HKLM (HKEY_LOCAL_MACHINE) is used for system-wide changes.
|
||||
; The 'Path' variable is located under 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'.
|
||||
; ValueType: expandsz allows for environment variables (like %ProgramFiles%) in the path.
|
||||
; ValueData: "{olddata};{app}" appends the current application directory to the existing PATH.
|
||||
; Flags: uninsdeletevalue ensures the entry is removed upon uninstallation.
|
||||
; Check: IsWin64 ensures this is applied on 64-bit systems, which matches ArchitecturesAllowed.
|
||||
[Registry]
|
||||
; Add the application's installation directory to the system PATH.
|
||||
Root: HKLM; Subkey: "SYSTEM\CurrentControlSet\Control\Session Manager\Environment"; \
|
||||
ValueType: expandsz; ValueName: "Path"; ValueData: "{olddata};{app}"; \
|
||||
Flags: uninsdeletevalue; Check: NeedsAddPath(ExpandConstant('{app}'))
|
||||
|
||||
[Code]
|
||||
function NeedsAddPath(Path: string): boolean;
|
||||
var
|
||||
OrigPath: string;
|
||||
begin
|
||||
if not RegQueryStringValue(HKEY_LOCAL_MACHINE,
|
||||
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||
'Path', OrigPath)
|
||||
then begin
|
||||
// Path variable doesn't exist at all, so we definitely need to add it.
|
||||
Result := True;
|
||||
exit;
|
||||
end;
|
||||
|
||||
// Perform a case-insensitive check to see if the path is already present.
|
||||
// We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2).
|
||||
if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then
|
||||
Result := False
|
||||
else
|
||||
Result := True;
|
||||
end;
|
||||
1070
olm/olm.go
Normal file
1070
olm/olm.go
Normal file
File diff suppressed because it is too large
Load Diff
66
olm/types.go
Normal file
66
olm/types.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/olm/peers"
|
||||
)
|
||||
|
||||
type WgData struct {
|
||||
Sites []peers.SiteConfig `json:"sites"`
|
||||
TunnelIP string `json:"tunnelIP"`
|
||||
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
|
||||
}
|
||||
|
||||
type GlobalConfig struct {
|
||||
// Logging
|
||||
LogLevel string
|
||||
|
||||
// HTTP server
|
||||
EnableAPI bool
|
||||
HTTPAddr string
|
||||
SocketPath string
|
||||
Version string
|
||||
Agent string
|
||||
|
||||
// Callbacks
|
||||
OnRegistered func()
|
||||
OnConnected func()
|
||||
OnTerminated func()
|
||||
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
|
||||
OnExit func() // Called when exit is requested via API
|
||||
}
|
||||
|
||||
type TunnelConfig struct {
|
||||
// Connection settings
|
||||
Endpoint string
|
||||
ID string
|
||||
Secret string
|
||||
UserToken string
|
||||
|
||||
// Network settings
|
||||
MTU int
|
||||
DNS string
|
||||
UpstreamDNS []string
|
||||
InterfaceName string
|
||||
|
||||
// Advanced
|
||||
Holepunch bool
|
||||
TlsClientCert string
|
||||
|
||||
// Parsed values (not in JSON)
|
||||
PingIntervalDuration time.Duration
|
||||
PingTimeoutDuration time.Duration
|
||||
|
||||
OrgID string
|
||||
// DoNotCreateNewClient bool
|
||||
|
||||
FileDescriptorTun uint32
|
||||
FileDescriptorUAPI uint32
|
||||
|
||||
EnableUAPI bool
|
||||
|
||||
OverrideDNS bool
|
||||
|
||||
DisableRelay bool
|
||||
}
|
||||
98
olm/util.go
Normal file
98
olm/util.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
)
|
||||
|
||||
// Helper function to format endpoints correctly
|
||||
func formatEndpoint(endpoint string) string {
|
||||
if endpoint == "" {
|
||||
return ""
|
||||
}
|
||||
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
|
||||
_, _, err := net.SplitHostPort(endpoint)
|
||||
if err == nil {
|
||||
return endpoint // Already valid, no change needed
|
||||
}
|
||||
|
||||
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
|
||||
lastColon := strings.LastIndex(endpoint, ":")
|
||||
if lastColon > 0 { // Ensure there is a colon and it's not the first character
|
||||
hostPart := endpoint[:lastColon]
|
||||
// Check if the host part is a literal IPv6 address
|
||||
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
|
||||
// It is! Reformat it with brackets.
|
||||
portPart := endpoint[lastColon+1:]
|
||||
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
|
||||
}
|
||||
}
|
||||
|
||||
// If it's not the specific malformed case, return it as is.
|
||||
return endpoint
|
||||
}
|
||||
|
||||
func sendPing(olm *websocket.Client) error {
|
||||
err := olm.SendMessage("olm/ping", map[string]interface{}{
|
||||
"timestamp": time.Now().Unix(),
|
||||
"userToken": olm.GetConfig().UserToken,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send ping message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Debug("Sent ping message")
|
||||
return nil
|
||||
}
|
||||
|
||||
func keepSendingPing(olm *websocket.Client) {
|
||||
// Send ping immediately on startup
|
||||
if err := sendPing(olm); err != nil {
|
||||
logger.Error("Failed to send initial ping: %v", err)
|
||||
} else {
|
||||
logger.Info("Sent initial ping message")
|
||||
}
|
||||
|
||||
// Set up ticker for one minute intervals
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopPing:
|
||||
logger.Info("Stopping ping messages")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := sendPing(olm); err != nil {
|
||||
logger.Error("Failed to send periodic ping: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetNetworkSettingsJSON() (string, error) {
|
||||
return network.GetJSON()
|
||||
}
|
||||
|
||||
func GetNetworkSettingsIncrementor() int {
|
||||
return network.GetIncrementor()
|
||||
}
|
||||
|
||||
// stringSlicesEqual compares two string slices for equality
|
||||
func stringSlicesEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,324 +0,0 @@
|
||||
package peermonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/fosrl/olm/wgtester"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
// PeerMonitorCallback is the function type for connection status change callbacks
|
||||
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
|
||||
|
||||
// WireGuardConfig holds the WireGuard configuration for a peer
|
||||
type WireGuardConfig struct {
|
||||
SiteID int
|
||||
PublicKey string
|
||||
ServerIP string
|
||||
Endpoint string
|
||||
PrimaryRelay string // The primary relay endpoint
|
||||
}
|
||||
|
||||
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||
type PeerMonitor struct {
|
||||
monitors map[int]*wgtester.Client
|
||||
configs map[int]*WireGuardConfig
|
||||
callback PeerMonitorCallback
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
privateKey string
|
||||
wsClient *websocket.Client
|
||||
device *device.Device
|
||||
handleRelaySwitch bool // Whether to handle relay switching
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
|
||||
return &PeerMonitor{
|
||||
monitors: make(map[int]*wgtester.Client),
|
||||
configs: make(map[int]*WireGuardConfig),
|
||||
callback: callback,
|
||||
interval: 1 * time.Second, // Default check interval
|
||||
timeout: 2500 * time.Millisecond,
|
||||
maxAttempts: 8,
|
||||
privateKey: privateKey,
|
||||
wsClient: wsClient,
|
||||
device: device,
|
||||
handleRelaySwitch: handleRelaySwitch,
|
||||
}
|
||||
}
|
||||
|
||||
// SetInterval changes how frequently peers are checked
|
||||
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.interval = interval
|
||||
|
||||
// Update interval for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetPacketInterval(interval)
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeout changes the timeout for waiting for responses
|
||||
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.timeout = timeout
|
||||
|
||||
// Update timeout for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetTimeout(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.maxAttempts = attempts
|
||||
|
||||
// Update max attempts for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetMaxAttempts(attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeer adds a new peer to monitor
|
||||
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// Check if we're already monitoring this peer
|
||||
if _, exists := pm.monitors[siteID]; exists {
|
||||
// Update the endpoint instead of creating a new monitor
|
||||
pm.removePeerUnlocked(siteID)
|
||||
}
|
||||
|
||||
client, err := wgtester.NewClient(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure the client with our settings
|
||||
client.SetPacketInterval(pm.interval)
|
||||
client.SetTimeout(pm.timeout)
|
||||
client.SetMaxAttempts(pm.maxAttempts)
|
||||
|
||||
// Store the client and config
|
||||
pm.monitors[siteID] = client
|
||||
pm.configs[siteID] = wgConfig
|
||||
|
||||
// If monitor is already running, start monitoring this peer
|
||||
if pm.running {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||
})
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
|
||||
// This function assumes the mutex is already held by the caller
|
||||
func (pm *PeerMonitor) removePeerUnlocked(siteID int) {
|
||||
client, exists := pm.monitors[siteID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
delete(pm.configs, siteID)
|
||||
}
|
||||
|
||||
// RemovePeer stops monitoring a peer and removes it from the monitor
|
||||
func (pm *PeerMonitor) RemovePeer(siteID int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.removePeerUnlocked(siteID)
|
||||
}
|
||||
|
||||
// Start begins monitoring all peers
|
||||
func (pm *PeerMonitor) Start() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if pm.running {
|
||||
return // Already running
|
||||
}
|
||||
|
||||
pm.running = true
|
||||
|
||||
// Start monitoring all peers
|
||||
for siteID, client := range pm.monitors {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err)
|
||||
continue
|
||||
}
|
||||
logger.Info("Started monitoring peer %d\n", siteID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnectionStatusChange is called when a peer's connection status changes
|
||||
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status wgtester.ConnectionStatus) {
|
||||
// Call the user-provided callback first
|
||||
if pm.callback != nil {
|
||||
pm.callback(siteID, status.Connected, status.RTT)
|
||||
}
|
||||
|
||||
// If disconnected, handle failover
|
||||
if !status.Connected {
|
||||
// Send relay message to the server
|
||||
if pm.wsClient != nil {
|
||||
pm.sendRelay(siteID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleFailover handles failover to the relay server when a peer is disconnected
|
||||
func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) {
|
||||
pm.mutex.Lock()
|
||||
config, exists := pm.configs[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Configure WireGuard to use the relay
|
||||
wgConfig := fmt.Sprintf(`private_key=%s
|
||||
public_key=%s
|
||||
allowed_ip=%s/32
|
||||
endpoint=%s:21820
|
||||
persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, relayEndpoint)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
logger.Error("Failed to configure WireGuard device: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Adjusted peer %d to point to relay!\n", siteID)
|
||||
}
|
||||
|
||||
// sendRelay sends a relay message to the server
|
||||
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||
if !pm.handleRelaySwitch {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent relay message")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops monitoring all peers
|
||||
func (pm *PeerMonitor) Stop() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
|
||||
// Stop all monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.StopMonitor()
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops monitoring and cleans up resources
|
||||
func (pm *PeerMonitor) Close() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// Stop and close all clients
|
||||
for siteID, client := range pm.monitors {
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
}
|
||||
|
||||
// TestPeer tests connectivity to a specific peer
|
||||
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||
pm.mutex.Lock()
|
||||
client, exists := pm.monitors[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
defer cancel()
|
||||
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
return connected, rtt, nil
|
||||
}
|
||||
|
||||
// TestAllPeers tests connectivity to all peers
|
||||
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
} {
|
||||
pm.mutex.Lock()
|
||||
peers := make(map[int]*wgtester.Client, len(pm.monitors))
|
||||
for siteID, client := range pm.monitors {
|
||||
peers[siteID] = client
|
||||
}
|
||||
pm.mutex.Unlock()
|
||||
|
||||
results := make(map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
})
|
||||
for siteID, client := range peers {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
cancel()
|
||||
|
||||
results[siteID] = struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
}{
|
||||
Connected: connected,
|
||||
RTT: rtt,
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
882
peers/manager.go
Normal file
882
peers/manager.go
Normal file
@@ -0,0 +1,882 @@
|
||||
package peers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/api"
|
||||
olmDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/dns"
|
||||
"github.com/fosrl/olm/peers/monitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// PeerManagerConfig contains the configuration for creating a PeerManager
|
||||
type PeerManagerConfig struct {
|
||||
Device *device.Device
|
||||
DNSProxy *dns.DNSProxy
|
||||
InterfaceName string
|
||||
PrivateKey wgtypes.Key
|
||||
// For peer monitoring
|
||||
MiddleDev *olmDevice.MiddleDevice
|
||||
LocalIP string
|
||||
SharedBind *bind.SharedBind
|
||||
// WSClient is optional - if nil, relay messages won't be sent
|
||||
WSClient *websocket.Client
|
||||
APIServer *api.API
|
||||
}
|
||||
|
||||
type PeerManager struct {
|
||||
mu sync.RWMutex
|
||||
device *device.Device
|
||||
peers map[int]SiteConfig
|
||||
peerMonitor *monitor.PeerMonitor
|
||||
dnsProxy *dns.DNSProxy
|
||||
interfaceName string
|
||||
privateKey wgtypes.Key
|
||||
// allowedIPOwners tracks which peer currently "owns" each allowed IP in WireGuard
|
||||
// key is the CIDR string, value is the siteId that has it configured in WG
|
||||
allowedIPOwners map[string]int
|
||||
// allowedIPClaims tracks all peers that claim each allowed IP
|
||||
// key is the CIDR string, value is a set of siteIds that want this IP
|
||||
allowedIPClaims map[string]map[int]bool
|
||||
APIServer *api.API
|
||||
}
|
||||
|
||||
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
|
||||
func NewPeerManager(config PeerManagerConfig) *PeerManager {
|
||||
pm := &PeerManager{
|
||||
device: config.Device,
|
||||
peers: make(map[int]SiteConfig),
|
||||
dnsProxy: config.DNSProxy,
|
||||
interfaceName: config.InterfaceName,
|
||||
privateKey: config.PrivateKey,
|
||||
allowedIPOwners: make(map[string]int),
|
||||
allowedIPClaims: make(map[string]map[int]bool),
|
||||
APIServer: config.APIServer,
|
||||
}
|
||||
|
||||
// Create the peer monitor
|
||||
pm.peerMonitor = monitor.NewPeerMonitor(
|
||||
config.WSClient,
|
||||
config.MiddleDev,
|
||||
config.LocalIP,
|
||||
config.SharedBind,
|
||||
config.APIServer,
|
||||
)
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
peer, ok := pm.peers[siteId]
|
||||
return peer, ok
|
||||
}
|
||||
|
||||
func (pm *PeerManager) GetAllPeers() []SiteConfig {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
peers := make([]SiteConfig, 0, len(pm.peers))
|
||||
for _, peer := range pm.peers {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
// build the allowed IPs list from the remote subnets and aliases and add them to the peer
|
||||
allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
|
||||
allowedIPs = append(allowedIPs, siteConfig.RemoteSubnets...)
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
allowedIPs = append(allowedIPs, alias.AliasAddress+"/32")
|
||||
}
|
||||
siteConfig.AllowedIps = allowedIPs
|
||||
|
||||
// Register claims for all allowed IPs and determine which ones this peer will own
|
||||
ownedIPs := make([]string, 0, len(allowedIPs))
|
||||
for _, ip := range allowedIPs {
|
||||
pm.claimAllowedIP(siteConfig.SiteId, ip)
|
||||
// Check if this peer became the owner
|
||||
if pm.allowedIPOwners[ip] == siteConfig.SiteId {
|
||||
ownedIPs = append(ownedIPs, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a config with only the owned IPs for WireGuard
|
||||
wgConfig := siteConfig
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := network.AddRouteForServerIP(siteConfig.ServerIP, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to add route for server IP: %v", err)
|
||||
}
|
||||
if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||
}
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||
|
||||
err := pm.peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, siteConfig.Endpoint) // always use the real site endpoint for hole punch monitoring
|
||||
if err != nil {
|
||||
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
|
||||
} else {
|
||||
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||
}
|
||||
|
||||
pm.peers[siteConfig.SiteId] = siteConfig
|
||||
|
||||
// Perform rapid initial holepunch test (outside of lock to avoid blocking)
|
||||
// This quickly determines if holepunch is viable and triggers relay if not
|
||||
go pm.performRapidInitialTest(siteConfig.SiteId, siteConfig.Endpoint)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *PeerManager) RemovePeer(siteId int) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
if err := RemovePeer(pm.device, siteId, peer.PublicKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := network.RemoveRouteForServerIP(peer.ServerIP, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to remove route for server IP: %v", err)
|
||||
}
|
||||
|
||||
// Only remove routes for subnets that aren't used by other peers
|
||||
for _, subnet := range peer.RemoteSubnets {
|
||||
subnetStillInUse := false
|
||||
for otherSiteId, otherPeer := range pm.peers {
|
||||
if otherSiteId == siteId {
|
||||
continue // Skip the peer being removed
|
||||
}
|
||||
for _, otherSubnet := range otherPeer.RemoteSubnets {
|
||||
if otherSubnet == subnet {
|
||||
subnetStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if subnetStillInUse {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !subnetStillInUse {
|
||||
if err := network.RemoveRoutes([]string{subnet}); err != nil {
|
||||
logger.Error("Failed to remove route for remote subnet %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For aliases
|
||||
for _, alias := range peer.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.RemoveDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
// Release all IP claims and promote other peers as needed
|
||||
// Collect promotions first to avoid modifying while iterating
|
||||
type promotion struct {
|
||||
newOwner int
|
||||
cidr string
|
||||
}
|
||||
var promotions []promotion
|
||||
|
||||
for _, ip := range peer.AllowedIps {
|
||||
newOwner, promoted := pm.releaseAllowedIP(siteId, ip)
|
||||
if promoted && newOwner >= 0 {
|
||||
promotions = append(promotions, promotion{newOwner: newOwner, cidr: ip})
|
||||
}
|
||||
}
|
||||
|
||||
// Apply promotions - update WireGuard config for newly promoted peers
|
||||
// Group by peer to avoid multiple config updates
|
||||
promotedPeers := make(map[int]bool)
|
||||
for _, p := range promotions {
|
||||
promotedPeers[p.newOwner] = true
|
||||
logger.Info("Promoted peer %d to owner of IP %s", p.newOwner, p.cidr)
|
||||
}
|
||||
|
||||
for promotedPeerId := range promotedPeers {
|
||||
if promotedPeer, exists := pm.peers[promotedPeerId]; exists {
|
||||
// Build the list of IPs this peer now owns
|
||||
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||
wgConfig := promotedPeer
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
|
||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop monitoring this peer
|
||||
pm.peerMonitor.RemovePeer(siteId)
|
||||
logger.Info("Stopped monitoring for site %d", siteId)
|
||||
|
||||
pm.APIServer.RemovePeerStatus(siteId)
|
||||
|
||||
delete(pm.peers, siteId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
oldPeer, exists := pm.peers[siteConfig.SiteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteConfig.SiteId)
|
||||
}
|
||||
|
||||
// If public key changed, remove old peer first
|
||||
if siteConfig.PublicKey != oldPeer.PublicKey {
|
||||
if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil {
|
||||
logger.Error("Failed to remove old peer: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build the new allowed IPs list
|
||||
newAllowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
|
||||
newAllowedIPs = append(newAllowedIPs, siteConfig.RemoteSubnets...)
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
newAllowedIPs = append(newAllowedIPs, alias.AliasAddress+"/32")
|
||||
}
|
||||
siteConfig.AllowedIps = newAllowedIPs
|
||||
|
||||
// Handle allowed IP claim changes
|
||||
oldAllowedIPs := make(map[string]bool)
|
||||
for _, ip := range oldPeer.AllowedIps {
|
||||
oldAllowedIPs[ip] = true
|
||||
}
|
||||
newAllowedIPsSet := make(map[string]bool)
|
||||
for _, ip := range newAllowedIPs {
|
||||
newAllowedIPsSet[ip] = true
|
||||
}
|
||||
|
||||
// Track peers that need WireGuard config updates due to promotions
|
||||
peersToUpdate := make(map[int]bool)
|
||||
|
||||
// Release claims for removed IPs and handle promotions
|
||||
for ip := range oldAllowedIPs {
|
||||
if !newAllowedIPsSet[ip] {
|
||||
newOwner, promoted := pm.releaseAllowedIP(siteConfig.SiteId, ip)
|
||||
if promoted && newOwner >= 0 {
|
||||
peersToUpdate[newOwner] = true
|
||||
logger.Info("Promoted peer %d to owner of IP %s", newOwner, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add claims for new IPs
|
||||
for ip := range newAllowedIPsSet {
|
||||
if !oldAllowedIPs[ip] {
|
||||
pm.claimAllowedIP(siteConfig.SiteId, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Build the list of IPs this peer owns for WireGuard config
|
||||
ownedIPs := pm.getOwnedAllowedIPs(siteConfig.SiteId)
|
||||
wgConfig := siteConfig
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update WireGuard config for any promoted peers
|
||||
for promotedPeerId := range peersToUpdate {
|
||||
if promotedPeer, exists := pm.peers[promotedPeerId]; exists {
|
||||
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||
promotedWgConfig := promotedPeer
|
||||
promotedWgConfig.AllowedIps = promotedOwnedIPs
|
||||
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
|
||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle remote subnet route changes
|
||||
// Calculate added and removed subnets
|
||||
oldSubnets := make(map[string]bool)
|
||||
for _, s := range oldPeer.RemoteSubnets {
|
||||
oldSubnets[s] = true
|
||||
}
|
||||
newSubnets := make(map[string]bool)
|
||||
for _, s := range siteConfig.RemoteSubnets {
|
||||
newSubnets[s] = true
|
||||
}
|
||||
|
||||
var addedSubnets []string
|
||||
var removedSubnets []string
|
||||
|
||||
for s := range newSubnets {
|
||||
if !oldSubnets[s] {
|
||||
addedSubnets = append(addedSubnets, s)
|
||||
}
|
||||
}
|
||||
for s := range oldSubnets {
|
||||
if !newSubnets[s] {
|
||||
removedSubnets = append(removedSubnets, s)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove routes for removed subnets (only if no other peer needs them)
|
||||
for _, subnet := range removedSubnets {
|
||||
subnetStillInUse := false
|
||||
for otherSiteId, otherPeer := range pm.peers {
|
||||
if otherSiteId == siteConfig.SiteId {
|
||||
continue // Skip the current peer (already updated)
|
||||
}
|
||||
for _, otherSubnet := range otherPeer.RemoteSubnets {
|
||||
if otherSubnet == subnet {
|
||||
subnetStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if subnetStillInUse {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !subnetStillInUse {
|
||||
if err := network.RemoveRoutes([]string{subnet}); err != nil {
|
||||
logger.Error("Failed to remove route for subnet %s: %v", subnet, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add routes for added subnets
|
||||
if len(addedSubnets) > 0 {
|
||||
if err := network.AddRoutes(addedSubnets, pm.interfaceName); err != nil {
|
||||
logger.Error("Failed to add routes: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update aliases
|
||||
// Remove old aliases
|
||||
for _, alias := range oldPeer.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.RemoveDNSRecord(alias.Alias, address)
|
||||
}
|
||||
// Add new aliases
|
||||
for _, alias := range siteConfig.Aliases {
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address == nil {
|
||||
continue
|
||||
}
|
||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
|
||||
|
||||
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||
pm.peerMonitor.UpdatePeerEndpoint(siteConfig.SiteId, monitorPeer) // +1 for monitor port
|
||||
|
||||
pm.peers[siteConfig.SiteId] = siteConfig
|
||||
return nil
|
||||
}
|
||||
|
||||
// claimAllowedIP registers a peer's claim to an allowed IP.
|
||||
// If no other peer owns it in WireGuard, this peer becomes the owner.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) claimAllowedIP(siteId int, cidr string) {
|
||||
// Add to claims
|
||||
if pm.allowedIPClaims[cidr] == nil {
|
||||
pm.allowedIPClaims[cidr] = make(map[int]bool)
|
||||
}
|
||||
pm.allowedIPClaims[cidr][siteId] = true
|
||||
|
||||
// If no owner yet, this peer becomes the owner
|
||||
if _, hasOwner := pm.allowedIPOwners[cidr]; !hasOwner {
|
||||
pm.allowedIPOwners[cidr] = siteId
|
||||
}
|
||||
}
|
||||
|
||||
// releaseAllowedIP removes a peer's claim to an allowed IP.
|
||||
// If this peer was the owner, it promotes another claimant to owner.
|
||||
// Returns the new owner's siteId (or -1 if no new owner) and whether promotion occurred.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) releaseAllowedIP(siteId int, cidr string) (newOwner int, promoted bool) {
|
||||
// Remove from claims
|
||||
if claims, exists := pm.allowedIPClaims[cidr]; exists {
|
||||
delete(claims, siteId)
|
||||
if len(claims) == 0 {
|
||||
delete(pm.allowedIPClaims, cidr)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this peer was the owner
|
||||
owner, isOwned := pm.allowedIPOwners[cidr]
|
||||
if !isOwned || owner != siteId {
|
||||
return -1, false // Not the owner, nothing to promote
|
||||
}
|
||||
|
||||
// This peer was the owner, need to find a new owner
|
||||
delete(pm.allowedIPOwners, cidr)
|
||||
|
||||
// Find another claimant to promote
|
||||
if claims, exists := pm.allowedIPClaims[cidr]; exists && len(claims) > 0 {
|
||||
for claimantId := range claims {
|
||||
pm.allowedIPOwners[cidr] = claimantId
|
||||
return claimantId, true
|
||||
}
|
||||
}
|
||||
|
||||
return -1, false
|
||||
}
|
||||
|
||||
// getOwnedAllowedIPs returns the list of allowed IPs that a peer currently owns in WireGuard.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) getOwnedAllowedIPs(siteId int) []string {
|
||||
var owned []string
|
||||
for cidr, owner := range pm.allowedIPOwners {
|
||||
if owner == siteId {
|
||||
owned = append(owned, cidr)
|
||||
}
|
||||
}
|
||||
return owned
|
||||
}
|
||||
|
||||
// addAllowedIp adds an IP (subnet) to the allowed IPs list of a peer
|
||||
// and updates WireGuard configuration if this peer owns the IP.
|
||||
// Must be called with lock held.
|
||||
func (pm *PeerManager) addAllowedIp(siteId int, ip string) error {
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
// Check if IP already exists in AllowedIps
|
||||
for _, allowedIp := range peer.AllowedIps {
|
||||
if allowedIp == ip {
|
||||
return nil // Already exists
|
||||
}
|
||||
}
|
||||
|
||||
// Register our claim to this IP
|
||||
pm.claimAllowedIP(siteId, ip)
|
||||
|
||||
peer.AllowedIps = append(peer.AllowedIps, ip)
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
// Only update WireGuard if we own this IP
|
||||
if pm.allowedIPOwners[ip] == siteId {
|
||||
if err := AddAllowedIP(pm.device, peer.PublicKey, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeAllowedIp removes an IP (subnet) from the allowed IPs list of a peer
|
||||
// and updates WireGuard configuration. If this peer owned the IP, it promotes
|
||||
// another peer that also claims this IP. Must be called with lock held.
|
||||
func (pm *PeerManager) removeAllowedIp(siteId int, cidr string) error {
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
found := false
|
||||
|
||||
// Remove from AllowedIps
|
||||
newAllowedIps := make([]string, 0, len(peer.AllowedIps))
|
||||
for _, allowedIp := range peer.AllowedIps {
|
||||
if allowedIp == cidr {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newAllowedIps = append(newAllowedIps, allowedIp)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil // Not found
|
||||
}
|
||||
|
||||
peer.AllowedIps = newAllowedIps
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
// Release our claim and check if we need to promote another peer
|
||||
newOwner, promoted := pm.releaseAllowedIP(siteId, cidr)
|
||||
|
||||
// Build the list of IPs this peer currently owns for the replace operation
|
||||
ownedIPs := pm.getOwnedAllowedIPs(siteId)
|
||||
// Also include the server IP which is always owned
|
||||
serverIP := strings.Split(peer.ServerIP, "/")[0] + "/32"
|
||||
hasServerIP := false
|
||||
for _, ip := range ownedIPs {
|
||||
if ip == serverIP {
|
||||
hasServerIP = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasServerIP {
|
||||
ownedIPs = append([]string{serverIP}, ownedIPs...)
|
||||
}
|
||||
|
||||
// Update WireGuard for this peer using replace_allowed_ips
|
||||
if err := RemoveAllowedIP(pm.device, peer.PublicKey, ownedIPs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If another peer was promoted to owner, add the IP to their WireGuard config
|
||||
if promoted && newOwner >= 0 {
|
||||
if newOwnerPeer, exists := pm.peers[newOwner]; exists {
|
||||
if err := AddAllowedIP(pm.device, newOwnerPeer.PublicKey, cidr); err != nil {
|
||||
logger.Error("Failed to promote peer %d for IP %s: %v", newOwner, cidr, err)
|
||||
} else {
|
||||
logger.Info("Promoted peer %d to owner of IP %s", newOwner, cidr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRemoteSubnet adds an IP (subnet) to the allowed IPs list of a peer
|
||||
func (pm *PeerManager) AddRemoteSubnet(siteId int, cidr string) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
// Check if IP already exists in RemoteSubnets
|
||||
for _, subnet := range peer.RemoteSubnets {
|
||||
if subnet == cidr {
|
||||
return nil // Already exists
|
||||
}
|
||||
}
|
||||
|
||||
peer.RemoteSubnets = append(peer.RemoteSubnets, cidr)
|
||||
pm.peers[siteId] = peer // Save before calling addAllowedIp which reads from pm.peers
|
||||
|
||||
// Add to allowed IPs
|
||||
if err := pm.addAllowedIp(siteId, cidr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add route
|
||||
if err := network.AddRoutes([]string{cidr}, pm.interfaceName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRemoteSubnet removes an IP (subnet) from the allowed IPs list of a peer
|
||||
func (pm *PeerManager) RemoveRemoteSubnet(siteId int, ip string) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
found := false
|
||||
|
||||
// Remove from RemoteSubnets
|
||||
newSubnets := make([]string, 0, len(peer.RemoteSubnets))
|
||||
for _, subnet := range peer.RemoteSubnets {
|
||||
if subnet == ip {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newSubnets = append(newSubnets, subnet)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil // Not found
|
||||
}
|
||||
|
||||
peer.RemoteSubnets = newSubnets
|
||||
pm.peers[siteId] = peer // Save before calling removeAllowedIp which reads from pm.peers
|
||||
|
||||
// Remove from allowed IPs (this also handles promotion of other peers)
|
||||
if err := pm.removeAllowedIp(siteId, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if any other peer still has this subnet before removing the route
|
||||
subnetStillInUse := false
|
||||
for otherSiteId, otherPeer := range pm.peers {
|
||||
if otherSiteId == siteId {
|
||||
continue // Skip the current peer (already updated above)
|
||||
}
|
||||
for _, subnet := range otherPeer.RemoteSubnets {
|
||||
if subnet == ip {
|
||||
subnetStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if subnetStillInUse {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Only remove route if no other peer needs it
|
||||
if !subnetStillInUse {
|
||||
if err := network.RemoveRoutes([]string{ip}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddAlias adds an alias to a peer
|
||||
func (pm *PeerManager) AddAlias(siteId int, alias Alias) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
peer.Aliases = append(peer.Aliases, alias)
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
address := net.ParseIP(alias.AliasAddress)
|
||||
if address != nil {
|
||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
||||
}
|
||||
|
||||
// Add an allowed IP for the alias
|
||||
if err := pm.addAllowedIp(siteId, alias.AliasAddress+"/32"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAlias removes an alias from a peer
|
||||
func (pm *PeerManager) RemoveAlias(siteId int, aliasName string) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
peer, exists := pm.peers[siteId]
|
||||
if !exists {
|
||||
return fmt.Errorf("peer with site ID %d not found", siteId)
|
||||
}
|
||||
|
||||
var aliasToRemove *Alias
|
||||
newAliases := make([]Alias, 0, len(peer.Aliases))
|
||||
for _, a := range peer.Aliases {
|
||||
if a.Alias == aliasName {
|
||||
aliasToRemove = &a
|
||||
continue
|
||||
}
|
||||
newAliases = append(newAliases, a)
|
||||
}
|
||||
|
||||
if aliasToRemove != nil {
|
||||
address := net.ParseIP(aliasToRemove.AliasAddress)
|
||||
if address != nil {
|
||||
pm.dnsProxy.RemoveDNSRecord(aliasName, address)
|
||||
}
|
||||
}
|
||||
|
||||
peer.Aliases = newAliases
|
||||
pm.peers[siteId] = peer
|
||||
|
||||
// Check if any other alias is still using this IP address before removing from allowed IPs
|
||||
ipStillInUse := false
|
||||
aliasIP := aliasToRemove.AliasAddress + "/32"
|
||||
for _, a := range newAliases {
|
||||
if a.AliasAddress+"/32" == aliasIP {
|
||||
ipStillInUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Only remove the allowed IP if no other alias is using it
|
||||
if !ipStillInUse {
|
||||
if err := pm.removeAllowedIp(siteId, aliasIP); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelayPeer handles failover to the relay server when a peer is disconnected
|
||||
func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string) {
|
||||
pm.mu.Lock()
|
||||
peer, exists := pm.peers[siteId]
|
||||
if exists {
|
||||
// Store the relay endpoint
|
||||
peer.RelayEndpoint = relayEndpoint
|
||||
pm.peers[siteId] = peer
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for IPv6 and format the endpoint correctly
|
||||
formattedEndpoint := relayEndpoint
|
||||
if strings.Contains(relayEndpoint, ":") {
|
||||
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
|
||||
}
|
||||
|
||||
// Update only the endpoint for this peer (update_only preserves other settings)
|
||||
wgConfig := fmt.Sprintf(`public_key=%s
|
||||
update_only=true
|
||||
endpoint=%s:21820`, util.FixKey(peer.PublicKey), formattedEndpoint)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
logger.Error("Failed to configure WireGuard device: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the peer as relayed in the monitor
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.MarkPeerRelayed(siteId, true)
|
||||
}
|
||||
|
||||
logger.Info("Adjusted peer %d to point to relay!\n", siteId)
|
||||
}
|
||||
|
||||
// performRapidInitialTest performs a rapid holepunch test for a newly added peer.
|
||||
// If the test fails, it immediately requests relay to minimize connection delay.
|
||||
// This runs in a goroutine to avoid blocking AddPeer.
|
||||
func (pm *PeerManager) performRapidInitialTest(siteId int, endpoint string) {
|
||||
if pm.peerMonitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Perform rapid test - this takes ~1-2 seconds max
|
||||
holepunchViable := pm.peerMonitor.RapidTestPeer(siteId, endpoint)
|
||||
|
||||
if !holepunchViable {
|
||||
// Holepunch failed rapid test, request relay immediately
|
||||
logger.Info("Rapid test failed for site %d, requesting relay", siteId)
|
||||
if err := pm.peerMonitor.RequestRelay(siteId); err != nil {
|
||||
logger.Error("Failed to request relay for site %d: %v", siteId, err)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Rapid test passed for site %d, using direct connection", siteId)
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the peer monitor
|
||||
func (pm *PeerManager) Start() {
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.Start()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the peer monitor
|
||||
func (pm *PeerManager) Stop() {
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the peer monitor and cleans up resources
|
||||
func (pm *PeerManager) Close() {
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.Close()
|
||||
pm.peerMonitor = nil
|
||||
}
|
||||
}
|
||||
|
||||
// MarkPeerRelayed marks a peer as currently using relay
|
||||
func (pm *PeerManager) MarkPeerRelayed(siteID int, relayed bool) {
|
||||
pm.mu.Lock()
|
||||
if peer, exists := pm.peers[siteID]; exists {
|
||||
if relayed {
|
||||
// We're being relayed, store the current endpoint as the original
|
||||
// (RelayEndpoint is set by HandleFailover)
|
||||
} else {
|
||||
// Clear relay endpoint when switching back to direct
|
||||
peer.RelayEndpoint = ""
|
||||
pm.peers[siteID] = peer
|
||||
}
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.MarkPeerRelayed(siteID, relayed)
|
||||
}
|
||||
}
|
||||
|
||||
// UnRelayPeer switches a peer from relay back to direct connection
|
||||
func (pm *PeerManager) UnRelayPeer(siteId int, endpoint string) error {
|
||||
pm.mu.Lock()
|
||||
peer, exists := pm.peers[siteId]
|
||||
if exists {
|
||||
// Store the relay endpoint
|
||||
peer.Endpoint = endpoint
|
||||
pm.peers[siteId] = peer
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update WireGuard to use the direct endpoint
|
||||
wgConfig := fmt.Sprintf(`public_key=%s
|
||||
update_only=true
|
||||
endpoint=%s`, util.FixKey(peer.PublicKey), endpoint)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
logger.Error("Failed to switch peer %d to direct connection: %v", siteId, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark as not relayed in monitor
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.MarkPeerRelayed(siteId, false)
|
||||
}
|
||||
|
||||
logger.Info("Switched peer %d back to direct connection at %s", siteId, endpoint)
|
||||
return nil
|
||||
}
|
||||
924
peers/monitor/monitor.go
Normal file
924
peers/monitor/monitor.go
Normal file
@@ -0,0 +1,924 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/api"
|
||||
middleDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||
type PeerMonitor struct {
|
||||
monitors map[int]*Client
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
wsClient *websocket.Client
|
||||
|
||||
// Netstack fields
|
||||
middleDev *middleDevice.MiddleDevice
|
||||
localIP string
|
||||
stack *stack.Stack
|
||||
ep *channel.Endpoint
|
||||
activePorts map[uint16]bool
|
||||
portsLock sync.Mutex
|
||||
nsCtx context.Context
|
||||
nsCancel context.CancelFunc
|
||||
nsWg sync.WaitGroup
|
||||
|
||||
// Holepunch testing fields
|
||||
sharedBind *bind.SharedBind
|
||||
holepunchTester *holepunch.HolepunchTester
|
||||
holepunchInterval time.Duration
|
||||
holepunchTimeout time.Duration
|
||||
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
||||
holepunchStatus map[int]bool // siteID -> connected status
|
||||
holepunchStopChan chan struct{}
|
||||
|
||||
// Relay tracking fields
|
||||
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
|
||||
holepunchMaxAttempts int // max consecutive failures before triggering relay
|
||||
holepunchFailures map[int]int // siteID -> consecutive failure count
|
||||
|
||||
// Rapid initial test fields
|
||||
rapidTestInterval time.Duration // interval between rapid test attempts
|
||||
rapidTestTimeout time.Duration // timeout for each rapid test attempt
|
||||
rapidTestMaxAttempts int // max attempts during rapid test phase
|
||||
|
||||
// API server for status updates
|
||||
apiServer *api.API
|
||||
|
||||
// WG connection status tracking
|
||||
wgConnectionStatus map[int]bool // siteID -> WG connected status
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pm := &PeerMonitor{
|
||||
monitors: make(map[int]*Client),
|
||||
interval: 2 * time.Second, // Default check interval (faster)
|
||||
timeout: 3 * time.Second,
|
||||
maxAttempts: 3,
|
||||
wsClient: wsClient,
|
||||
middleDev: middleDev,
|
||||
localIP: localIP,
|
||||
activePorts: make(map[uint16]bool),
|
||||
nsCtx: ctx,
|
||||
nsCancel: cancel,
|
||||
sharedBind: sharedBind,
|
||||
holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds
|
||||
holepunchTimeout: 2 * time.Second, // Faster timeout
|
||||
holepunchEndpoints: make(map[int]string),
|
||||
holepunchStatus: make(map[int]bool),
|
||||
relayedPeers: make(map[int]bool),
|
||||
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
|
||||
holepunchFailures: make(map[int]int),
|
||||
// Rapid initial test settings: complete within ~1.5 seconds
|
||||
rapidTestInterval: 200 * time.Millisecond, // 200ms between attempts
|
||||
rapidTestTimeout: 400 * time.Millisecond, // 400ms timeout per attempt
|
||||
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
|
||||
apiServer: apiServer,
|
||||
wgConnectionStatus: make(map[int]bool),
|
||||
}
|
||||
|
||||
if err := pm.initNetstack(); err != nil {
|
||||
logger.Error("Failed to initialize netstack for peer monitor: %v", err)
|
||||
}
|
||||
|
||||
// Initialize holepunch tester if sharedBind is available
|
||||
if sharedBind != nil {
|
||||
pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind)
|
||||
}
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
// SetInterval changes how frequently peers are checked
|
||||
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.interval = interval
|
||||
|
||||
// Update interval for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetPacketInterval(interval)
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeout changes the timeout for waiting for responses
|
||||
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.timeout = timeout
|
||||
|
||||
// Update timeout for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetTimeout(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.maxAttempts = attempts
|
||||
|
||||
// Update max attempts for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetMaxAttempts(attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeer adds a new peer to monitor
|
||||
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint string) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if _, exists := pm.monitors[siteID]; exists {
|
||||
return nil // Already monitoring
|
||||
}
|
||||
|
||||
// Use our custom dialer that uses netstack
|
||||
client, err := NewClient(endpoint, pm.dial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.SetPacketInterval(pm.interval)
|
||||
client.SetTimeout(pm.timeout)
|
||||
client.SetMaxAttempts(pm.maxAttempts)
|
||||
|
||||
pm.monitors[siteID] = client
|
||||
|
||||
pm.holepunchEndpoints[siteID] = holepunchEndpoint
|
||||
pm.holepunchStatus[siteID] = false // Initially unknown/disconnected
|
||||
|
||||
if pm.running {
|
||||
if err := client.StartMonitor(func(status ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteID, status)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// update holepunch endpoint for a peer
|
||||
func (pm *PeerMonitor) UpdateHolepunchEndpoint(siteID int, endpoint string) {
|
||||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
pm.holepunchEndpoints[siteID] = endpoint
|
||||
}()
|
||||
}
|
||||
|
||||
// RapidTestPeer performs a rapid connectivity test for a newly added peer.
|
||||
// This is designed to quickly determine if holepunch is viable within ~1-2 seconds.
|
||||
// Returns true if the connection is viable (holepunch works), false if it should relay.
|
||||
func (pm *PeerMonitor) RapidTestPeer(siteID int, endpoint string) bool {
|
||||
if pm.holepunchTester == nil {
|
||||
logger.Warn("Cannot perform rapid test: holepunch tester not initialized")
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mutex.Lock()
|
||||
interval := pm.rapidTestInterval
|
||||
timeout := pm.rapidTestTimeout
|
||||
maxAttempts := pm.rapidTestMaxAttempts
|
||||
pm.mutex.Unlock()
|
||||
|
||||
logger.Info("Starting rapid holepunch test for site %d at %s (max %d attempts, %v timeout each)",
|
||||
siteID, endpoint, maxAttempts, timeout)
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
|
||||
|
||||
if result.Success {
|
||||
logger.Info("Rapid test: site %d holepunch SUCCEEDED on attempt %d (RTT: %v)",
|
||||
siteID, attempt, result.RTT)
|
||||
|
||||
// Update status
|
||||
pm.mutex.Lock()
|
||||
pm.holepunchStatus[siteID] = true
|
||||
pm.holepunchFailures[siteID] = 0
|
||||
pm.mutex.Unlock()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if attempt < maxAttempts {
|
||||
time.Sleep(interval)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Warn("Rapid test: site %d holepunch FAILED after %d attempts, will relay",
|
||||
siteID, maxAttempts)
|
||||
|
||||
// Update status to reflect failure
|
||||
pm.mutex.Lock()
|
||||
pm.holepunchStatus[siteID] = false
|
||||
pm.holepunchFailures[siteID] = maxAttempts
|
||||
pm.mutex.Unlock()
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdatePeerEndpoint updates the monitor endpoint for a peer
|
||||
func (pm *PeerMonitor) UpdatePeerEndpoint(siteID int, monitorPeer string) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
client, exists := pm.monitors[siteID]
|
||||
if !exists {
|
||||
logger.Warn("Cannot update endpoint: peer %d not found in monitor", siteID)
|
||||
return
|
||||
}
|
||||
|
||||
// Update the client's server address
|
||||
client.UpdateServerAddr(monitorPeer)
|
||||
|
||||
logger.Info("Updated monitor endpoint for site %d to %s", siteID, monitorPeer)
|
||||
}
|
||||
|
||||
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
|
||||
// This function assumes the mutex is already held by the caller
|
||||
func (pm *PeerMonitor) removePeerUnlocked(siteID int) {
|
||||
client, exists := pm.monitors[siteID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
}
|
||||
|
||||
// RemovePeer stops monitoring a peer and removes it from the monitor
|
||||
func (pm *PeerMonitor) RemovePeer(siteID int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// remove the holepunch endpoint info
|
||||
delete(pm.holepunchEndpoints, siteID)
|
||||
delete(pm.holepunchStatus, siteID)
|
||||
delete(pm.relayedPeers, siteID)
|
||||
delete(pm.holepunchFailures, siteID)
|
||||
|
||||
pm.removePeerUnlocked(siteID)
|
||||
}
|
||||
|
||||
// Start begins monitoring all peers
|
||||
func (pm *PeerMonitor) Start() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if pm.running {
|
||||
return // Already running
|
||||
}
|
||||
|
||||
pm.running = true
|
||||
|
||||
// Start monitoring all peers
|
||||
for siteID, client := range pm.monitors {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
err := client.StartMonitor(func(status ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err)
|
||||
continue
|
||||
}
|
||||
logger.Info("Started monitoring peer %d\n", siteID)
|
||||
}
|
||||
|
||||
pm.startHolepunchMonitor()
|
||||
}
|
||||
|
||||
// handleConnectionStatusChange is called when a peer's connection status changes
|
||||
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) {
|
||||
pm.mutex.Lock()
|
||||
previousStatus, exists := pm.wgConnectionStatus[siteID]
|
||||
pm.wgConnectionStatus[siteID] = status.Connected
|
||||
isRelayed := pm.relayedPeers[siteID]
|
||||
endpoint := pm.holepunchEndpoints[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
// Log status changes
|
||||
if !exists || previousStatus != status.Connected {
|
||||
if status.Connected {
|
||||
logger.Info("WireGuard connection to site %d is CONNECTED (RTT: %v)", siteID, status.RTT)
|
||||
} else {
|
||||
logger.Warn("WireGuard connection to site %d is DISCONNECTED", siteID)
|
||||
}
|
||||
}
|
||||
|
||||
// Update API with connection status
|
||||
if pm.apiServer != nil {
|
||||
pm.apiServer.UpdatePeerStatus(siteID, status.Connected, status.RTT, endpoint, isRelayed)
|
||||
}
|
||||
}
|
||||
|
||||
// sendRelay sends a relay message to the server
|
||||
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent relay message")
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestRelay is a public method to request relay for a peer.
|
||||
// This is used when rapid initial testing determines holepunch is not viable.
|
||||
func (pm *PeerMonitor) RequestRelay(siteID int) error {
|
||||
return pm.sendRelay(siteID)
|
||||
}
|
||||
|
||||
// sendUnRelay sends an unrelay message to the server
|
||||
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent unrelay message")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops monitoring all peers
|
||||
func (pm *PeerMonitor) Stop() {
|
||||
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||
pm.stopHolepunchMonitor()
|
||||
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
|
||||
// Stop all monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.StopMonitor()
|
||||
}
|
||||
}
|
||||
|
||||
// MarkPeerRelayed marks a peer as currently using relay
|
||||
func (pm *PeerMonitor) MarkPeerRelayed(siteID int, relayed bool) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
pm.relayedPeers[siteID] = relayed
|
||||
if relayed {
|
||||
// Reset failure count when marked as relayed
|
||||
pm.holepunchFailures[siteID] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// IsPeerRelayed returns whether a peer is currently using relay
|
||||
func (pm *PeerMonitor) IsPeerRelayed(siteID int) bool {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
return pm.relayedPeers[siteID]
|
||||
}
|
||||
|
||||
// startHolepunchMonitor starts the holepunch connection monitoring
|
||||
// Note: This function assumes the mutex is already held by the caller (called from Start())
|
||||
func (pm *PeerMonitor) startHolepunchMonitor() error {
|
||||
if pm.holepunchTester == nil {
|
||||
return fmt.Errorf("holepunch tester not initialized (sharedBind not provided)")
|
||||
}
|
||||
|
||||
if pm.holepunchStopChan != nil {
|
||||
return fmt.Errorf("holepunch monitor already running")
|
||||
}
|
||||
|
||||
if err := pm.holepunchTester.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start holepunch tester: %w", err)
|
||||
}
|
||||
|
||||
pm.holepunchStopChan = make(chan struct{})
|
||||
|
||||
go pm.runHolepunchMonitor()
|
||||
|
||||
logger.Info("Started holepunch connection monitor")
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopHolepunchMonitor stops the holepunch connection monitoring
|
||||
func (pm *PeerMonitor) stopHolepunchMonitor() {
|
||||
pm.mutex.Lock()
|
||||
stopChan := pm.holepunchStopChan
|
||||
pm.holepunchStopChan = nil
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if stopChan != nil {
|
||||
close(stopChan)
|
||||
}
|
||||
|
||||
if pm.holepunchTester != nil {
|
||||
pm.holepunchTester.Stop()
|
||||
}
|
||||
|
||||
logger.Info("Stopped holepunch connection monitor")
|
||||
}
|
||||
|
||||
// runHolepunchMonitor runs the holepunch monitoring loop
|
||||
func (pm *PeerMonitor) runHolepunchMonitor() {
|
||||
ticker := time.NewTicker(pm.holepunchInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Do initial check immediately
|
||||
pm.checkHolepunchEndpoints()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pm.holepunchStopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
pm.checkHolepunchEndpoints()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkHolepunchEndpoints tests all holepunch endpoints
|
||||
func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
pm.mutex.Lock()
|
||||
// Check if we're still running before doing any work
|
||||
if !pm.running {
|
||||
pm.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
|
||||
for siteID, endpoint := range pm.holepunchEndpoints {
|
||||
endpoints[siteID] = endpoint
|
||||
}
|
||||
timeout := pm.holepunchTimeout
|
||||
maxAttempts := pm.holepunchMaxAttempts
|
||||
pm.mutex.Unlock()
|
||||
|
||||
for siteID, endpoint := range endpoints {
|
||||
logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
|
||||
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
|
||||
|
||||
pm.mutex.Lock()
|
||||
// Check if peer was removed while we were testing
|
||||
if _, stillExists := pm.holepunchEndpoints[siteID]; !stillExists {
|
||||
pm.mutex.Unlock()
|
||||
continue // Peer was removed, skip processing
|
||||
}
|
||||
|
||||
previousStatus, exists := pm.holepunchStatus[siteID]
|
||||
pm.holepunchStatus[siteID] = result.Success
|
||||
isRelayed := pm.relayedPeers[siteID]
|
||||
|
||||
// Track consecutive failures for relay triggering
|
||||
if result.Success {
|
||||
pm.holepunchFailures[siteID] = 0
|
||||
} else {
|
||||
pm.holepunchFailures[siteID]++
|
||||
}
|
||||
failureCount := pm.holepunchFailures[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
// Log status changes
|
||||
if !exists || previousStatus != result.Success {
|
||||
if result.Success {
|
||||
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
|
||||
} else {
|
||||
if result.Error != nil {
|
||||
logger.Warn("Holepunch to site %d (%s) is DISCONNECTED: %v", siteID, endpoint, result.Error)
|
||||
} else {
|
||||
logger.Warn("Holepunch to site %d (%s) is DISCONNECTED", siteID, endpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update API with holepunch status
|
||||
if pm.apiServer != nil {
|
||||
// Update holepunch connection status
|
||||
pm.apiServer.UpdatePeerHolepunchStatus(siteID, result.Success)
|
||||
|
||||
// Get the current WG connection status for this peer
|
||||
pm.mutex.Lock()
|
||||
wgConnected := pm.wgConnectionStatus[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
// Update API - use holepunch endpoint and relay status
|
||||
pm.apiServer.UpdatePeerStatus(siteID, wgConnected, result.RTT, endpoint, isRelayed)
|
||||
}
|
||||
|
||||
// Handle relay logic based on holepunch status
|
||||
// Check if we're still running before sending relay messages
|
||||
pm.mutex.Lock()
|
||||
stillRunning := pm.running
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !stillRunning {
|
||||
return // Stop processing if shutdown is in progress
|
||||
}
|
||||
|
||||
if !result.Success && !isRelayed && failureCount >= maxAttempts {
|
||||
// Holepunch failed and we're not relayed - trigger relay
|
||||
logger.Info("Holepunch to site %d failed %d times, triggering relay", siteID, failureCount)
|
||||
if pm.wsClient != nil {
|
||||
pm.sendRelay(siteID)
|
||||
}
|
||||
} else if result.Success && isRelayed {
|
||||
// Holepunch succeeded and we ARE relayed - switch back to direct
|
||||
logger.Info("Holepunch to site %d succeeded while relayed, switching to direct connection", siteID)
|
||||
if pm.wsClient != nil {
|
||||
pm.sendUnRelay(siteID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetHolepunchStatus returns the current holepunch status for all endpoints
|
||||
func (pm *PeerMonitor) GetHolepunchStatus() map[int]bool {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
status := make(map[int]bool, len(pm.holepunchStatus))
|
||||
for siteID, connected := range pm.holepunchStatus {
|
||||
status[siteID] = connected
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
// Close stops monitoring and cleans up resources
|
||||
func (pm *PeerMonitor) Close() {
|
||||
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||
pm.stopHolepunchMonitor()
|
||||
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
logger.Debug("PeerMonitor: Starting cleanup")
|
||||
|
||||
// Stop and close all clients first
|
||||
for siteID, client := range pm.monitors {
|
||||
logger.Debug("PeerMonitor: Stopping client for site %d", siteID)
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
|
||||
// Clean up netstack resources
|
||||
logger.Debug("PeerMonitor: Cancelling netstack context")
|
||||
if pm.nsCancel != nil {
|
||||
pm.nsCancel() // Signal goroutines to stop
|
||||
}
|
||||
|
||||
// Close the channel endpoint to unblock any pending reads
|
||||
logger.Debug("PeerMonitor: Closing endpoint")
|
||||
if pm.ep != nil {
|
||||
pm.ep.Close()
|
||||
}
|
||||
|
||||
// Wait for packet sender goroutine to finish with timeout
|
||||
logger.Debug("PeerMonitor: Waiting for goroutines to finish")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
pm.nsWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
logger.Debug("PeerMonitor: Goroutines finished cleanly")
|
||||
case <-time.After(2 * time.Second):
|
||||
logger.Warn("PeerMonitor: Timeout waiting for goroutines to finish, proceeding anyway")
|
||||
}
|
||||
|
||||
// Destroy the stack last, after all goroutines are done
|
||||
logger.Debug("PeerMonitor: Destroying stack")
|
||||
if pm.stack != nil {
|
||||
pm.stack.Destroy()
|
||||
pm.stack = nil
|
||||
}
|
||||
|
||||
logger.Debug("PeerMonitor: Cleanup complete")
|
||||
}
|
||||
|
||||
// TestPeer tests connectivity to a specific peer
|
||||
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||
pm.mutex.Lock()
|
||||
client, exists := pm.monitors[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
defer cancel()
|
||||
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
return connected, rtt, nil
|
||||
}
|
||||
|
||||
// TestAllPeers tests connectivity to all peers
|
||||
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
} {
|
||||
pm.mutex.Lock()
|
||||
peers := make(map[int]*Client, len(pm.monitors))
|
||||
for siteID, client := range pm.monitors {
|
||||
peers[siteID] = client
|
||||
}
|
||||
pm.mutex.Unlock()
|
||||
|
||||
results := make(map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
})
|
||||
for siteID, client := range peers {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
cancel()
|
||||
|
||||
results[siteID] = struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
}{
|
||||
Connected: connected,
|
||||
RTT: rtt,
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// initNetstack initializes the gvisor netstack
|
||||
func (pm *PeerMonitor) initNetstack() error {
|
||||
if pm.localIP == "" {
|
||||
return fmt.Errorf("local IP not provided")
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(pm.localIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid local IP: %v", err)
|
||||
}
|
||||
|
||||
// Create gvisor netstack
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
HandleLocal: true,
|
||||
}
|
||||
|
||||
pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG)
|
||||
pm.stack = stack.New(stackOpts)
|
||||
|
||||
// Create NIC
|
||||
if err := pm.stack.CreateNIC(1, pm.ep); err != nil {
|
||||
return fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
// Add IP address
|
||||
ipBytes := addr.As4()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
||||
}
|
||||
|
||||
if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return fmt.Errorf("failed to add protocol address: %v", err)
|
||||
}
|
||||
|
||||
// Add default route
|
||||
pm.stack.AddRoute(tcpip.Route{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
// Register filter rule on MiddleDevice
|
||||
// We want to intercept packets destined to our local IP
|
||||
// But ONLY if they are for ports we are listening on
|
||||
pm.middleDev.AddRule(addr, pm.handlePacket)
|
||||
|
||||
// Start packet sender (Stack -> WG)
|
||||
pm.nsWg.Add(1)
|
||||
go pm.runPacketSender()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handlePacket is called by MiddleDevice when a packet arrives for our IP
|
||||
func (pm *PeerMonitor) handlePacket(packet []byte) bool {
|
||||
// Check if it's UDP
|
||||
proto, ok := util.GetProtocol(packet)
|
||||
if !ok || proto != 17 { // UDP
|
||||
return false
|
||||
}
|
||||
|
||||
// Check destination port
|
||||
port, ok := util.GetDestPort(packet)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if we are listening on this port
|
||||
pm.portsLock.Lock()
|
||||
active := pm.activePorts[uint16(port)]
|
||||
pm.portsLock.Unlock()
|
||||
|
||||
if !active {
|
||||
return false
|
||||
}
|
||||
|
||||
// Inject into netstack
|
||||
version := packet[0] >> 4
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||
case 6:
|
||||
pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||
default:
|
||||
pkb.DecRef()
|
||||
return false
|
||||
}
|
||||
|
||||
pkb.DecRef()
|
||||
return true // Handled
|
||||
}
|
||||
|
||||
// runPacketSender reads packets from netstack and injects them into WireGuard
|
||||
func (pm *PeerMonitor) runPacketSender() {
|
||||
defer pm.nsWg.Done()
|
||||
logger.Debug("PeerMonitor: Packet sender goroutine started")
|
||||
|
||||
// Use a ticker to periodically check for packets without blocking indefinitely
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pm.nsCtx.Done():
|
||||
logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets")
|
||||
// Drain any remaining packets before exiting
|
||||
for {
|
||||
pkt := pm.ep.Read()
|
||||
if pkt == nil {
|
||||
break
|
||||
}
|
||||
pkt.DecRef()
|
||||
}
|
||||
logger.Debug("PeerMonitor: Packet sender goroutine exiting")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Try to read packets in batches
|
||||
for i := 0; i < 10; i++ {
|
||||
pkt := pm.ep.Read()
|
||||
if pkt == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Extract packet data
|
||||
slices := pkt.AsSlices()
|
||||
if len(slices) > 0 {
|
||||
var totalSize int
|
||||
for _, slice := range slices {
|
||||
totalSize += len(slice)
|
||||
}
|
||||
|
||||
buf := make([]byte, totalSize)
|
||||
pos := 0
|
||||
for _, slice := range slices {
|
||||
copy(buf[pos:], slice)
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Inject into MiddleDevice (outbound to WG)
|
||||
pm.middleDev.InjectOutbound(buf)
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dial creates a UDP connection using the netstack
|
||||
func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) {
|
||||
if pm.stack == nil {
|
||||
return nil, fmt.Errorf("netstack not initialized")
|
||||
}
|
||||
|
||||
// Parse remote address
|
||||
raddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse local IP
|
||||
localIP, err := netip.ParseAddr(pm.localIP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ipBytes := localIP.As4()
|
||||
|
||||
// Create UDP connection
|
||||
// We bind to port 0 (ephemeral)
|
||||
laddr := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4(ipBytes),
|
||||
Port: 0,
|
||||
}
|
||||
|
||||
raddrTcpip := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
|
||||
Port: uint16(raddr.Port),
|
||||
}
|
||||
|
||||
conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get local port
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
port := uint16(localAddr.Port)
|
||||
|
||||
// Register port
|
||||
pm.portsLock.Lock()
|
||||
pm.activePorts[port] = true
|
||||
pm.portsLock.Unlock()
|
||||
|
||||
// Wrap connection to cleanup port on close
|
||||
return &trackedConn{
|
||||
Conn: conn,
|
||||
pm: pm,
|
||||
port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (pm *PeerMonitor) removePort(port uint16) {
|
||||
pm.portsLock.Lock()
|
||||
delete(pm.activePorts, port)
|
||||
pm.portsLock.Unlock()
|
||||
}
|
||||
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
pm *PeerMonitor
|
||||
port uint16
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.pm.removePort(c.port)
|
||||
if c.Conn != nil {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package wgtester
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -26,7 +26,7 @@ const (
|
||||
|
||||
// Client handles checking connectivity to a server
|
||||
type Client struct {
|
||||
conn *net.UDPConn
|
||||
conn net.Conn
|
||||
serverAddr string
|
||||
monitorRunning bool
|
||||
monitorLock sync.Mutex
|
||||
@@ -35,8 +35,12 @@ type Client struct {
|
||||
packetInterval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
dialer Dialer
|
||||
}
|
||||
|
||||
// Dialer is a function that creates a connection
|
||||
type Dialer func(network, addr string) (net.Conn, error)
|
||||
|
||||
// ConnectionStatus represents the current connection state
|
||||
type ConnectionStatus struct {
|
||||
Connected bool
|
||||
@@ -44,13 +48,14 @@ type ConnectionStatus struct {
|
||||
}
|
||||
|
||||
// NewClient creates a new connection test client
|
||||
func NewClient(serverAddr string) (*Client, error) {
|
||||
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
|
||||
return &Client{
|
||||
serverAddr: serverAddr,
|
||||
shutdownCh: make(chan struct{}),
|
||||
packetInterval: 2 * time.Second,
|
||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||
maxAttempts: 3, // Default max attempts
|
||||
dialer: dialer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -69,6 +74,20 @@ func (c *Client) SetMaxAttempts(attempts int) {
|
||||
c.maxAttempts = attempts
|
||||
}
|
||||
|
||||
// UpdateServerAddr updates the server address and resets the connection
|
||||
func (c *Client) UpdateServerAddr(serverAddr string) {
|
||||
c.connLock.Lock()
|
||||
defer c.connLock.Unlock()
|
||||
|
||||
// Close existing connection if any
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
c.serverAddr = serverAddr
|
||||
}
|
||||
|
||||
// Close cleans up client resources
|
||||
func (c *Client) Close() {
|
||||
c.StopMonitor()
|
||||
@@ -91,12 +110,14 @@ func (c *Client) ensureConnection() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
var err error
|
||||
if c.dialer != nil {
|
||||
c.conn, err = c.dialer("udp", c.serverAddr)
|
||||
} else {
|
||||
// Fallback to standard net.Dial
|
||||
c.conn, err = net.Dial("udp", c.serverAddr)
|
||||
}
|
||||
|
||||
c.conn, err = net.DialUDP("udp", nil, serverAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -136,14 +157,14 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
|
||||
// logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
|
||||
_, err := c.conn.Write(packet)
|
||||
if err != nil {
|
||||
c.connLock.Unlock()
|
||||
logger.Info("Error sending packet: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Debug("Successfully sent monitor packet")
|
||||
// logger.Debug("Successfully sent monitor packet")
|
||||
|
||||
// Set read deadline
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||
142
peers/peer.go
Normal file
142
peers/peer.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package peers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error {
|
||||
var endpoint string
|
||||
if relay && siteConfig.RelayEndpoint != "" {
|
||||
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
|
||||
} else {
|
||||
endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||
}
|
||||
siteHost, err := util.ResolveDomain(endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||
}
|
||||
|
||||
// Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP
|
||||
allowedIp := strings.Split(siteConfig.ServerIP, "/")
|
||||
if len(allowedIp) > 1 {
|
||||
allowedIp[1] = "32"
|
||||
} else {
|
||||
allowedIp = append(allowedIp, "32")
|
||||
}
|
||||
allowedIpStr := strings.Join(allowedIp, "/")
|
||||
|
||||
// Collect all allowed IPs in a slice
|
||||
var allowedIPs []string
|
||||
allowedIPs = append(allowedIPs, allowedIpStr)
|
||||
|
||||
// Use AllowedIps if available, otherwise fall back to RemoteSubnets for backwards compatibility
|
||||
subnetsToAdd := siteConfig.AllowedIps
|
||||
|
||||
// If we have anything to add, process them
|
||||
if len(subnetsToAdd) > 0 {
|
||||
// Add each subnet
|
||||
for _, subnet := range subnetsToAdd {
|
||||
subnet = strings.TrimSpace(subnet)
|
||||
if subnet != "" {
|
||||
allowedIPs = append(allowedIPs, subnet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct WireGuard config for this peer
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", util.FixKey(privateKey.String())))
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(siteConfig.PublicKey)))
|
||||
|
||||
// Add each allowed IP separately
|
||||
for _, allowedIP := range allowedIPs {
|
||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||
}
|
||||
|
||||
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
||||
configBuilder.WriteString("persistent_keepalive_interval=5\n")
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Configuring peer with config: %s", config)
|
||||
|
||||
err = dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeer removes a peer from the WireGuard device
|
||||
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
|
||||
// Construct WireGuard config to remove the peer
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||
configBuilder.WriteString("remove=true\n")
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Removing peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddAllowedIP adds a single allowed IP to an existing peer without reconfiguring the entire peer
|
||||
func AddAllowedIP(dev *device.Device, publicKey string, allowedIP string) error {
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||
configBuilder.WriteString("update_only=true\n")
|
||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Adding allowed IP to peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add allowed IP to WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAllowedIP removes a single allowed IP from an existing peer by replacing the allowed IPs list
|
||||
// This requires providing all the allowed IPs that should remain after removal
|
||||
func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs []string) error {
|
||||
var configBuilder strings.Builder
|
||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||
configBuilder.WriteString("update_only=true\n")
|
||||
configBuilder.WriteString("replace_allowed_ips=true\n")
|
||||
|
||||
// Add each remaining allowed IP
|
||||
for _, allowedIP := range remainingAllowedIPs {
|
||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||
}
|
||||
|
||||
config := configBuilder.String()
|
||||
logger.Debug("Removing allowed IP from peer with config: %s", config)
|
||||
|
||||
err := dev.IpcSet(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove allowed IP from WireGuard peer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatEndpoint(endpoint string) string {
|
||||
if strings.Contains(endpoint, ":") {
|
||||
return endpoint
|
||||
}
|
||||
return endpoint + ":51820"
|
||||
}
|
||||
62
peers/types.go
Normal file
62
peers/types.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package peers
|
||||
|
||||
// PeerAction represents a request to add, update, or remove a peer
|
||||
type PeerAction struct {
|
||||
Action string `json:"action"` // "add", "update", or "remove"
|
||||
SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information
|
||||
}
|
||||
|
||||
// UpdatePeerData represents the data needed to update a peer
|
||||
type SiteConfig struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
RelayEndpoint string `json:"relayEndpoint,omitempty"`
|
||||
PublicKey string `json:"publicKey,omitempty"`
|
||||
ServerIP string `json:"serverIP,omitempty"`
|
||||
ServerPort uint16 `json:"serverPort,omitempty"`
|
||||
RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access
|
||||
AllowedIps []string `json:"allowedIps,omitempty"` // optional, array of allowed IPs for the peer
|
||||
Aliases []Alias `json:"aliases,omitempty"` // optional, array of alias configurations
|
||||
}
|
||||
|
||||
type Alias struct {
|
||||
Alias string `json:"alias"` // the alias name
|
||||
AliasAddress string `json:"aliasAddress"` // the alias IP address
|
||||
}
|
||||
|
||||
// RemovePeer represents the data needed to remove a peer
|
||||
type PeerRemove struct {
|
||||
SiteId int `json:"siteId"`
|
||||
}
|
||||
|
||||
type RelayPeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
RelayEndpoint string `json:"relayEndpoint"`
|
||||
}
|
||||
|
||||
type UnRelayPeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
// PeerAdd represents the data needed to add remote subnets to a peer
|
||||
type PeerAdd struct {
|
||||
SiteId int `json:"siteId"`
|
||||
RemoteSubnets []string `json:"remoteSubnets"` // subnets to add
|
||||
Aliases []Alias `json:"aliases,omitempty"` // aliases to add
|
||||
}
|
||||
|
||||
// RemovePeerData represents the data needed to remove remote subnets from a peer
|
||||
type RemovePeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
RemoteSubnets []string `json:"remoteSubnets"` // subnets to remove
|
||||
Aliases []Alias `json:"aliases,omitempty"` // aliases to remove
|
||||
}
|
||||
|
||||
type UpdatePeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
OldRemoteSubnets []string `json:"oldRemoteSubnets"` // old list of remote subnets
|
||||
NewRemoteSubnets []string `json:"newRemoteSubnets"` // new list of remote subnets
|
||||
OldAliases []Alias `json:"oldAliases,omitempty"` // old list of aliases
|
||||
NewAliases []Alias `json:"newAliases,omitempty"` // new list of aliases
|
||||
}
|
||||
@@ -48,3 +48,7 @@ func setupWindowsEventLog() {
|
||||
func watchLogFile(end bool) error {
|
||||
return fmt.Errorf("watching log file is only available on Windows")
|
||||
}
|
||||
|
||||
func showServiceConfig() {
|
||||
fmt.Println("Service configuration is only available on Windows")
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -69,12 +70,6 @@ func loadServiceArgs() ([]string, error) {
|
||||
return nil, fmt.Errorf("failed to read service args: %v", err)
|
||||
}
|
||||
|
||||
// delete the file after reading
|
||||
err = os.Remove(argsPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete service args file: %v", err)
|
||||
}
|
||||
|
||||
var args []string
|
||||
err = json.Unmarshal(data, &args)
|
||||
if err != nil {
|
||||
@@ -95,7 +90,7 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes
|
||||
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
|
||||
changes <- svc.Status{State: svc.StartPending}
|
||||
|
||||
s.elog.Info(1, "Service Execute called, starting main logic")
|
||||
s.elog.Info(1, fmt.Sprintf("Service Execute called with args: %v", args))
|
||||
|
||||
// Load saved service arguments
|
||||
savedArgs, err := loadServiceArgs()
|
||||
@@ -104,7 +99,24 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes
|
||||
// Continue with empty args if loading fails
|
||||
savedArgs = []string{}
|
||||
}
|
||||
s.args = savedArgs
|
||||
|
||||
// Combine service start args with saved args, giving priority to service start args
|
||||
finalArgs := []string{}
|
||||
if len(args) > 0 {
|
||||
// Skip the first arg which is typically the service name
|
||||
if len(args) > 1 {
|
||||
finalArgs = append(finalArgs, args[1:]...)
|
||||
}
|
||||
s.elog.Info(1, fmt.Sprintf("Using service start parameters: %v", finalArgs))
|
||||
}
|
||||
|
||||
// If no service start parameters, use saved args
|
||||
if len(finalArgs) == 0 && len(savedArgs) > 0 {
|
||||
finalArgs = savedArgs
|
||||
s.elog.Info(1, fmt.Sprintf("Using saved service args: %v", finalArgs))
|
||||
}
|
||||
|
||||
s.args = finalArgs
|
||||
|
||||
// Start the main olm functionality
|
||||
olmDone := make(chan struct{})
|
||||
@@ -151,6 +163,9 @@ func (s *olmService) runOlm() {
|
||||
// Create a context that can be cancelled when the service stops
|
||||
s.ctx, s.stop = context.WithCancel(context.Background())
|
||||
|
||||
// Create a separate context for programmatic shutdown (e.g., via API exit)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Setup logging for service mode
|
||||
s.elog.Info(1, "Starting Olm main logic")
|
||||
|
||||
@@ -165,7 +180,8 @@ func (s *olmService) runOlm() {
|
||||
}()
|
||||
|
||||
// Call the main olm function with stored arguments
|
||||
runOlmMainWithArgs(s.ctx, s.args)
|
||||
// Use s.ctx as the signal context since the service manages shutdown
|
||||
runOlmMainWithArgs(ctx, cancel, s.ctx, s.args)
|
||||
}()
|
||||
|
||||
// Wait for either context cancellation or main logic completion
|
||||
@@ -309,7 +325,7 @@ func removeService() error {
|
||||
}
|
||||
|
||||
func startService(args []string) error {
|
||||
// Save the service arguments before starting
|
||||
// Save the service arguments as backup
|
||||
if len(args) > 0 {
|
||||
err := saveServiceArgs(args)
|
||||
if err != nil {
|
||||
@@ -329,7 +345,8 @@ func startService(args []string) error {
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
err = s.Start()
|
||||
// Pass arguments directly to the service start call
|
||||
err = s.Start(args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start service: %v", err)
|
||||
}
|
||||
@@ -379,17 +396,12 @@ func debugService(args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// fmt.Printf("Starting service in debug mode...\n")
|
||||
|
||||
// Start the service
|
||||
err := startService([]string{}) // Pass empty args since we already saved them
|
||||
// Start the service with the provided arguments
|
||||
err := startService(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start service: %v", err)
|
||||
}
|
||||
|
||||
// fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n")
|
||||
// fmt.Printf("================================================================================\n")
|
||||
|
||||
// Watch the log file
|
||||
return watchLogFile(true)
|
||||
}
|
||||
@@ -509,11 +521,89 @@ func getServiceStatus() (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// showServiceConfig displays current saved service configuration
|
||||
func showServiceConfig() {
|
||||
configPath := getServiceArgsPath()
|
||||
fmt.Printf("Service configuration file: %s\n", configPath)
|
||||
|
||||
args, err := loadServiceArgs()
|
||||
if err != nil {
|
||||
fmt.Printf("No saved configuration found or error loading: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
fmt.Println("No saved service arguments found")
|
||||
} else {
|
||||
fmt.Printf("Saved service arguments: %v\n", args)
|
||||
}
|
||||
}
|
||||
|
||||
func isWindowsService() bool {
|
||||
isWindowsService, err := svc.IsWindowsService()
|
||||
return err == nil && isWindowsService
|
||||
}
|
||||
|
||||
// rotateLogFile handles daily log rotation
|
||||
func rotateLogFile(logDir string, logFile string) error {
|
||||
// Get current log file info
|
||||
info, err := os.Stat(logFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // No current log file to rotate
|
||||
}
|
||||
return fmt.Errorf("failed to stat log file: %v", err)
|
||||
}
|
||||
|
||||
// Check if log file is from today
|
||||
now := time.Now()
|
||||
fileTime := info.ModTime()
|
||||
|
||||
// If the log file is from today, no rotation needed
|
||||
if now.Year() == fileTime.Year() && now.YearDay() == fileTime.YearDay() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create rotated filename with date
|
||||
rotatedName := fmt.Sprintf("olm-%s.log", fileTime.Format("2006-01-02"))
|
||||
rotatedPath := filepath.Join(logDir, rotatedName)
|
||||
|
||||
// Rename current log file to dated filename
|
||||
err = os.Rename(logFile, rotatedPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to rotate log file: %v", err)
|
||||
}
|
||||
|
||||
// Clean up old log files (keep last 30 days)
|
||||
cleanupOldLogFiles(logDir, 30)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOldLogFiles removes log files older than specified days
|
||||
func cleanupOldLogFiles(logDir string, daysToKeep int) {
|
||||
cutoff := time.Now().AddDate(0, 0, -daysToKeep)
|
||||
|
||||
files, err := os.ReadDir(logDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if !file.IsDir() && strings.HasPrefix(file.Name(), "olm-") && strings.HasSuffix(file.Name(), ".log") {
|
||||
filePath := filepath.Join(logDir, file.Name())
|
||||
info, err := file.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if info.ModTime().Before(cutoff) {
|
||||
os.Remove(filePath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupWindowsEventLog() {
|
||||
// Create log directory if it doesn't exist
|
||||
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
|
||||
@@ -524,6 +614,14 @@ func setupWindowsEventLog() {
|
||||
}
|
||||
|
||||
logFile := filepath.Join(logDir, "olm.log")
|
||||
|
||||
// Rotate log file if needed
|
||||
err = rotateLogFile(logDir, logFile)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to rotate log file: %v\n", err)
|
||||
// Continue anyway to create new log file
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to open log file: %v\n", err)
|
||||
|
||||
35
unix.go
35
unix.go
@@ -1,35 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
|
||||
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(int(fd), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "")
|
||||
return tun.CreateTUNFromFile(file, mtuInt)
|
||||
}
|
||||
func uapiOpen(interfaceName string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(interfaceName)
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||
}
|
||||
715
websocket/client.go
Normal file
715
websocket/client.go
Normal file
@@ -0,0 +1,715 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"software.sslmate.com/src/go-pkcs12"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// AuthError represents an authentication/authorization error (401/403)
|
||||
type AuthError struct {
|
||||
StatusCode int
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *AuthError) Error() string {
|
||||
return fmt.Sprintf("authentication error (status %d): %s", e.StatusCode, e.Message)
|
||||
}
|
||||
|
||||
// IsAuthError checks if an error is an authentication error
|
||||
func IsAuthError(err error) bool {
|
||||
_, ok := err.(*AuthError)
|
||||
return ok
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
Data struct {
|
||||
Token string `json:"token"`
|
||||
ExitNodes []ExitNode `json:"exitNodes"`
|
||||
} `json:"data"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
type WSMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// this is not json anymore
|
||||
type Config struct {
|
||||
ID string
|
||||
Secret string
|
||||
Endpoint string
|
||||
TlsClientCert string // legacy PKCS12 file path
|
||||
UserToken string // optional user token for websocket authentication
|
||||
OrgID string // optional organization ID for websocket authentication
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
config *Config
|
||||
conn *websocket.Conn
|
||||
baseURL string
|
||||
handlers map[string]MessageHandler
|
||||
done chan struct{}
|
||||
handlersMux sync.RWMutex
|
||||
reconnectInterval time.Duration
|
||||
isConnected bool
|
||||
reconnectMux sync.RWMutex
|
||||
pingInterval time.Duration
|
||||
pingTimeout time.Duration
|
||||
onConnect func() error
|
||||
onTokenUpdate func(token string, exitNodes []ExitNode)
|
||||
onAuthError func(statusCode int, message string) // Callback for auth errors
|
||||
writeMux sync.Mutex
|
||||
clientType string // Type of client (e.g., "newt", "olm")
|
||||
tlsConfig TLSConfig
|
||||
configNeedsSave bool // Flag to track if config needs to be saved
|
||||
}
|
||||
|
||||
type ClientOption func(*Client)
|
||||
|
||||
type MessageHandler func(message WSMessage)
|
||||
|
||||
// TLSConfig holds TLS configuration options
|
||||
type TLSConfig struct {
|
||||
// New separate certificate support
|
||||
ClientCertFile string
|
||||
ClientKeyFile string
|
||||
CAFiles []string
|
||||
|
||||
// Existing PKCS12 support (deprecated)
|
||||
PKCS12File string
|
||||
}
|
||||
|
||||
// WithBaseURL sets the base URL for the client
|
||||
func WithBaseURL(url string) ClientOption {
|
||||
return func(c *Client) {
|
||||
c.baseURL = url
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSConfig sets the TLS configuration for the client
|
||||
func WithTLSConfig(config TLSConfig) ClientOption {
|
||||
return func(c *Client) {
|
||||
c.tlsConfig = config
|
||||
// For backward compatibility, also set the legacy field
|
||||
if config.PKCS12File != "" {
|
||||
c.config.TlsClientCert = config.PKCS12File
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) OnConnect(callback func() error) {
|
||||
c.onConnect = callback
|
||||
}
|
||||
|
||||
func (c *Client) OnTokenUpdate(callback func(token string, exitNodes []ExitNode)) {
|
||||
c.onTokenUpdate = callback
|
||||
}
|
||||
|
||||
func (c *Client) OnAuthError(callback func(statusCode int, message string)) {
|
||||
c.onAuthError = callback
|
||||
}
|
||||
|
||||
// NewClient creates a new websocket client
|
||||
func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
|
||||
config := &Config{
|
||||
ID: ID,
|
||||
Secret: secret,
|
||||
Endpoint: endpoint,
|
||||
UserToken: userToken,
|
||||
OrgID: orgId,
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
config: config,
|
||||
baseURL: endpoint, // default value
|
||||
handlers: make(map[string]MessageHandler),
|
||||
done: make(chan struct{}),
|
||||
reconnectInterval: 3 * time.Second,
|
||||
isConnected: false,
|
||||
pingInterval: pingInterval,
|
||||
pingTimeout: pingTimeout,
|
||||
clientType: "olm",
|
||||
}
|
||||
|
||||
// Apply options before loading config
|
||||
for _, opt := range opts {
|
||||
if opt == nil {
|
||||
continue
|
||||
}
|
||||
opt(client)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetConfig() *Config {
|
||||
return c.config
|
||||
}
|
||||
|
||||
// Connect establishes the WebSocket connection
|
||||
func (c *Client) Connect() error {
|
||||
go c.connectWithRetry()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the WebSocket connection gracefully
|
||||
func (c *Client) Close() error {
|
||||
// Signal shutdown to all goroutines first
|
||||
select {
|
||||
case <-c.done:
|
||||
// Already closed
|
||||
return nil
|
||||
default:
|
||||
close(c.done)
|
||||
}
|
||||
|
||||
// Set connection status to false
|
||||
c.setConnected(false)
|
||||
|
||||
// Close the WebSocket connection gracefully
|
||||
if c.conn != nil {
|
||||
// Send close message
|
||||
c.writeMux.Lock()
|
||||
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
c.writeMux.Unlock()
|
||||
|
||||
// Close the connection
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessage sends a message through the WebSocket connection
|
||||
func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
msg := WSMessage{
|
||||
Type: messageType,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
logger.Debug("Sending message: %s, data: %+v", messageType, data)
|
||||
|
||||
c.writeMux.Lock()
|
||||
defer c.writeMux.Unlock()
|
||||
return c.conn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
|
||||
stopChan := make(chan struct{})
|
||||
updateChan := make(chan interface{})
|
||||
var dataMux sync.Mutex
|
||||
currentData := data
|
||||
|
||||
go func() {
|
||||
count := 0
|
||||
maxAttempts := 10
|
||||
|
||||
err := c.SendMessage(messageType, currentData) // Send immediately
|
||||
if err != nil {
|
||||
logger.Error("Failed to send initial message: %v", err)
|
||||
}
|
||||
count++
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if count >= maxAttempts {
|
||||
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||
return
|
||||
}
|
||||
dataMux.Lock()
|
||||
err = c.SendMessage(messageType, currentData)
|
||||
dataMux.Unlock()
|
||||
if err != nil {
|
||||
logger.Error("Failed to send message: %v", err)
|
||||
}
|
||||
count++
|
||||
case newData := <-updateChan:
|
||||
dataMux.Lock()
|
||||
// Merge newData into currentData if both are maps
|
||||
if currentMap, ok := currentData.(map[string]interface{}); ok {
|
||||
if newMap, ok := newData.(map[string]interface{}); ok {
|
||||
// Update or add keys from newData
|
||||
for key, value := range newMap {
|
||||
currentMap[key] = value
|
||||
}
|
||||
currentData = currentMap
|
||||
} else {
|
||||
// If newData is not a map, replace entirely
|
||||
currentData = newData
|
||||
}
|
||||
} else {
|
||||
// If currentData is not a map, replace entirely
|
||||
currentData = newData
|
||||
}
|
||||
dataMux.Unlock()
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
close(stopChan)
|
||||
}, func(newData interface{}) {
|
||||
select {
|
||||
case updateChan <- newData:
|
||||
case <-stopChan:
|
||||
// Channel is closed, ignore update
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a specific message type
|
||||
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
||||
c.handlersMux.Lock()
|
||||
defer c.handlersMux.Unlock()
|
||||
c.handlers[messageType] = handler
|
||||
}
|
||||
|
||||
func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
// Parse the base URL to ensure we have the correct hostname
|
||||
baseURL, err := url.Parse(c.baseURL)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to parse base URL: %w", err)
|
||||
}
|
||||
|
||||
// Ensure we have the base URL without trailing slashes
|
||||
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
||||
|
||||
var tlsConfig *tls.Config = nil
|
||||
|
||||
// Use new TLS configuration method
|
||||
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||
tlsConfig, err = c.setupTLS()
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for environment variable to skip TLS verification
|
||||
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
|
||||
if tlsConfig == nil {
|
||||
tlsConfig = &tls.Config{}
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
tokenData := map[string]interface{}{
|
||||
"olmId": c.config.ID,
|
||||
"secret": c.config.Secret,
|
||||
"orgId": c.config.OrgID,
|
||||
}
|
||||
jsonData, err := json.Marshal(tokenData)
|
||||
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to marshal token request data: %w", err)
|
||||
}
|
||||
|
||||
// Create a new request
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token",
|
||||
bytes.NewBuffer(jsonData),
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
||||
|
||||
// Make the request
|
||||
client := &http.Client{}
|
||||
if tlsConfig != nil {
|
||||
client.Transport = &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to request new token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
|
||||
// Return AuthError for 401/403 status codes
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return "", nil, &AuthError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Message: string(body),
|
||||
}
|
||||
}
|
||||
|
||||
// For other errors (5xx, network issues, etc.), return regular error
|
||||
return "", nil, fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
logger.Error("Failed to decode token response.")
|
||||
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
if !tokenResp.Success {
|
||||
return "", nil, fmt.Errorf("failed to get token: %s", tokenResp.Message)
|
||||
}
|
||||
|
||||
if tokenResp.Data.Token == "" {
|
||||
return "", nil, fmt.Errorf("received empty token from server")
|
||||
}
|
||||
|
||||
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
||||
|
||||
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
|
||||
}
|
||||
|
||||
func (c *Client) connectWithRetry() {
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
err := c.establishConnection()
|
||||
if err != nil {
|
||||
// Check if this is an auth error (401/403)
|
||||
if authErr, ok := err.(*AuthError); ok {
|
||||
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
|
||||
// Trigger auth error callback if set (this should terminate the tunnel)
|
||||
if c.onAuthError != nil {
|
||||
c.onAuthError(authErr.StatusCode, authErr.Message)
|
||||
}
|
||||
// Continue retrying after auth error
|
||||
time.Sleep(c.reconnectInterval)
|
||||
continue
|
||||
}
|
||||
// For other errors (5xx, network issues), continue retrying
|
||||
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||
time.Sleep(c.reconnectInterval)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) establishConnection() error {
|
||||
// Get token for authentication
|
||||
token, exitNodes, err := c.getToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
|
||||
if c.onTokenUpdate != nil {
|
||||
c.onTokenUpdate(token, exitNodes)
|
||||
}
|
||||
|
||||
// Parse the base URL to determine protocol and hostname
|
||||
baseURL, err := url.Parse(c.baseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse base URL: %w", err)
|
||||
}
|
||||
|
||||
// Determine WebSocket protocol based on HTTP protocol
|
||||
wsProtocol := "wss"
|
||||
if baseURL.Scheme == "http" {
|
||||
wsProtocol = "ws"
|
||||
}
|
||||
|
||||
// Create WebSocket URL
|
||||
wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host)
|
||||
u, err := url.Parse(wsURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse WebSocket URL: %w", err)
|
||||
}
|
||||
|
||||
// Add token to query parameters
|
||||
q := u.Query()
|
||||
q.Set("token", token)
|
||||
q.Set("clientType", c.clientType)
|
||||
if c.config.UserToken != "" {
|
||||
q.Set("userToken", c.config.UserToken)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
// Connect to WebSocket
|
||||
dialer := websocket.DefaultDialer
|
||||
|
||||
// Use new TLS configuration method
|
||||
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||
logger.Info("Setting up TLS configuration for WebSocket connection")
|
||||
tlsConfig, err := c.setupTLS()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||
}
|
||||
dialer.TLSClientConfig = tlsConfig
|
||||
}
|
||||
|
||||
// Check for environment variable to skip TLS verification for WebSocket connection
|
||||
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
|
||||
if dialer.TLSClientConfig == nil {
|
||||
dialer.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
dialer.TLSClientConfig.InsecureSkipVerify = true
|
||||
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
conn, _, err := dialer.Dial(u.String(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
c.setConnected(true)
|
||||
|
||||
// Start the ping monitor
|
||||
go c.pingMonitor()
|
||||
// Start the read pump with disconnect detection
|
||||
go c.readPumpWithDisconnectDetection()
|
||||
|
||||
if c.onConnect != nil {
|
||||
if err := c.onConnect(); err != nil {
|
||||
logger.Error("OnConnect callback failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupTLS configures TLS based on the TLS configuration
|
||||
func (c *Client) setupTLS() (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
// Handle new separate certificate configuration
|
||||
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
|
||||
logger.Info("Loading separate certificate files for mTLS")
|
||||
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||
|
||||
// Load client certificate and key
|
||||
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load client certificate pair: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
|
||||
// Load CA certificates for remote validation if specified
|
||||
if len(c.tlsConfig.CAFiles) > 0 {
|
||||
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||
caCertPool := x509.NewCertPool()
|
||||
for _, caFile := range c.tlsConfig.CAFiles {
|
||||
caCert, err := os.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA file %s: %w", caFile, err)
|
||||
}
|
||||
|
||||
// Try to parse as PEM first, then DER
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
// If PEM parsing failed, try DER
|
||||
cert, err := x509.ParseCertificate(caCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate from %s: %w", caFile, err)
|
||||
}
|
||||
caCertPool.AddCert(cert)
|
||||
}
|
||||
}
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// Fallback to existing PKCS12 implementation for backward compatibility
|
||||
if c.tlsConfig.PKCS12File != "" {
|
||||
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
|
||||
return c.setupPKCS12TLS()
|
||||
}
|
||||
|
||||
// Legacy fallback using config.TlsClientCert
|
||||
if c.config.TlsClientCert != "" {
|
||||
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||
return loadClientCertificate(c.config.TlsClientCert)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// setupPKCS12TLS loads TLS configuration from PKCS12 file
|
||||
func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
|
||||
return loadClientCertificate(c.tlsConfig.PKCS12File)
|
||||
}
|
||||
|
||||
// pingMonitor sends pings at a short interval and triggers reconnect on failure
|
||||
func (c *Client) pingMonitor() {
|
||||
ticker := time.NewTicker(c.pingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
if c.conn == nil {
|
||||
return
|
||||
}
|
||||
c.writeMux.Lock()
|
||||
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
|
||||
c.writeMux.Unlock()
|
||||
if err != nil {
|
||||
// Check if we're shutting down before logging error and reconnecting
|
||||
select {
|
||||
case <-c.done:
|
||||
// Expected during shutdown
|
||||
return
|
||||
default:
|
||||
logger.Error("Ping failed: %v", err)
|
||||
c.reconnect()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
|
||||
func (c *Client) readPumpWithDisconnectDetection() {
|
||||
defer func() {
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
// Only attempt reconnect if we're not shutting down
|
||||
select {
|
||||
case <-c.done:
|
||||
// Shutting down, don't reconnect
|
||||
return
|
||||
default:
|
||||
c.reconnect()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
var msg WSMessage
|
||||
err := c.conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
// Check if we're shutting down before logging error
|
||||
select {
|
||||
case <-c.done:
|
||||
// Expected during shutdown, don't log as error
|
||||
logger.Debug("WebSocket connection closed during shutdown")
|
||||
return
|
||||
default:
|
||||
// Unexpected error during normal operation
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
||||
logger.Error("WebSocket read error: %v", err)
|
||||
} else {
|
||||
logger.Debug("WebSocket connection closed: %v", err)
|
||||
}
|
||||
return // triggers reconnect via defer
|
||||
}
|
||||
}
|
||||
|
||||
c.handlersMux.RLock()
|
||||
if handler, ok := c.handlers[msg.Type]; ok {
|
||||
handler(msg)
|
||||
}
|
||||
c.handlersMux.RUnlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) reconnect() {
|
||||
c.setConnected(false)
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
// Only reconnect if we're not shutting down
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
go c.connectWithRetry()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) setConnected(status bool) {
|
||||
c.reconnectMux.Lock()
|
||||
defer c.reconnectMux.Unlock()
|
||||
c.isConnected = status
|
||||
}
|
||||
|
||||
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
|
||||
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||
logger.Info("Loading tls-client-cert %s", p12Path)
|
||||
// Read the PKCS12 file
|
||||
p12Data, err := os.ReadFile(p12Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
|
||||
}
|
||||
|
||||
// Parse PKCS12 with empty password for non-encrypted files
|
||||
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
|
||||
}
|
||||
|
||||
// Create certificate
|
||||
cert := tls.Certificate{
|
||||
Certificate: [][]byte{certificate.Raw},
|
||||
PrivateKey: privateKey,
|
||||
}
|
||||
|
||||
// Optional: Add CA certificates if present
|
||||
rootCAs, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
|
||||
}
|
||||
if len(caCerts) > 0 {
|
||||
for _, caCert := range caCerts {
|
||||
rootCAs.AddCert(caCert)
|
||||
}
|
||||
}
|
||||
|
||||
// Create TLS configuration
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
RootCAs: rootCAs,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user