Compare commits

...

88 Commits
v1.3.0 ... dev

Author SHA1 Message Date
Owen
af973b2440 Support prt records 2026-02-06 15:17:01 -08:00
Owen
dd9bff9a4b Fix peer names clearing 2026-02-02 18:03:29 -08:00
Owen
1be5e454ba Default override dns to true
Ref #59
2026-02-02 10:03:22 -08:00
Owen
4a25a0d413 Dont go unregistered when low power mode
Former-commit-id: 0938564038
2026-01-31 16:58:05 -08:00
Owen
7fc3c7088e Lowercase all domains before matching
Former-commit-id: 8f8872aa47
2026-01-30 14:53:25 -08:00
Owen
1869e70894 Merge branch 'dev'
Former-commit-id: 43cc56a961
2026-01-30 10:58:00 -08:00
Owen
79783cc3dc Merge branch 'main' of github.com:fosrl/olm
Former-commit-id: 0b31f4e5d1
2026-01-30 10:57:40 -08:00
Owen
584298e3bd Fix terminate due to inactivity 2026-01-27 20:19:41 -08:00
miloschwartz
f683afa647 improve override-dns and tunnel-dns descriptions 2026-01-27 17:53:34 -08:00
Owen
ba2631d388 Prevent crashing on close before connect
Former-commit-id: ea461e0bfb
2026-01-23 14:47:54 -08:00
Owen Schwartz
6ae4e2b691 Merge pull request #87 from fosrl/dev
1.4.0

Former-commit-id: 1212217421
2026-01-23 10:25:03 -08:00
Owen
51eee9dcf5 Bump newt
Former-commit-id: f4885e9c4d
2026-01-23 10:23:42 -08:00
Owen
660e9e0e35 Merge branch 'main' into dev
Former-commit-id: b5580036d3
2026-01-23 10:22:21 -08:00
Owen
4ef6089053 Comment out local newt
Former-commit-id: c4ef1e724e
2026-01-23 10:19:38 -08:00
Owen
c4e297cc96 Handle properly stopping and starting the ping
Former-commit-id: 34c7717767
2026-01-20 11:30:06 -08:00
Owen
e3f5497176 Add stale bot
Former-commit-id: 313dee9ba8
2026-01-19 17:12:15 -08:00
dependabot[bot]
6a5dcc01a6 Bump actions/checkout from 5.0.0 to 6.0.1
Bumps [actions/checkout](https://github.com/actions/checkout) from 5.0.0 to 6.0.1.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](08c6903cd8...8e8c483db8)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: 6.0.1
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: e19b33e2fa
2026-01-19 17:08:10 -08:00
dependabot[bot]
18b6d3bb0f Bump the patch-updates group across 1 directory with 3 updates
Bumps the patch-updates group with 3 updates in the / directory: [github.com/fosrl/newt](https://github.com/fosrl/newt), [github.com/godbus/dbus/v5](https://github.com/godbus/dbus) and [github.com/miekg/dns](https://github.com/miekg/dns).


Updates `github.com/fosrl/newt` from 1.8.0 to 1.8.1
- [Release notes](https://github.com/fosrl/newt/releases)
- [Commits](https://github.com/fosrl/newt/compare/1.8.0...1.8.1)

Updates `github.com/godbus/dbus/v5` from 5.2.0 to 5.2.2
- [Release notes](https://github.com/godbus/dbus/releases)
- [Commits](https://github.com/godbus/dbus/compare/v5.2.0...v5.2.2)

Updates `github.com/miekg/dns` from 1.1.68 to 1.1.70
- [Commits](https://github.com/miekg/dns/compare/v1.1.68...v1.1.70)

---
updated-dependencies:
- dependency-name: github.com/fosrl/newt
  dependency-version: 1.8.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: patch-updates
- dependency-name: github.com/godbus/dbus/v5
  dependency-version: 5.2.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: patch-updates
- dependency-name: github.com/miekg/dns
  dependency-version: 1.1.70
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: patch-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: 69f25032cb
2026-01-19 17:08:00 -08:00
dependabot[bot]
ccbfdc5265 Bump docker/metadata-action from 5.9.0 to 5.10.0
Bumps [docker/metadata-action](https://github.com/docker/metadata-action) from 5.9.0 to 5.10.0.
- [Release notes](https://github.com/docker/metadata-action/releases)
- [Commits](318604b99e...c299e40c65)

---
updated-dependencies:
- dependency-name: docker/metadata-action
  dependency-version: 5.10.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: 225779c665
2026-01-19 17:06:44 -08:00
dependabot[bot]
ab04537278 Bump softprops/action-gh-release from 2.4.2 to 2.5.0
Bumps [softprops/action-gh-release](https://github.com/softprops/action-gh-release) from 2.4.2 to 2.5.0.
- [Release notes](https://github.com/softprops/action-gh-release/releases)
- [Changelog](https://github.com/softprops/action-gh-release/blob/master/CHANGELOG.md)
- [Commits](5be0e66d93...a06a81a03e)

---
updated-dependencies:
- dependency-name: softprops/action-gh-release
  dependency-version: 2.5.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: a7f029e232
2026-01-19 17:06:23 -08:00
dependabot[bot]
29c36c9837 Bump docker/setup-buildx-action from 3.11.1 to 3.12.0
Bumps [docker/setup-buildx-action](https://github.com/docker/setup-buildx-action) from 3.11.1 to 3.12.0.
- [Release notes](https://github.com/docker/setup-buildx-action/releases)
- [Commits](e468171a9d...8d2750c68a)

---
updated-dependencies:
- dependency-name: docker/setup-buildx-action
  dependency-version: 3.12.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: af4e74de81
2026-01-19 17:05:50 -08:00
dependabot[bot]
c47e9bf547 Bump actions/cache from 4.3.0 to 5.0.2
Bumps [actions/cache](https://github.com/actions/cache) from 4.3.0 to 5.0.2.
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](0057852bfa...8b402f58fb)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-version: 5.0.2
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: f87d043d59
2026-01-19 17:05:42 -08:00
dependabot[bot]
abb682c935 Bump the minor-updates group across 1 directory with 2 updates
Bumps the minor-updates group with 2 updates in the / directory: [golang.org/x/sys](https://github.com/golang/sys) and software.sslmate.com/src/go-pkcs12.


Updates `golang.org/x/sys` from 0.38.0 to 0.40.0
- [Commits](https://github.com/golang/sys/compare/v0.38.0...v0.40.0)

Updates `software.sslmate.com/src/go-pkcs12` from 0.6.0 to 0.7.0

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-version: 0.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-updates
- dependency-name: software.sslmate.com/src/go-pkcs12
  dependency-version: 0.7.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: ae1436c5d1
2026-01-19 17:05:19 -08:00
Owen
79e8a4a8bb Dont start holepunching if we rebind while in low power mode
Former-commit-id: 4a5ebd41f3
2026-01-19 15:57:20 -08:00
Owen
f2e81c024a Set fingerprint earlier
Former-commit-id: ef36f7ca82
2026-01-19 15:05:29 -08:00
Owen
6d10650e70 Send an initial ping so we get online faster in the dashboard
Former-commit-id: 41e4eb24a2
2026-01-18 15:14:11 -08:00
Owen
a81c683c66 Reorder websocket disconnect message
Former-commit-id: 592a0d60c6
2026-01-18 14:49:42 -08:00
Owen
25cb50901e Quiet up logs again
Former-commit-id: 112283191c
2026-01-18 12:18:48 -08:00
Owen
a8e0844758 Send disconnecting message when stopping
Former-commit-id: 1fb6e2a00d
2026-01-18 11:55:09 -08:00
Owen
8b9ee6f26a Move power mode to the api from signal
Former-commit-id: 5d8ea92ef0
2026-01-18 11:46:18 -08:00
Owen
82e8fcc3a7 Merge branch 'bubble-errors-up' into dev
Former-commit-id: 61846f9ec4
2026-01-18 11:38:20 -08:00
Owen
e2b7777ba7 Merge branch 'rebind' into dev
Former-commit-id: 2139aeaa85
2026-01-18 11:37:43 -08:00
Owen
4e4d1a39f6 Try to close the socket first
Former-commit-id: ed4775bd26
2026-01-17 17:35:00 -08:00
Owen
17dc1b0be1 Dont start the ping until we are connected
Former-commit-id: 43c8a14fda
2026-01-17 17:32:01 -08:00
Owen
a06436eeab Add rebind endpoints for the shared socket
Former-commit-id: 6fd0984b13
2026-01-17 17:05:29 -08:00
Lokowitz
a83cc2a3a3 clean up dependabot
Former-commit-id: a37f0514c4
2026-01-17 15:43:06 -08:00
Lokowitz
d56537d0fd add docker build dev
Former-commit-id: b983216808
2026-01-17 15:43:06 -08:00
Lokowitz
31bb483e40 add qemu
Former-commit-id: 172eb97aa1
2026-01-17 15:43:06 -08:00
Lokowitz
cd91ae6e3a update test
Former-commit-id: b034f81ed9
2026-01-17 15:43:06 -08:00
Lokowitz
a9ec1e61d3 fix test
Former-commit-id: 076d01b48c
2026-01-17 15:43:06 -08:00
Owen
a13010c4af Update docs for metadata
Former-commit-id: 9d77a1daf7
2026-01-16 17:33:40 -08:00
Owen
cfac3cdd53 Use the right duration
Former-commit-id: c921f08bd5
2026-01-16 15:17:41 -08:00
Owen
5ecba61718 Use the right duration
Former-commit-id: 352b122166
2026-01-16 15:17:20 -08:00
Owen
2ea12ce258 Set the error on terminate as well
Former-commit-id: 8ff58e6efc
2026-01-16 14:59:13 -08:00
Owen
0b46289136 Add error can be sent from cloud to display in api
Former-commit-id: 2167f22713
2026-01-16 14:19:02 -08:00
Owen
71044165d0 Include fingerprint and posture info in ping
Former-commit-id: f061596e5b
2026-01-16 12:16:51 -08:00
Owen
eafd816159 Clean up log messages
Former-commit-id: 0231591f36
2026-01-16 12:02:02 -08:00
Owen
e1a687407e Set the ping inteval to 30 seconds
Former-commit-id: 737ffca15d
2026-01-15 21:59:18 -08:00
Owen
bd8031651e Message syncing works
Former-commit-id: 1650624a55
2026-01-15 21:25:53 -08:00
Owen
a63439543d Merge branch 'dev' into msg-delivery
Former-commit-id: d6b9170e79
2026-01-15 16:41:00 -08:00
Owen
90cd6e7f6e Merge branch 'power-state' into dev
Former-commit-id: e2a071e6dc
2026-01-15 16:39:41 -08:00
Owen
ea4a63c9b3 Merge branch 'dev' of github.com:fosrl/olm into dev
Former-commit-id: 1c21071ee1
2026-01-15 16:37:09 -08:00
Owen
e047330ffd Handle and test config version bugs
Former-commit-id: 285f8ce530
2026-01-15 16:36:11 -08:00
Owen
9dcc0796a6 Small clean up and move ping to client.go
Former-commit-id: af33218792
2026-01-15 14:20:12 -08:00
Varun Narravula
4b6999e06a feat(ping): send fingerprint and posture checks as part of ping/register
Former-commit-id: 70a7e83291
2026-01-15 12:13:36 -08:00
Varun Narravula
69952ee5c5 feat(api): add fingerprint + posture fields to client state
Former-commit-id: 566084683a
2026-01-15 12:13:36 -08:00
Owen
3710880ce0 Merge branch 'power-state' into msg-delivery
Former-commit-id: bda6606098
2026-01-14 17:51:42 -08:00
Owen
17b75bf58f Dont get token each time
Former-commit-id: 07dfc651f1
2026-01-14 16:51:04 -08:00
Owen
3ba1714524 Power state getting set correctly
Former-commit-id: 0895156efd
2026-01-14 16:38:40 -08:00
Owen
3470da76fc Update resetting intervals
Former-commit-id: 303c2dc0b7
2026-01-14 12:32:29 -08:00
Owen
c86df2c041 Refactor operation
Former-commit-id: 4f09d122bb
2026-01-14 11:58:12 -08:00
Owen
0e8315b149 Merge branch 'dev' into power-state
Former-commit-id: e9728efee3
2026-01-14 11:19:46 -08:00
Owen
2ab9790588 Reduce the pings
Former-commit-id: 5c6ad1ea75
2026-01-14 11:12:10 -08:00
Owen
1ecb97306f Add back AddDevice function
Former-commit-id: cae0ffa2e1
2026-01-13 21:38:37 -08:00
Varun Narravula
15e96a779c refactor(olm): convert global state into an olm instance
Former-commit-id: b755f77d95
2026-01-13 20:52:10 -08:00
miloschwartz
dada0cc124 add low power state for testing
Former-commit-id: 996fe59999
2026-01-13 14:30:02 -08:00
Owen
9c0b4fcd5f Fix error checking
Former-commit-id: 231808476b
2026-01-13 11:51:51 -08:00
Owen
8a788ef238 Merge branch 'dev' of github.com:fosrl/olm into dev
Former-commit-id: 8c5c8d3966
2026-01-12 17:12:45 -08:00
Owen
20e0c18845 Try to reduce cpu when idle
Former-commit-id: ba91478b89
2026-01-12 12:29:42 -08:00
Owen
5b637bb4ca Add expo backoff
Former-commit-id: faae551aca
2026-01-12 12:20:59 -08:00
Varun Narravula
c565a46a6f feat(logger): configure log file path thorugh global options
Former-commit-id: 577d89f4fb
2026-01-11 13:49:39 -08:00
Varun Narravula
7b7eae617a chore: format files using gofmt
Former-commit-id: 5cfa0dfb97
2026-01-11 13:49:39 -08:00
miloschwartz
1ed27fec1a set mtu to 0 on darwin
Former-commit-id: fbe686961e
2026-01-01 17:38:01 -05:00
Owen
83edde3449 Fix build on darwin
Former-commit-id: fbeb5be88d
2025-12-31 18:01:25 -05:00
Owen
1b43f029a9 Dont pass in dns proxy to override
Former-commit-id: 51dd927f9b
2025-12-31 15:42:51 -05:00
Owen
aeb908b68c Exiting the middle device works now?
Former-commit-id: d76b3c366f
2025-12-31 11:33:00 -05:00
Owen
f08b17c7bd Middle device working but not closing
Former-commit-id: c85fcc434b
2025-12-31 11:22:09 -05:00
Owen
cce8742490 Try to make the tun replacable
Former-commit-id: 6be0958887
2025-12-30 21:38:07 -05:00
Owen
c56696bab1 Use a different method on android
Former-commit-id: adf4c21f7b
2025-12-30 16:59:36 -05:00
Owen
7bb004cf50 Update docs
Former-commit-id: 543ca05eb9
2025-12-29 22:15:01 -05:00
Owen
28910ce188 Add stub
Former-commit-id: ece4239aaa
2025-12-29 17:50:15 -05:00
miloschwartz
f8dc134210 add content-length header to status payload
Former-commit-id: 8152d4133f
2025-12-29 17:28:12 -05:00
Varun Narravula
148f5fde23 fix(ci): add back missing docker build local image rule
Former-commit-id: 6d2afb4c72
2025-12-24 10:08:40 -05:00
Owen
b76259bc31 Add sync message
Former-commit-id: d01f180941
2025-12-24 10:06:25 -05:00
Owen
88cc57bcef Update mod
Former-commit-id: 1b474ebc1c
2025-12-23 18:00:15 -05:00
Owen
0b05497c25 Merge branch 'dev' into msg-delivery
Former-commit-id: 4deb3e07b0
2025-12-23 15:44:02 -05:00
Owen
dde79bb2dc Fix go mod
Former-commit-id: e355d8db5f
2025-12-21 20:57:20 -05:00
Owen
3822b1a065 Add version and send it down
Former-commit-id: 52273a81c8
2025-12-19 16:45:11 -05:00
34 changed files with 4027 additions and 1565 deletions

View File

@@ -5,20 +5,10 @@ updates:
schedule:
interval: "daily"
groups:
dev-patch-updates:
dependency-type: "development"
patch-updates:
update-types:
- "patch"
dev-minor-updates:
dependency-type: "development"
update-types:
- "minor"
prod-patch-updates:
dependency-type: "production"
update-types:
- "patch"
prod-minor-updates:
dependency-type: "production"
minor-updates:
update-types:
- "minor"

View File

@@ -48,7 +48,7 @@ jobs:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 0
@@ -92,7 +92,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 0
@@ -104,7 +104,7 @@ jobs:
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
- name: Set up 1.2.0 Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
- name: Log in to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
@@ -234,7 +234,7 @@ jobs:
- name: Cache Go modules
if: ${{ hashFiles('**/go.sum') != '' }}
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
with:
path: |
~/.cache/go-build
@@ -269,7 +269,7 @@ jobs:
} >> "$GITHUB_ENV"
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # v5.9.0
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5.10.0
with:
images: ${{ env.IMAGE_LIST }}
tags: |
@@ -599,7 +599,7 @@ jobs:
shell: bash
- name: Create GitHub Release
uses: softprops/action-gh-release@5be0e66d93ac7ed76da52eca8bb058f665c3a5fe # v2.4.2
uses: softprops/action-gh-release@a06a81a03ee405af7f2048a818ed3f03bbf83c7b # v2.5.0
with:
tag_name: ${{ env.TAG }}
generate_release_notes: true

37
.github/workflows/stale-bot.yml vendored Normal file
View File

@@ -0,0 +1,37 @@
name: Mark and Close Stale Issues
on:
schedule:
- cron: '0 0 * * *'
workflow_dispatch: # Allow manual trigger
permissions:
contents: write # only for delete-branch option
issues: write
pull-requests: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
with:
days-before-stale: 14
days-before-close: 14
stale-issue-message: 'This issue has been automatically marked as stale due to 14 days of inactivity. It will be closed in 14 days if no further activity occurs.'
close-issue-message: 'This issue has been automatically closed due to inactivity. If you believe this is still relevant, please open a new issue with up-to-date information.'
stale-issue-label: 'stale'
exempt-issue-labels: 'needs investigating, networking, new feature, reverse proxy, bug, api, authentication, documentation, enhancement, help wanted, good first issue, question'
exempt-all-issue-assignees: true
only-labels: ''
exempt-pr-labels: ''
days-before-pr-stale: -1
days-before-pr-close: -1
operations-per-run: 100
remove-stale-when-updated: true
delete-branch: false
enable-statistics: true

View File

@@ -1,5 +1,8 @@
name: Run Tests
permissions:
contents: read
on:
pull_request:
branches:
@@ -7,11 +10,12 @@ on:
- dev
jobs:
test:
build-go:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Checkout repository
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- name: Set up Go
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
@@ -21,5 +25,18 @@ jobs:
- name: Build binaries
run: make go-build-release
build-docker:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- name: Set up QEMU
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
- name: Set up 1.2.0 Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
- name: Build Docker image
run: make docker-build-release
run: make docker-build-dev

95
API.md
View File

@@ -1,10 +1,10 @@
## HTTP API
## API
Olm can be controlled with an embedded HTTP server when using `--enable-http`. This allows you to start it as a daemon and trigger it with the following endpoints. The API can listen on either a TCP address or a Unix socket/Windows named pipe.
Olm can be controlled with an embedded API server when using `--enable-api`. This allows you to start it as a daemon and trigger it with the following endpoints. The API can listen on either a TCP address or a Unix socket/Windows named pipe.
### Socket vs TCP
By default, when `--enable-http` is used, Olm listens on a TCP address (configured via `--http-addr`, default `:9452`). Alternatively, Olm can listen on a Unix socket (Linux/macOS) or Windows named pipe for local-only communication with better security.
When `--enable-api` is used, Olm can listen on a TCP address when configured via `--http-addr` (like `:9452`). Alternatively, Olm can listen on a Unix socket (Linux/macOS) or Windows named pipe for local-only communication with better security when using `--socket-path` (like `/var/run/olm.sock`).
**Unix Socket (Linux/macOS):**
- Socket path example: `/var/run/olm/olm.sock`
@@ -46,7 +46,18 @@ Initiates a new connection request to a Pangolin server.
"tlsClientCert": "string",
"pingInterval": "3s",
"pingTimeout": "5s",
"orgId": "string"
"orgId": "string",
"fingerprint": {
"username": "string",
"hostname": "string",
"platform": "string",
"osVersion": "string",
"kernelVersion": "string",
"arch": "string",
"deviceModel": "string",
"serialNumber": "string"
},
"postures": {}
}
```
@@ -67,6 +78,16 @@ Initiates a new connection request to a Pangolin server.
- `pingInterval`: Interval for pinging the server (default: 3s)
- `pingTimeout`: Timeout for each ping (default: 5s)
- `orgId`: Organization ID to connect to
- `fingerprint`: Device fingerprinting information (should be set before connecting)
- `username`: Current username on the device
- `hostname`: Device hostname
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
- `osVersion`: Operating system version
- `kernelVersion`: Kernel version
- `arch`: System architecture (e.g., amd64, arm64)
- `deviceModel`: Device model identifier
- `serialNumber`: Device serial number
- `postures`: Device posture/security information
**Response:**
- **Status Code:** `202 Accepted`
@@ -205,6 +226,56 @@ Switches to a different organization while maintaining the connection.
---
### PUT /metadata
Updates device fingerprinting and posture information. This endpoint can be called at any time to update metadata, but it's recommended to provide this information in the initial `/connect` request or immediately before connecting.
**Request Body:**
```json
{
"fingerprint": {
"username": "string",
"hostname": "string",
"platform": "string",
"osVersion": "string",
"kernelVersion": "string",
"arch": "string",
"deviceModel": "string",
"serialNumber": "string"
},
"postures": {}
}
```
**Optional Fields:**
- `fingerprint`: Device fingerprinting information
- `username`: Current username on the device
- `hostname`: Device hostname
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
- `osVersion`: Operating system version
- `kernelVersion`: Kernel version
- `arch`: System architecture (e.g., amd64, arm64)
- `deviceModel`: Device model identifier
- `serialNumber`: Device serial number
- `postures`: Device posture/security information (object with arbitrary key-value pairs)
**Response:**
- **Status Code:** `200 OK`
- **Content-Type:** `application/json`
```json
{
"status": "metadata updated"
}
```
**Error Responses:**
- `405 Method Not Allowed` - Non-PUT requests
- `400 Bad Request` - Invalid JSON
**Note:** It's recommended to call this endpoint BEFORE `/connect` to ensure fingerprinting information is available during the initial connection handshake.
---
### POST /exit
Initiates a graceful shutdown of the Olm process.
@@ -247,6 +318,22 @@ Simple health check endpoint to verify the API server is running.
## Usage Examples
### Update metadata before connecting (recommended)
```bash
curl -X PUT http://localhost:9452/metadata \
-H "Content-Type: application/json" \
-d '{
"fingerprint": {
"username": "john",
"hostname": "johns-laptop",
"platform": "macos",
"osVersion": "14.2.1",
"arch": "arm64",
"deviceModel": "MacBookPro18,3"
}
}'
```
### Connect to a peer
```bash
curl -X POST http://localhost:9452/connect \

View File

@@ -5,6 +5,9 @@ all: local
local:
CGO_ENABLED=0 go build -o ./bin/olm
docker-build:
docker build -t fosrl/olm:latest .
docker-build-release:
@if [ -z "$(tag)" ]; then \
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
@@ -17,6 +20,12 @@ docker-build-release:
-f Dockerfile \
--push
docker-build-dev:
docker buildx build . \
--platform linux/arm/v7,linux/arm64,linux/amd64 \
-t fosrl/olm:latest \
-f Dockerfile
.PHONY: go-build-release \
go-build-release-linux-arm64 go-build-release-linux-arm32-v7 \
go-build-release-linux-arm32-v6 go-build-release-linux-amd64 \

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/http"
"strconv"
"sync"
"time"
@@ -32,7 +33,12 @@ type ConnectionRequest struct {
// SwitchOrgRequest defines the structure for switching organizations
type SwitchOrgRequest struct {
OrgID string `json:"orgId"`
OrgID string `json:"org_id"`
}
// PowerModeRequest represents a request to change power mode
type PowerModeRequest struct {
Mode string `json:"mode"` // "normal" or "low"
}
// PeerStatus represents the status of a peer connection
@@ -48,11 +54,18 @@ type PeerStatus struct {
HolepunchConnected bool `json:"holepunchConnected"`
}
// OlmError holds error information from registration failures
type OlmError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// StatusResponse is returned by the status endpoint
type StatusResponse struct {
Connected bool `json:"connected"`
Registered bool `json:"registered"`
Terminated bool `json:"terminated"`
OlmError *OlmError `json:"error,omitempty"`
Version string `json:"version,omitempty"`
Agent string `json:"agent,omitempty"`
OrgID string `json:"orgId,omitempty"`
@@ -60,25 +73,37 @@ type StatusResponse struct {
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
}
type MetadataChangeRequest struct {
Fingerprint map[string]any `json:"fingerprint"`
Postures map[string]any `json:"postures"`
}
// API represents the HTTP server and its state
type API struct {
addr string
socketPath string
listener net.Listener
server *http.Server
onConnect func(ConnectionRequest) error
onSwitchOrg func(SwitchOrgRequest) error
onDisconnect func() error
onExit func() error
addr string
socketPath string
listener net.Listener
server *http.Server
onConnect func(ConnectionRequest) error
onSwitchOrg func(SwitchOrgRequest) error
onMetadataChange func(MetadataChangeRequest) error
onDisconnect func() error
onExit func() error
onRebind func() error
onPowerMode func(PowerModeRequest) error
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
isRegistered bool
isTerminated bool
version string
agent string
orgID string
olmError *OlmError
version string
agent string
orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
@@ -101,28 +126,49 @@ func NewAPISocket(socketPath string) *API {
return s
}
func NewAPIStub() *API {
s := &API{
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// SetHandlers sets the callback functions for handling API requests
func (s *API) SetHandlers(
onConnect func(ConnectionRequest) error,
onSwitchOrg func(SwitchOrgRequest) error,
onMetadataChange func(MetadataChangeRequest) error,
onDisconnect func() error,
onExit func() error,
onRebind func() error,
onPowerMode func(PowerModeRequest) error,
) {
s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg
s.onMetadataChange = onMetadataChange
s.onDisconnect = onDisconnect
s.onExit = onExit
s.onRebind = onRebind
s.onPowerMode = onPowerMode
}
// Start starts the HTTP server
func (s *API) Start() error {
if s.socketPath == "" && s.addr == "" {
return fmt.Errorf("either socketPath or addr must be provided to start the API server")
}
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/metadata", s.handleMetadataChange)
mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit)
mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/rebind", s.handleRebind)
mux.HandleFunc("/power-mode", s.handlePowerMode)
s.server = &http.Server{
Handler: mux,
@@ -160,7 +206,7 @@ func (s *API) Stop() error {
// Close the server first, which will also close the listener gracefully
if s.server != nil {
s.server.Close()
_ = s.server.Close()
}
// Clean up socket file if using Unix socket
@@ -226,9 +272,6 @@ func (s *API) SetConnectionStatus(isConnected bool) {
if isConnected {
s.connectedAt = time.Now()
} else {
// Clear peer statuses when disconnected
s.peerStatuses = make(map[int]*PeerStatus)
}
}
@@ -236,6 +279,27 @@ func (s *API) SetRegistered(registered bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isRegistered = registered
// Clear any registration error when successfully registered
if registered {
s.olmError = nil
}
}
// SetOlmError sets the registration error
func (s *API) SetOlmError(code string, message string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.olmError = &OlmError{
Code: code,
Message: message,
}
}
// ClearOlmError clears any registration error
func (s *API) ClearOlmError() {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.olmError = nil
}
func (s *API) SetTerminated(terminated bool) {
@@ -345,7 +409,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "connection request accepted",
})
}
@@ -358,12 +422,12 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
}
s.statusMu.RLock()
defer s.statusMu.RUnlock()
resp := StatusResponse{
Connected: s.isConnected,
Registered: s.isRegistered,
Terminated: s.isTerminated,
OlmError: s.olmError,
Version: s.version,
Agent: s.agent,
OrgID: s.orgID,
@@ -371,8 +435,18 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
NetworkSettings: network.GetSettings(),
}
s.statusMu.RUnlock()
data, err := json.Marshal(resp)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}
// handleHealth handles the /health endpoint
@@ -384,7 +458,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "ok",
})
}
@@ -401,7 +475,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
// Return a success response first
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "shutdown initiated",
})
@@ -450,7 +524,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "org switch request accepted",
})
}
@@ -484,16 +558,43 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "disconnect initiated",
})
}
// handleMetadataChange handles the /metadata endpoint
func (s *API) handleMetadataChange(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req MetadataChangeRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
logger.Info("Received metadata change request via API: %v", req)
_ = s.onMetadataChange(req)
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "metadata updated",
})
}
func (s *API) GetStatus() StatusResponse {
return StatusResponse{
Connected: s.isConnected,
Registered: s.isRegistered,
Terminated: s.isTerminated,
OlmError: s.olmError,
Version: s.version,
Agent: s.agent,
OrgID: s.orgID,
@@ -501,3 +602,74 @@ func (s *API) GetStatus() StatusResponse {
NetworkSettings: network.GetSettings(),
}
}
// handleRebind handles the /rebind endpoint
// This triggers a socket rebind, which is necessary when network connectivity changes
// (e.g., WiFi to cellular transition on macOS/iOS) and the old socket becomes stale.
func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("Received rebind request via API")
// Call the rebind handler if set
if s.onRebind != nil {
if err := s.onRebind(); err != nil {
http.Error(w, fmt.Sprintf("Rebind failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "Rebind handler not configured", http.StatusNotImplemented)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "socket rebound successfully",
})
}
// handlePowerMode handles the /power-mode endpoint
// This allows changing the power mode between "normal" and "low"
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req PowerModeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate power mode
if req.Mode != "normal" && req.Mode != "low" {
http.Error(w, "Invalid power mode: must be 'normal' or 'low'", http.StatusBadRequest)
return
}
logger.Info("Received power mode change request via API: mode=%s", req.Mode)
// Call the power mode handler if set
if s.onPowerMode != nil {
if err := s.onPowerMode(req); err != nil {
http.Error(w, fmt.Sprintf("Power mode change failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "Power mode handler not configured", http.StatusNotImplemented)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": fmt.Sprintf("power mode changed to %s successfully", req.Mode),
})
}

View File

@@ -89,6 +89,7 @@ func DefaultConfig() *OlmConfig {
PingInterval: "3s",
PingTimeout: "5s",
DisableHolepunch: false,
OverrideDNS: true,
TunnelDNS: false,
// DoNotCreateNewClient: false,
sources: make(map[string]string),
@@ -324,9 +325,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings")
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "When enabled, the client uses custom DNS servers to resolve internal resources and aliases. This overrides your system's default DNS settings. Queries that cannot be resolved as a Pangolin resource will be forwarded to your configured Upstream DNS Server. (default false)")
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "Use tunnel for DNS traffic")
serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "When enabled, DNS queries are routed through the tunnel for remote resolution. To ensure queries are tunneled correctly, you must define the DNS server as a Pangolin resource and enter its address as an Upstream DNS Server. (default false)")
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
version := serviceFlags.Bool("version", false, "Print the version")

View File

@@ -1,9 +1,12 @@
package device
import (
"io"
"net/netip"
"os"
"sync"
"sync/atomic"
"time"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun"
@@ -18,14 +21,68 @@ type FilterRule struct {
Handler PacketHandler
}
// MiddleDevice wraps a TUN device with packet filtering capabilities
type MiddleDevice struct {
// closeAwareDevice wraps a tun.Device along with a flag
// indicating whether its Close method was called.
type closeAwareDevice struct {
isClosed atomic.Bool
tun.Device
rules []FilterRule
mutex sync.RWMutex
readCh chan readResult
injectCh chan []byte
closed chan struct{}
closeEventCh chan struct{}
wg sync.WaitGroup
closeOnce sync.Once
}
func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice {
return &closeAwareDevice{
Device: tunDevice,
isClosed: atomic.Bool{},
closeEventCh: make(chan struct{}),
}
}
// redirectEvents redirects the Events() method of the underlying tun.Device
// to the given channel.
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
for {
select {
case ev, ok := <-c.Device.Events():
if !ok {
return
}
if ev == tun.EventDown {
continue
}
select {
case out <- ev:
case <-c.closeEventCh:
return
}
case <-c.closeEventCh:
return
}
}
}()
}
// Close calls the underlying Device's Close method
// after setting isClosed to true.
func (c *closeAwareDevice) Close() (err error) {
c.closeOnce.Do(func() {
c.isClosed.Store(true)
close(c.closeEventCh)
err = c.Device.Close()
c.wg.Wait()
})
return err
}
func (c *closeAwareDevice) IsClosed() bool {
return c.isClosed.Load()
}
type readResult struct {
@@ -36,58 +93,136 @@ type readResult struct {
err error
}
// MiddleDevice wraps a TUN device with packet filtering capabilities
// and supports swapping the underlying device.
type MiddleDevice struct {
devices []*closeAwareDevice
mu sync.Mutex
cond *sync.Cond
rules []FilterRule
rulesMutex sync.RWMutex
readCh chan readResult
injectCh chan []byte
closed atomic.Bool
events chan tun.Event
}
// NewMiddleDevice creates a new filtered TUN device wrapper
func NewMiddleDevice(device tun.Device) *MiddleDevice {
d := &MiddleDevice{
Device: device,
devices: make([]*closeAwareDevice, 0),
rules: make([]FilterRule, 0),
readCh: make(chan readResult),
readCh: make(chan readResult, 16),
injectCh: make(chan []byte, 100),
closed: make(chan struct{}),
events: make(chan tun.Event, 16),
}
go d.pump()
d.cond = sync.NewCond(&d.mu)
if device != nil {
d.AddDevice(device)
}
return d
}
func (d *MiddleDevice) pump() {
// AddDevice adds a new underlying TUN device, closing any previous one
func (d *MiddleDevice) AddDevice(device tun.Device) {
d.mu.Lock()
if d.closed.Load() {
d.mu.Unlock()
_ = device.Close()
return
}
var toClose *closeAwareDevice
if len(d.devices) > 0 {
toClose = d.devices[len(d.devices)-1]
}
cad := newCloseAwareDevice(device)
cad.redirectEvents(d.events)
d.devices = []*closeAwareDevice{cad}
// Start pump for the new device
go d.pump(cad)
d.cond.Broadcast()
d.mu.Unlock()
if toClose != nil {
logger.Debug("MiddleDevice: Closing previous device")
if err := toClose.Close(); err != nil {
logger.Debug("MiddleDevice: Error closing previous device: %v", err)
}
}
}
func (d *MiddleDevice) pump(dev *closeAwareDevice) {
const defaultOffset = 16
batchSize := d.Device.BatchSize()
logger.Debug("MiddleDevice: pump started")
batchSize := dev.BatchSize()
logger.Debug("MiddleDevice: pump started for device")
// Recover from panic if readCh is closed while we're trying to send
defer func() {
if r := recover(); r != nil {
logger.Debug("MiddleDevice: pump recovered from panic (channel closed)")
}
}()
for {
// Check closed first with priority
select {
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel")
// Check if this device is closed
if dev.IsClosed() {
logger.Debug("MiddleDevice: pump exiting, device is closed")
return
}
// Check if MiddleDevice itself is closed
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed")
return
default:
}
// Allocate buffers for reading
// We allocate new buffers for each read to avoid race conditions
// since we pass them to the channel
bufs := make([][]byte, batchSize)
sizes := make([]int, batchSize)
for i := range bufs {
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
}
n, err := d.Device.Read(bufs, sizes, defaultOffset)
n, err := dev.Read(bufs, sizes, defaultOffset)
// Check closed again after read returns
select {
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
// Check if device was closed during read
if dev.IsClosed() {
logger.Debug("MiddleDevice: pump exiting, device closed during read")
return
}
// Check if MiddleDevice was closed during read
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read")
return
}
// Try to send the result - check closed state first to avoid sending on closed channel
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, device closed before send")
return
default:
}
// Now try to send the result
select {
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
return
default:
// Channel full, check if we should exit
if dev.IsClosed() || d.closed.Load() {
return
}
// Try again with blocking
select {
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
case <-dev.closeEventCh:
return
}
}
if err != nil {
@@ -99,16 +234,28 @@ func (d *MiddleDevice) pump() {
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
func (d *MiddleDevice) InjectOutbound(packet []byte) {
if d.closed.Load() {
return
}
// Use defer/recover to handle panic from sending on closed channel
// This can happen during shutdown race conditions
defer func() {
if r := recover(); r != nil {
logger.Debug("MiddleDevice: InjectOutbound recovered from panic (channel closed)")
}
}()
select {
case d.injectCh <- packet:
case <-d.closed:
default:
// Channel full, drop packet
logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full")
}
}
// AddRule adds a packet filtering rule
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.rulesMutex.Lock()
defer d.rulesMutex.Unlock()
d.rules = append(d.rules, FilterRule{
DestIP: destIP,
Handler: handler,
@@ -117,8 +264,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
// RemoveRule removes all rules for a given destination IP
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.rulesMutex.Lock()
defer d.rulesMutex.Unlock()
newRules := make([]FilterRule, 0, len(d.rules))
for _, rule := range d.rules {
if rule.DestIP != destIP {
@@ -130,18 +277,120 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
// Close stops the device
func (d *MiddleDevice) Close() error {
select {
case <-d.closed:
// Already closed
return nil
default:
logger.Debug("MiddleDevice: Closing, signaling closed channel")
close(d.closed)
if !d.closed.CompareAndSwap(false, true) {
return nil // already closed
}
logger.Debug("MiddleDevice: Closing underlying TUN device")
err := d.Device.Close()
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
return err
d.mu.Lock()
devices := d.devices
d.devices = nil
d.cond.Broadcast()
d.mu.Unlock()
// Close underlying devices first - this causes the pump goroutines to exit
// when their read operations return errors
var lastErr error
logger.Debug("MiddleDevice: Closing %d devices", len(devices))
for _, device := range devices {
if err := device.Close(); err != nil {
logger.Debug("MiddleDevice: Error closing device: %v", err)
lastErr = err
}
}
// Now close channels to unblock any remaining readers
// The pump should have exited by now, but close channels to be safe
close(d.readCh)
close(d.injectCh)
close(d.events)
return lastErr
}
// Events returns the events channel
func (d *MiddleDevice) Events() <-chan tun.Event {
return d.events
}
// File returns the underlying file descriptor
func (d *MiddleDevice) File() *os.File {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return nil
}
continue
}
file := dev.File()
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return file
}
}
// MTU returns the MTU of the underlying device
func (d *MiddleDevice) MTU() (int, error) {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
continue
}
mtu, err := dev.MTU()
if err == nil {
return mtu, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return 0, err
}
}
// Name returns the name of the underlying device
func (d *MiddleDevice) Name() (string, error) {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return "", io.EOF
}
continue
}
name, err := dev.Name()
if err == nil {
return name, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return "", err
}
}
// BatchSize returns the batch size
func (d *MiddleDevice) BatchSize() int {
dev := d.peekLast()
if dev == nil {
return 1
}
return dev.BatchSize()
}
// extractDestIP extracts destination IP from packet (fast path)
@@ -176,156 +425,239 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
// Read intercepts packets going UP from the TUN device (towards WireGuard)
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
// Check if already closed first (non-blocking)
select {
case <-d.closed:
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
return 0, os.ErrClosed
default:
}
// Now block waiting for data
select {
case res := <-d.readCh:
if res.err != nil {
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
return 0, res.err
for {
if d.closed.Load() {
logger.Debug("MiddleDevice: Read returning io.EOF, device closed")
return 0, io.EOF
}
// Copy packets from result to provided buffers
count := 0
for i := 0; i < res.n && i < len(bufs); i++ {
// Handle offset mismatch if necessary
// We assume the pump used defaultOffset (16)
// If caller asks for different offset, we need to shift
src := res.bufs[i]
srcOffset := res.offset
srcSize := res.sizes[i]
// Calculate where the packet data starts and ends in src
pktData := src[srcOffset : srcOffset+srcSize]
// Ensure dest buffer is large enough
if len(bufs[i]) < offset+len(pktData) {
continue // Skip if buffer too small
// Wait for a device to be available
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
copy(bufs[i][offset:], pktData)
sizes[i] = len(pktData)
count++
}
n = count
case pkt := <-d.injectCh:
if len(bufs) == 0 {
return 0, nil
}
if len(bufs[0]) < offset+len(pkt) {
return 0, nil // Buffer too small
}
copy(bufs[0][offset:], pkt)
sizes[0] = len(pkt)
n = 1
case <-d.closed:
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
return 0, os.ErrClosed // Signal that device is closed
}
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
if len(rules) == 0 {
return n, nil
}
// Process packets and filter out handled ones
writeIdx := 0
for readIdx := 0; readIdx < n; readIdx++ {
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
continue
}
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
// Now block waiting for data from readCh or injectCh
select {
case res, ok := <-d.readCh:
if !ok {
// Channel closed, device is shutting down
return 0, io.EOF
}
if res.err != nil {
// Check if device was swapped
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
return 0, res.err
}
// Copy packets from result to provided buffers
count := 0
for i := 0; i < res.n && i < len(bufs); i++ {
src := res.bufs[i]
srcOffset := res.offset
srcSize := res.sizes[i]
pktData := src[srcOffset : srcOffset+srcSize]
if len(bufs[i]) < offset+len(pktData) {
continue
}
copy(bufs[i][offset:], pktData)
sizes[i] = len(pktData)
count++
}
n = count
case pkt, ok := <-d.injectCh:
if !ok {
// Channel closed, device is shutting down
return 0, io.EOF
}
if len(bufs) == 0 {
return 0, nil
}
if len(bufs[0]) < offset+len(pkt) {
return 0, nil
}
copy(bufs[0][offset:], pkt)
sizes[0] = len(pkt)
n = 1
}
// Apply filtering rules
d.rulesMutex.RLock()
rules := d.rules
d.rulesMutex.RUnlock()
if len(rules) == 0 {
return n, nil
}
// Process packets and filter out handled ones
writeIdx := 0
for readIdx := 0; readIdx < n; readIdx++ {
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
destIP, ok := extractDestIP(packet)
if !ok {
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
continue
}
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
handled = true
break
}
}
}
}
if !handled {
// Keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
if !handled {
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
}
writeIdx++
}
}
return writeIdx, err
return writeIdx, nil
}
}
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
for {
if d.closed.Load() {
return 0, io.EOF
}
if len(rules) == 0 {
return d.Device.Write(bufs, offset)
}
// Filter packets going down
filteredBufs := make([][]byte, 0, len(bufs))
for _, buf := range bufs {
if len(buf) <= offset {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
continue
}
packet := buf[offset:]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
filteredBufs = append(filteredBufs, buf)
continue
}
d.rulesMutex.RLock()
rules := d.rules
d.rulesMutex.RUnlock()
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
var filteredBufs [][]byte
if len(rules) == 0 {
filteredBufs = bufs
} else {
filteredBufs = make([][]byte, 0, len(bufs))
for _, buf := range bufs {
if len(buf) <= offset {
continue
}
packet := buf[offset:]
destIP, ok := extractDestIP(packet)
if !ok {
filteredBufs = append(filteredBufs, buf)
continue
}
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
handled = true
break
}
}
}
if !handled {
filteredBufs = append(filteredBufs, buf)
}
}
}
if !handled {
filteredBufs = append(filteredBufs, buf)
if len(filteredBufs) == 0 {
return len(bufs), nil
}
}
if len(filteredBufs) == 0 {
return len(bufs), nil // All packets were handled
}
n, err := dev.Write(filteredBufs, offset)
if err == nil {
return n, nil
}
return d.Device.Write(filteredBufs, offset)
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
}
func (d *MiddleDevice) waitForDevice() bool {
d.mu.Lock()
defer d.mu.Unlock()
for len(d.devices) == 0 && !d.closed.Load() {
d.cond.Wait()
}
return !d.closed.Load()
}
func (d *MiddleDevice) peekLast() *closeAwareDevice {
d.mu.Lock()
defer d.mu.Unlock()
if len(d.devices) == 0 {
return nil
}
return d.devices[len(d.devices)-1]
}
// WriteToTun writes packets directly to the underlying TUN device,
// bypassing WireGuard. This is useful for sending packets that should
// appear to come from the TUN interface (e.g., DNS responses from a proxy).
// Unlike Write(), this does not go through packet filtering rules.
func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) {
for {
if d.closed.Load() {
return 0, io.EOF
}
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
continue
}
n, err := dev.Write(bufs, offset)
if err == nil {
return n, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
}

View File

@@ -1,4 +1,4 @@
//go:build !windows
//go:build darwin
package device
@@ -26,7 +26,7 @@ func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
}
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
device, err := tun.CreateTUNFromFile(file, mtuInt)
device, err := tun.CreateTUNFromFile(file, 0)
if err != nil {
file.Close()
return nil, err

50
device/tun_linux.go Normal file
View File

@@ -0,0 +1,50 @@
//go:build linux
package device
import (
"net"
"os"
"runtime"
"github.com/fosrl/newt/logger"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
)
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
if runtime.GOOS == "android" { // otherwise we get a permission denied
theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
return theTun, err
}
dupTunFd, err := unix.Dup(int(tunFd))
if err != nil {
logger.Error("Unable to dup tun fd: %v", err)
return nil, err
}
err = unix.SetNonblock(dupTunFd, true)
if err != nil {
unix.Close(dupTunFd)
return nil, err
}
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
device, err := tun.CreateTUNFromFile(file, mtuInt)
if err != nil {
file.Close()
return nil, err
}
return device, nil
}
func UapiOpen(interfaceName string) (*os.File, error) {
return ipc.UAPIOpen(interfaceName)
}
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
return ipc.UAPIListen(interfaceName, fileUAPI)
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/device"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
@@ -34,18 +33,17 @@ type DNSProxy struct {
ep *channel.Endpoint
proxyIP netip.Addr
upstreamDNS []string
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes
recordStore *DNSRecordStore // Local DNS records
// Tunnel DNS fields - for sending queries over WireGuard
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
tunnelEp *channel.Endpoint
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
tunnelEp *channel.Endpoint
tunnelActivePorts map[uint16]bool
tunnelPortsLock sync.Mutex
tunnelPortsLock sync.Mutex
ctx context.Context
cancel context.CancelFunc
@@ -53,7 +51,7 @@ type DNSProxy struct {
}
// NewDNSProxy creates a new DNS proxy
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
proxyIP, err := PickIPFromSubnet(utilitySubnet)
if err != nil {
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
@@ -68,7 +66,6 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
proxy := &DNSProxy{
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
middleDevice: middleDevice,
upstreamDNS: upstreamDns,
tunnelDNS: tunnelDns,
@@ -383,7 +380,7 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
// Check if we have local records for this query
var response *dns.Msg
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA || question.Qtype == dns.TypePTR {
response = p.checkLocalRecords(msg, question)
}
@@ -413,6 +410,34 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
// checkLocalRecords checks if we have local records for the query
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
// Handle PTR queries
if question.Qtype == dns.TypePTR {
if ptrDomain, ok := p.recordStore.GetPTRRecord(question.Name); ok {
logger.Debug("Found local PTR record for %s -> %s", question.Name, ptrDomain)
// Create response message
response := new(dns.Msg)
response.SetReply(query)
response.Authoritative = true
// Add PTR answer record
rr := &dns.PTR{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 300, // 5 minutes
},
Ptr: ptrDomain,
}
response.Answer = append(response.Answer, rr)
return response
}
return nil
}
// Handle A and AAAA queries
var recordType RecordType
if question.Qtype == dns.TypeA {
recordType = RecordTypeA
@@ -602,12 +627,12 @@ func (p *DNSProxy) runTunnelPacketSender() {
defer p.wg.Done()
logger.Debug("DNS tunnel packet sender goroutine started")
ticker := time.NewTicker(1 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-p.ctx.Done():
// Use blocking ReadContext instead of polling - much more CPU efficient
// This will block until a packet is available or context is cancelled
pkt := p.tunnelEp.ReadContext(p.ctx)
if pkt == nil {
// Context was cancelled or endpoint closed
logger.Debug("DNS tunnel packet sender exiting")
// Drain any remaining packets
for {
@@ -618,36 +643,28 @@ func (p *DNSProxy) runTunnelPacketSender() {
pkt.DecRef()
}
return
case <-ticker.C:
// Try to read packets
for i := 0; i < 10; i++ {
pkt := p.tunnelEp.Read()
if pkt == nil {
break
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
p.middleDevice.InjectOutbound(buf)
}
pkt.DecRef()
}
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
p.middleDevice.InjectOutbound(buf)
}
pkt.DecRef()
}
}
@@ -660,18 +677,12 @@ func (p *DNSProxy) runPacketSender() {
const offset = 16
for {
select {
case <-p.ctx.Done():
return
default:
}
// Read packets from netstack endpoint
pkt := p.ep.Read()
// Use blocking ReadContext instead of polling - much more CPU efficient
// This will block until a packet is available or context is cancelled
pkt := p.ep.ReadContext(p.ctx)
if pkt == nil {
// No packet available, small sleep to avoid busy loop
time.Sleep(1 * time.Millisecond)
continue
// Context was cancelled or endpoint closed
return
}
// Extract packet data as slices
@@ -694,9 +705,9 @@ func (p *DNSProxy) runPacketSender() {
pos += len(slice)
}
// Write packet to TUN device
// Write packet to TUN device via MiddleDevice
// offset=16 indicates packet data starts at position 16 in the buffer
_, err := p.tunDevice.Write([][]byte{buf}, offset)
_, err := p.middleDevice.WriteToTun([][]byte{buf}, offset)
if err != nil {
logger.Error("Failed to write DNS response to TUN: %v", err)
}

View File

@@ -1,6 +1,7 @@
package dns
import (
"fmt"
"net"
"strings"
"sync"
@@ -14,15 +15,17 @@ type RecordType uint16
const (
RecordTypeA RecordType = RecordType(dns.TypeA)
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
RecordTypePTR RecordType = RecordType(dns.TypePTR)
)
// DNSRecordStore manages local DNS records for A and AAAA queries
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries
type DNSRecordStore struct {
mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
ptrRecords map[string]string // IP address string -> domain name
}
// NewDNSRecordStore creates a new DNS record store
@@ -32,6 +35,7 @@ func NewDNSRecordStore() *DNSRecordStore {
aaaaRecords: make(map[string][]net.IP),
aWildcards: make(map[string][]net.IP),
aaaaWildcards: make(map[string][]net.IP),
ptrRecords: make(map[string]string),
}
}
@@ -39,6 +43,7 @@ func NewDNSRecordStore() *DNSRecordStore {
// domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address
// Automatically adds a corresponding PTR record for non-wildcard domains
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -48,8 +53,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
domain = domain + "."
}
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?")
@@ -60,6 +65,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
} else {
s.aRecords[domain] = append(s.aRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
}
} else if ip.To16() != nil {
// IPv6 address
@@ -67,6 +74,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
} else {
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
}
} else {
return &net.ParseError{Type: "IP address", Text: ip.String()}
@@ -75,8 +84,30 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
return nil
}
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
// ip should be a valid IPv4 or IPv6 address
// domain should be in FQDN format (e.g., "example.com.")
func (s *DNSRecordStore) AddPTRRecord(ip net.IP, domain string) error {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Store PTR record using IP string as key
s.ptrRecords[ip.String()] = domain
return nil
}
// RemoveRecord removes a specific DNS record mapping
// If ip is nil, removes all records for the domain (including wildcards)
// Automatically removes corresponding PTR records for non-wildcard domains
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -86,8 +117,8 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
domain = domain + "."
}
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?")
@@ -98,6 +129,23 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
delete(s.aWildcards, domain)
delete(s.aaaaWildcards, domain)
} else {
// For non-wildcard domains, remove PTR records for all IPs
if ips, ok := s.aRecords[domain]; ok {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
if ips, ok := s.aaaaRecords[domain]; ok {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
}
@@ -119,6 +167,10 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
}
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
}
}
} else if ip.To16() != nil {
@@ -136,11 +188,23 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
}
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
}
}
}
}
// RemovePTRRecord removes a PTR record for an IP address
func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.ptrRecords, ip.String())
}
// GetRecords returns all IP addresses for a domain and record type
// First checks for exact matches, then checks wildcard patterns
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
@@ -148,7 +212,7 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain)
domain = strings.ToLower(dns.Fqdn(domain))
var records []net.IP
switch recordType {
@@ -198,6 +262,26 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
return records
}
// GetPTRRecord returns the domain name for a PTR record query
// domain should be in reverse DNS format (e.g., "1.0.0.127.in-addr.arpa.")
func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
// Convert reverse DNS format to IP address
ip := reverseDNSToIP(domain)
if ip == nil {
return "", false
}
// Look up the PTR record
if ptrDomain, ok := s.ptrRecords[ip.String()]; ok {
return ptrDomain, true
}
return "", false
}
// HasRecord checks if a domain has any records of the specified type
// Checks both exact matches and wildcard patterns
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
@@ -205,7 +289,7 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain)
domain = strings.ToLower(dns.Fqdn(domain))
switch recordType {
case RecordTypeA:
@@ -235,6 +319,21 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
return false
}
// HasPTRRecord checks if a PTR record exists for the given reverse DNS domain
func (s *DNSRecordStore) HasPTRRecord(domain string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
// Convert reverse DNS format to IP address
ip := reverseDNSToIP(domain)
if ip == nil {
return false
}
_, ok := s.ptrRecords[ip.String()]
return ok
}
// Clear removes all records from the store
func (s *DNSRecordStore) Clear() {
s.mu.Lock()
@@ -244,6 +343,7 @@ func (s *DNSRecordStore) Clear() {
s.aaaaRecords = make(map[string][]net.IP)
s.aWildcards = make(map[string][]net.IP)
s.aaaaWildcards = make(map[string][]net.IP)
s.ptrRecords = make(map[string]string)
}
// removeIP is a helper function to remove a specific IP from a slice
@@ -322,4 +422,76 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool {
}
return matchWildcardInternal(pattern, domain, pi+1, di+1)
}
// reverseDNSToIP converts a reverse DNS query name to an IP address
// Supports both IPv4 (in-addr.arpa) and IPv6 (ip6.arpa) formats
func reverseDNSToIP(domain string) net.IP {
// Normalize to lowercase and ensure FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Check for IPv4 reverse DNS (in-addr.arpa)
if strings.HasSuffix(domain, ".in-addr.arpa.") {
// Remove the suffix
ipPart := strings.TrimSuffix(domain, ".in-addr.arpa.")
// Split by dots and reverse
parts := strings.Split(ipPart, ".")
if len(parts) != 4 {
return nil
}
// Reverse the octets
reversed := make([]string, 4)
for i := 0; i < 4; i++ {
reversed[i] = parts[3-i]
}
// Parse as IP
return net.ParseIP(strings.Join(reversed, "."))
}
// Check for IPv6 reverse DNS (ip6.arpa)
if strings.HasSuffix(domain, ".ip6.arpa.") {
// Remove the suffix
ipPart := strings.TrimSuffix(domain, ".ip6.arpa.")
// Split by dots and reverse
parts := strings.Split(ipPart, ".")
if len(parts) != 32 {
return nil
}
// Reverse the nibbles and group into 16-bit hex values
reversed := make([]string, 32)
for i := 0; i < 32; i++ {
reversed[i] = parts[31-i]
}
// Join into IPv6 format (groups of 4 nibbles separated by colons)
var ipv6Parts []string
for i := 0; i < 32; i += 4 {
ipv6Parts = append(ipv6Parts, reversed[i]+reversed[i+1]+reversed[i+2]+reversed[i+3])
}
// Parse as IP
return net.ParseIP(strings.Join(ipv6Parts, ":"))
}
return nil
}
// IPToReverseDNS converts an IP address to reverse DNS format
// Returns the domain name for PTR queries (e.g., "1.0.0.127.in-addr.arpa.")
func IPToReverseDNS(ip net.IP) string {
if ip4 := ip.To4(); ip4 != nil {
// IPv4: reverse octets and append .in-addr.arpa.
return dns.Fqdn(fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa",
ip4[3], ip4[2], ip4[1], ip4[0]))
}
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
// IPv6: expand to 32 nibbles, reverse, and append .ip6.arpa.
var nibbles []string
for i := 15; i >= 0; i-- {
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]&0x0f))
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]>>4))
}
return dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
}
return ""
}

View File

@@ -37,7 +37,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "autoco.internal.",
expected: false,
},
// Question mark wildcard tests
{
name: "host-0?.autoco.internal matches host-01.autoco.internal",
@@ -63,7 +63,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-012.autoco.internal.",
expected: false,
},
// Combined wildcard tests
{
name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal",
@@ -83,7 +83,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-01.autoco.internal.",
expected: false,
},
// Multiple asterisks
{
name: "*.*. autoco.internal matches any.thing.autoco.internal",
@@ -97,7 +97,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "single.autoco.internal.",
expected: false,
},
// Asterisk in middle
{
name: "host-*.autoco.internal matches host-anything.autoco.internal",
@@ -111,7 +111,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-.autoco.internal.",
expected: true,
},
// Multiple question marks
{
name: "host-??.autoco.internal matches host-01.autoco.internal",
@@ -125,7 +125,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-1.autoco.internal.",
expected: false,
},
// Exact match (no wildcards)
{
name: "exact.autoco.internal matches exact.autoco.internal",
@@ -139,7 +139,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "other.autoco.internal.",
expected: false,
},
// Edge cases
{
name: "* matches anything",
@@ -154,7 +154,7 @@ func TestWildcardMatching(t *testing.T) {
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchWildcard(tt.pattern, tt.domain)
@@ -167,21 +167,21 @@ func TestWildcardMatching(t *testing.T) {
func TestDNSRecordStoreWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard records
wildcardIP := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", wildcardIP)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Add exact record
exactIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("exact.autoco.internal", exactIP)
if err != nil {
t.Fatalf("Failed to add exact record: %v", err)
}
// Test exact match takes precedence
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
@@ -190,7 +190,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
if !ips[0].Equal(exactIP) {
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
}
// Test wildcard match
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
@@ -199,7 +199,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
if !ips[0].Equal(wildcardIP) {
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
}
// Test non-match (base domain)
ips = store.GetRecords("autoco.internal.", RecordTypeA)
if len(ips) != 0 {
@@ -209,14 +209,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
func TestDNSRecordStoreComplexWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add complex wildcard pattern
ip1 := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.host-0?.autoco.internal", ip1)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Test matching domain
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
@@ -225,13 +225,13 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
if len(ips) > 0 && !ips[0].Equal(ip1) {
t.Errorf("Expected IP %v, got %v", ip1, ips[0])
}
// Test non-matching domain (missing prefix)
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
}
// Test non-matching domain (wrong ? position)
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
@@ -241,23 +241,23 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Verify it exists
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
}
// Remove wildcard record
store.RemoveRecord("*.autoco.internal", nil)
// Verify it's gone
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
@@ -267,40 +267,40 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
store := NewDNSRecordStore()
// Add multiple wildcard patterns that don't overlap
ip1 := net.ParseIP("10.0.0.1")
ip2 := net.ParseIP("10.0.0.2")
ip3 := net.ParseIP("10.0.0.3")
err := store.AddRecord("*.prod.autoco.internal", ip1)
if err != nil {
t.Fatalf("Failed to add first wildcard: %v", err)
}
err = store.AddRecord("*.dev.autoco.internal", ip2)
if err != nil {
t.Fatalf("Failed to add second wildcard: %v", err)
}
// Add a broader wildcard that matches both
err = store.AddRecord("*.autoco.internal", ip3)
if err != nil {
t.Fatalf("Failed to add third wildcard: %v", err)
}
// Test domain matching only the prod pattern and the broad pattern
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
if len(ips) != 2 {
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
}
// Test domain matching only the dev pattern and the broad pattern
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
if len(ips) != 2 {
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
}
// Test domain matching only the broad pattern
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
@@ -310,14 +310,14 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add IPv6 wildcard record
ip := net.ParseIP("2001:db8::1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
}
// Test wildcard match for IPv6
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
if len(ips) != 1 {
@@ -330,21 +330,535 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
func TestHasRecordWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Test HasRecord with wildcard match
if !store.HasRecord("host.autoco.internal.", RecordTypeA) {
t.Error("Expected HasRecord to return true for wildcard match")
}
// Test HasRecord with non-match
if store.HasRecord("autoco.internal.", RecordTypeA) {
t.Error("Expected HasRecord to return false for base domain")
}
}
}
func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
store := NewDNSRecordStore()
// Add record with mixed case
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("MyHost.AutoCo.Internal", ip)
if err != nil {
t.Fatalf("Failed to add mixed case record: %v", err)
}
// Test lookup with different cases
testCases := []string{
"myhost.autoco.internal.",
"MYHOST.AUTOCO.INTERNAL.",
"MyHost.AutoCo.Internal.",
"mYhOsT.aUtOcO.iNtErNaL.",
}
for _, domain := range testCases {
ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
}
if len(ips) > 0 && !ips[0].Equal(ip) {
t.Errorf("Expected IP %v for domain %q, got %v", ip, domain, ips[0])
}
}
// Test wildcard with mixed case
wildcardIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("*.Example.Com", wildcardIP)
if err != nil {
t.Fatalf("Failed to add mixed case wildcard: %v", err)
}
wildcardTestCases := []string{
"host.example.com.",
"HOST.EXAMPLE.COM.",
"Host.Example.Com.",
"HoSt.ExAmPlE.CoM.",
}
for _, domain := range wildcardTestCases {
ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
}
if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
t.Errorf("Expected IP %v for wildcard domain %q, got %v", wildcardIP, domain, ips[0])
}
}
// Test removal with different case
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
}
// Test HasRecord with different case
if !store.HasRecord("HOST.EXAMPLE.COM.", RecordTypeA) {
t.Error("Expected HasRecord to return true for mixed case wildcard match")
}
}
func TestPTRRecordIPv4(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record for IPv4
ip := net.ParseIP("192.168.1.1")
domain := "host.example.com."
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Test reverse DNS lookup
reverseDomain := "1.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to be found")
}
if result != domain {
t.Errorf("Expected domain %q, got %q", domain, result)
}
// Test HasPTRRecord
if !store.HasPTRRecord(reverseDomain) {
t.Error("Expected HasPTRRecord to return true")
}
// Test non-existent PTR record
_, ok = store.GetPTRRecord("2.1.168.192.in-addr.arpa.")
if ok {
t.Error("Expected PTR record not to be found for different IP")
}
}
func TestPTRRecordIPv6(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record for IPv6
ip := net.ParseIP("2001:db8::1")
domain := "ipv6host.example.com."
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Test reverse DNS lookup
// 2001:db8::1 = 2001:0db8:0000:0000:0000:0000:0000:0001
// Reverse: 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
reverseDomain := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected IPv6 PTR record to be found")
}
if result != domain {
t.Errorf("Expected domain %q, got %q", domain, result)
}
// Test HasPTRRecord
if !store.HasPTRRecord(reverseDomain) {
t.Error("Expected HasPTRRecord to return true for IPv6")
}
}
func TestRemovePTRRecord(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record
ip := net.ParseIP("10.0.0.1")
domain := "test.example.com."
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Verify it exists
reverseDomain := "1.0.0.10.in-addr.arpa."
_, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to exist before removal")
}
// Remove PTR record
store.RemovePTRRecord(ip)
// Verify it's gone
_, ok = store.GetPTRRecord(reverseDomain)
if ok {
t.Error("Expected PTR record to be removed")
}
// Test HasPTRRecord after removal
if store.HasPTRRecord(reverseDomain) {
t.Error("Expected HasPTRRecord to return false after removal")
}
}
func TestIPToReverseDNS(t *testing.T) {
tests := []struct {
name string
ip string
expected string
}{
{
name: "IPv4 simple",
ip: "192.168.1.1",
expected: "1.1.168.192.in-addr.arpa.",
},
{
name: "IPv4 localhost",
ip: "127.0.0.1",
expected: "1.0.0.127.in-addr.arpa.",
},
{
name: "IPv4 with zeros",
ip: "10.0.0.1",
expected: "1.0.0.10.in-addr.arpa.",
},
{
name: "IPv6 simple",
ip: "2001:db8::1",
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
},
{
name: "IPv6 localhost",
ip: "::1",
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("Failed to parse IP: %s", tt.ip)
}
result := IPToReverseDNS(ip)
if result != tt.expected {
t.Errorf("IPToReverseDNS(%s) = %q, want %q", tt.ip, result, tt.expected)
}
})
}
}
func TestReverseDNSToIP(t *testing.T) {
tests := []struct {
name string
reverseDNS string
expectedIP string
shouldMatch bool
}{
{
name: "IPv4 simple",
reverseDNS: "1.1.168.192.in-addr.arpa.",
expectedIP: "192.168.1.1",
shouldMatch: true,
},
{
name: "IPv4 localhost",
reverseDNS: "1.0.0.127.in-addr.arpa.",
expectedIP: "127.0.0.1",
shouldMatch: true,
},
{
name: "IPv6 simple",
reverseDNS: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
expectedIP: "2001:db8::1",
shouldMatch: true,
},
{
name: "Invalid IPv4 format",
reverseDNS: "1.1.168.in-addr.arpa.",
expectedIP: "",
shouldMatch: false,
},
{
name: "Invalid IPv6 format",
reverseDNS: "1.0.0.0.ip6.arpa.",
expectedIP: "",
shouldMatch: false,
},
{
name: "Not a reverse DNS domain",
reverseDNS: "example.com.",
expectedIP: "",
shouldMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := reverseDNSToIP(tt.reverseDNS)
if tt.shouldMatch {
if result == nil {
t.Errorf("reverseDNSToIP(%q) returned nil, expected IP", tt.reverseDNS)
return
}
expectedIP := net.ParseIP(tt.expectedIP)
if !result.Equal(expectedIP) {
t.Errorf("reverseDNSToIP(%q) = %v, want %v", tt.reverseDNS, result, expectedIP)
}
} else {
if result != nil {
t.Errorf("reverseDNSToIP(%q) = %v, expected nil", tt.reverseDNS, result)
}
}
})
}
}
func TestPTRRecordCaseInsensitive(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record with mixed case domain
ip := net.ParseIP("192.168.1.1")
domain := "MyHost.Example.Com"
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Test lookup with different cases in reverse DNS format
reverseDomain := "1.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to be found")
}
// Domain should be normalized to lowercase
if result != "myhost.example.com." {
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
}
// Test with uppercase reverse DNS
reverseDomainUpper := "1.1.168.192.IN-ADDR.ARPA."
result, ok = store.GetPTRRecord(reverseDomainUpper)
if !ok {
t.Error("Expected PTR record to be found with uppercase reverse DNS")
}
if result != "myhost.example.com." {
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
}
}
func TestClearPTRRecords(t *testing.T) {
store := NewDNSRecordStore()
// Add some PTR records
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.2")
store.AddPTRRecord(ip1, "host1.example.com.")
store.AddPTRRecord(ip2, "host2.example.com.")
// Add some A records too
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"))
// Verify PTR records exist
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
t.Error("Expected PTR record to exist before clear")
}
// Clear all records
store.Clear()
// Verify PTR records are gone
if store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
t.Error("Expected PTR record to be cleared")
}
if store.HasPTRRecord("2.1.168.192.in-addr.arpa.") {
t.Error("Expected PTR record to be cleared")
}
// Verify A records are also gone
if store.HasRecord("test.example.com.", RecordTypeA) {
t.Error("Expected A record to be cleared")
}
}
func TestAutomaticPTRRecordOnAdd(t *testing.T) {
store := NewDNSRecordStore()
// Add an A record - should automatically add PTR record
domain := "host.example.com."
ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip)
if err != nil {
t.Fatalf("Failed to add A record: %v", err)
}
// Verify PTR record was automatically created
reverseDomain := "100.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to be automatically created")
}
if result != domain {
t.Errorf("Expected PTR to point to %q, got %q", domain, result)
}
// Add AAAA record - should also automatically add PTR record
domain6 := "ipv6host.example.com."
ip6 := net.ParseIP("2001:db8::1")
err = store.AddRecord(domain6, ip6)
if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err)
}
// Verify IPv6 PTR record was automatically created
reverseDomain6 := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
result6, ok := store.GetPTRRecord(reverseDomain6)
if !ok {
t.Error("Expected IPv6 PTR record to be automatically created")
}
if result6 != domain6 {
t.Errorf("Expected PTR to point to %q, got %q", domain6, result6)
}
}
func TestAutomaticPTRRecordOnRemove(t *testing.T) {
store := NewDNSRecordStore()
// Add an A record (with automatic PTR)
domain := "host.example.com."
ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain, ip)
// Verify PTR exists
reverseDomain := "100.1.168.192.in-addr.arpa."
if !store.HasPTRRecord(reverseDomain) {
t.Error("Expected PTR record to exist after adding A record")
}
// Remove the A record
store.RemoveRecord(domain, ip)
// Verify PTR was automatically removed
if store.HasPTRRecord(reverseDomain) {
t.Error("Expected PTR record to be automatically removed")
}
// Verify A record is also gone
ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected A record to be removed, got %d records", len(ips))
}
}
func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
store := NewDNSRecordStore()
// Add multiple IPs for the same domain
domain := "host.example.com."
ip1 := net.ParseIP("192.168.1.100")
ip2 := net.ParseIP("192.168.1.101")
store.AddRecord(domain, ip1)
store.AddRecord(domain, ip2)
// Verify both PTR records exist
reverseDomain1 := "100.1.168.192.in-addr.arpa."
reverseDomain2 := "101.1.168.192.in-addr.arpa."
if !store.HasPTRRecord(reverseDomain1) {
t.Error("Expected first PTR record to exist")
}
if !store.HasPTRRecord(reverseDomain2) {
t.Error("Expected second PTR record to exist")
}
// Remove all records for the domain
store.RemoveRecord(domain, nil)
// Verify both PTR records were removed
if store.HasPTRRecord(reverseDomain1) {
t.Error("Expected first PTR record to be removed")
}
if store.HasPTRRecord(reverseDomain2) {
t.Error("Expected second PTR record to be removed")
}
}
func TestNoPTRForWildcardRecords(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard record - should NOT create PTR record
domain := "*.example.com."
ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Verify no PTR record was created
reverseDomain := "100.1.168.192.in-addr.arpa."
_, ok := store.GetPTRRecord(reverseDomain)
if ok {
t.Error("Expected no PTR record for wildcard domain")
}
// Verify wildcard A record exists
if !store.HasRecord("host.example.com.", RecordTypeA) {
t.Error("Expected wildcard A record to exist")
}
}
func TestPTRRecordOverwrite(t *testing.T) {
store := NewDNSRecordStore()
// Add first domain with IP
domain1 := "host1.example.com."
ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain1, ip)
// Verify PTR points to first domain
reverseDomain := "100.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Fatal("Expected PTR record to exist")
}
if result != domain1 {
t.Errorf("Expected PTR to point to %q, got %q", domain1, result)
}
// Add second domain with same IP - should overwrite PTR
domain2 := "host2.example.com."
store.AddRecord(domain2, ip)
// Verify PTR now points to second domain (last one added)
result, ok = store.GetPTRRecord(reverseDomain)
if !ok {
t.Fatal("Expected PTR record to still exist")
}
if result != domain2 {
t.Errorf("Expected PTR to point to %q (overwritten), got %q", domain2, result)
}
// Remove first domain - PTR should remain pointing to second domain
store.RemoveRecord(domain1, ip)
result, ok = store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to still exist after removing first domain")
}
if result != domain2 {
t.Errorf("Expected PTR to still point to %q, got %q", domain2, result)
}
// Remove second domain - PTR should now be gone
store.RemoveRecord(domain2, ip)
_, ok = store.GetPTRRecord(reverseDomain)
if ok {
t.Error("Expected PTR record to be removed after removing second domain")
}
}

View File

@@ -0,0 +1,16 @@
//go:build android
package olm
import "net/netip"
// SetupDNSOverride is a no-op on Android
// Android handles DNS through the VpnService API at the Java/Kotlin layer
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
return nil
}
// RestoreDNSOverride is a no-op on Android
func RestoreDNSOverride() error {
return nil
}

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
// Uses scutil for DNS configuration
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
var err error
configurator, err = platform.NewDarwinDNSConfigurator()
if err != nil {
@@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
dnsProxy.GetProxyIP(),
proxyIp,
}
logger.Info("Setting DNS servers to: %v", newDNS)

View File

@@ -0,0 +1,15 @@
//go:build ios
package olm
import "net/netip"
// SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
return nil
}
// RestoreDNSOverride is a no-op on iOS as DNS configuration is handled by the system
func RestoreDNSOverride() error {
return nil
}

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
var err error
// Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability
@@ -32,7 +27,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using systemd-resolved DNS configurator")
return setDNS(dnsProxy, configurator)
return setDNS(proxyIp, configurator)
}
logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err)
@@ -40,7 +35,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using NetworkManager DNS configurator")
return setDNS(dnsProxy, configurator)
return setDNS(proxyIp, configurator)
}
logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err)
@@ -48,7 +43,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using resolvconf DNS configurator")
return setDNS(dnsProxy, configurator)
return setDNS(proxyIp, configurator)
}
logger.Warn("Failed to create resolvconf configurator: %v, falling back", err)
}
@@ -60,11 +55,11 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
}
logger.Info("Using file-based DNS configurator")
return setDNS(dnsProxy, configurator)
return setDNS(proxyIp, configurator)
}
// setDNS is a helper function to set DNS and log the results
func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error {
// Get current DNS servers before changing
currentDNS, err := conf.GetCurrentDNS()
if err != nil {
@@ -75,7 +70,7 @@ func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
dnsProxy.GetProxyIP(),
proxyIp,
}
logger.Info("Setting DNS servers to: %v", newDNS)

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
// Uses registry-based configuration (automatically extracts interface GUID)
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
var err error
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
if err != nil {
@@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
dnsProxy.GetProxyIP(),
proxyIp,
}
logger.Info("Setting DNS servers to: %v", newDNS)

View File

@@ -416,4 +416,4 @@ func (d *DarwinDNSConfigurator) clearState() error {
logger.Debug("Cleared DNS state file")
return nil
}
}

68
go.mod
View File

@@ -4,74 +4,32 @@ go 1.25
require (
github.com/Microsoft/go-winio v0.6.2
github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06
github.com/godbus/dbus/v5 v5.2.0
github.com/fosrl/newt v1.9.0
github.com/godbus/dbus/v5 v5.2.2
github.com/gorilla/websocket v1.5.3
github.com/miekg/dns v1.1.68
golang.org/x/sys v0.38.0
github.com/miekg/dns v1.1.70
golang.org/x/sys v0.40.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
software.sslmate.com/src/go-pkcs12 v0.6.0
software.sslmate.com/src/go-pkcs12 v0.7.0
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v0.3.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.2+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.4.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/client_golang v1.23.2 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/otlptranslator v0.0.2 // indirect
github.com/prometheus/procfs v0.17.0 // indirect
github.com/vishvananda/netlink v1.3.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect
go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0 // indirect
go.opentelemetry.io/otel v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/prometheus v0.60.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/otel/sdk v1.38.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/crypto v0.46.0 // indirect
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
golang.org/x/mod v0.30.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/time v0.12.0 // indirect
golang.org/x/tools v0.39.0 // indirect
golang.org/x/tools v0.40.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/grpc v1.76.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
// To be used ONLY for local development
// replace github.com/fosrl/newt => ../newt

136
go.sum
View File

@@ -1,123 +1,39 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/containerd/errdefs v0.3.0 h1:FSZgGOeK4yuT/+DnF07/Olde/q4KBoMsaamhXxIMDp4=
github.com/containerd/errdefs v0.3.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM=
github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01 h1:VpuI42l4enih//6IFFQDln/B7WukfMePxIRIpXsNe/0=
github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06 h1:xWuCn+gzX0W7bHs/cV/ykNBliisNzNomPR76E4M0dtI=
github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE=
github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc h1:GN2Lv3MGO7AS6PrRoT6yV5+wkrOpcszoIsO4+4ds248=
github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/otlptranslator v0.0.2 h1:+1CdeLVrRQ6Psmhnobldo0kTp96Rj80DRXRd5OSnMEQ=
github.com/prometheus/otlptranslator v0.0.2/go.mod h1:P8AwMgdD7XEr6QRUJ2QWLpiAZTgTE2UYgjlu3svompI=
github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0=
github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw=
github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA=
github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg=
go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0 h1:PeBoRj6af6xMI7qCupwFvTbbnd49V7n5YpG6pg8iDYQ=
go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0/go.mod h1:ingqBCtMCe8I4vpz/UVzCW6sxoqgZB37nao91mLQ3Bw=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 h1:vl9obrcoWVKp/lwl8tRE33853I8Xru9HFbw/skNeLs8=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0/go.mod h1:GAXRxmLJcVM3u22IjTg74zWBrRCKq8BnOqUVLodpcpw=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
go.opentelemetry.io/otel/exporters/prometheus v0.60.0 h1:cGtQxGvZbnrWdC2GyjZi0PDKVSLWP/Jocix3QWfXtbo=
go.opentelemetry.io/otel/exporters/prometheus v0.60.0/go.mod h1:hkd1EekxNo69PTV4OWFGZcKQiIqg0RfuWExcPKFvepk=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4=
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
@@ -126,19 +42,7 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdI
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
google.golang.org/genproto v0.0.0-20230920204549-e6e6cdab5c13 h1:vlzZttNJGVqTsRFU9AmdnrcO1Znh8Ew9kCD//yjigk0=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A=
google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=
software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
software.sslmate.com/src/go-pkcs12 v0.7.0 h1:Db8W44cB54TWD7stUFFSWxdfpdn6fZVcDl0w3R4RVM0=
software.sslmate.com/src/go-pkcs12 v0.7.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=

13
main.go
View File

@@ -10,7 +10,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates"
"github.com/fosrl/olm/olm"
olmpkg "github.com/fosrl/olm/olm"
)
func main() {
@@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
}
// Create a new olm.Config struct and copy values from the main config
olmConfig := olm.GlobalConfig{
olmConfig := olmpkg.OlmConfig{
LogLevel: config.LogLevel,
EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr,
@@ -219,15 +219,20 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
Agent: "Olm CLI",
OnExit: cancel, // Pass cancel function directly to trigger shutdown
OnTerminated: cancel,
PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE
}
olm, err := olmpkg.Init(ctx, olmConfig)
if err != nil {
logger.Fatal("Failed to initialize olm: %v", err)
}
olm.Init(ctx, olmConfig)
if err := olm.StartApi(); err != nil {
logger.Fatal("Failed to start API server: %v", err)
}
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
tunnelConfig := olm.TunnelConfig{
tunnelConfig := olmpkg.TunnelConfig{
Endpoint: config.Endpoint,
ID: config.ID,
Secret: config.Secret,

299
olm/connect.go Normal file
View File

@@ -0,0 +1,299 @@
package olm
import (
"encoding/json"
"fmt"
"os"
"runtime"
"strconv"
"strings"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
olmDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns"
dnsOverride "github.com/fosrl/olm/dns/override"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
)
// OlmErrorData represents the error data sent from the server
type OlmErrorData struct {
Code string `json:"code"`
Message string `json:"message"`
}
func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring connect message")
return
}
var wgData WgData
if o.registered {
logger.Info("Already connected. Ignoring new connection request.")
return
}
if o.stopRegister != nil {
o.stopRegister()
o.stopRegister = nil
}
if o.updateRegister != nil {
o.updateRegister = nil
}
// if there is an existing tunnel then close it
if o.dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
o.dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
o.tdev, err = func() (tun.Device, error) {
if o.tunnelConfig.FileDescriptorTun != 0 {
return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU)
}
ifName := o.tunnelConfig.InterfaceName
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
ifName, err = network.FindUnusedUTUN()
if err != nil {
return nil, err
}
}
return tun.CreateTUN(ifName, o.tunnelConfig.MTU)
}()
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
return
}
// if config.FileDescriptorTun == 0 {
if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
o.tunnelConfig.InterfaceName = realInterfaceName
}
// }
// Wrap TUN device with packet filter for DNS proxy
o.middleDev = olmDevice.NewMiddleDevice(o.tdev)
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
// Use filtered device instead of raw TUN device
o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger))
if o.tunnelConfig.EnableUAPI {
fileUAPI, err := func() (*os.File, error) {
if o.tunnelConfig.FileDescriptorUAPI != 0 {
fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err)
}
return os.NewFile(uintptr(fd), ""), nil
}
return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName)
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() {
for {
conn, err := o.uapiListener.Accept()
if err != nil {
return
}
go o.dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
}
if err = o.dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
// Extract interface IP (strip CIDR notation if present)
interfaceIP := wgData.TunnelIP
if strings.Contains(interfaceIP, "/") {
interfaceIP = strings.Split(interfaceIP, "/")[0]
}
// Create and start DNS proxy
o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP)
if err != nil {
logger.Error("Failed to create DNS proxy: %v", err)
}
if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil {
logger.Error("Failed to o.tunnelConfigure interface: %v", err)
}
if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet
logger.Error("Failed to add route for utility subnet: %v", err)
}
// Create peer manager with integrated peer monitoring
o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
Device: o.dev,
DNSProxy: o.dnsProxy,
InterfaceName: o.tunnelConfig.InterfaceName,
PrivateKey: o.privateKey,
MiddleDev: o.middleDev,
LocalIP: interfaceIP,
SharedBind: o.sharedBind,
WSClient: o.websocket,
APIServer: o.apiServer,
})
for i := range wgData.Sites {
site := wgData.Sites[i]
var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {
siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
}
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
if err := o.peerManager.AddPeer(site); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
logger.Info("Configured peer %s", site.PublicKey)
}
o.peerManager.Start()
if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime
logger.Error("Failed to start DNS proxy: %v", err)
}
if o.tunnelConfig.OverrideDNS {
// Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
logger.Error("Failed to setup DNS override: %v", err)
return
}
network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()})
}
o.apiServer.SetRegistered(true)
o.registered = true
// Start ping monitor now that we are registered and connected
o.websocket.StartPingMonitor()
// Invoke onConnected callback if configured
if o.olmConfig.OnConnected != nil {
go o.olmConfig.OnConnected()
}
logger.Info("WireGuard device created.")
}
func (o *Olm) handleOlmError(msg websocket.WSMessage) {
logger.Debug("Received olm error message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring olm error message")
return
}
var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling olm error data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &errorData); err != nil {
logger.Error("Error unmarshaling olm error data: %v", err)
return
}
logger.Error("Olm error (code: %s): %s", errorData.Code, errorData.Message)
// Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
// Invoke onOlmError callback if configured
if o.olmConfig.OnOlmError != nil {
go o.olmConfig.OnOlmError(errorData.Code, errorData.Message)
}
}
func (o *Olm) handleTerminate(msg websocket.WSMessage) {
logger.Info("Received terminate message")
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring terminate message")
return
}
var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling terminate error data: %v", err)
} else {
if err := json.Unmarshal(jsonData, &errorData); err != nil {
logger.Error("Error unmarshaling terminate error data: %v", err)
} else {
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
if errorData.Code == "TERMINATED_INACTIVITY" {
logger.Info("Ignoring...")
return
}
// Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
}
}
o.apiServer.SetTerminated(true)
o.apiServer.SetConnectionStatus(false)
o.apiServer.SetRegistered(false)
o.apiServer.ClearPeerStatuses()
network.ClearNetworkSettings()
o.Close()
if o.olmConfig.OnTerminated != nil {
go o.olmConfig.OnTerminated()
}
}

365
olm/data.go Normal file
View File

@@ -0,0 +1,365 @@
package olm
import (
"encoding/json"
"time"
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
)
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring add-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var addSubnetsData peers.PeerAdd
if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil {
logger.Error("Error unmarshaling add-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId)
return
}
// Add new subnets
for _, subnet := range addSubnetsData.RemoteSubnets {
if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
}
}
// Add new aliases
for _, alias := range addSubnetsData.Aliases {
if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil {
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
}
}
}
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring remove-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeSubnetsData peers.RemovePeerData
if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil {
logger.Error("Error unmarshaling remove-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId)
return
}
// Remove subnets
for _, subnet := range removeSubnetsData.RemoteSubnets {
if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
}
}
// Remove aliases
for _, alias := range removeSubnetsData.Aliases {
if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil {
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
}
}
}
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring update-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateSubnetsData peers.UpdatePeerData
if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil {
logger.Error("Error unmarshaling update-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId)
return
}
// Add new subnets BEFORE removing old ones to preserve shared subnets
// This ensures that if an old and new subnet are the same on different peers,
// the route won't be temporarily removed
for _, subnet := range updateSubnetsData.NewRemoteSubnets {
if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
}
}
// Remove old subnets after new ones are added
for _, subnet := range updateSubnetsData.OldRemoteSubnets {
if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
}
}
// Add new aliases BEFORE removing old ones to preserve shared IP addresses
// This ensures that if an old and new alias share the same IP, the IP won't be
// temporarily removed from the allowed IPs list
for _, alias := range updateSubnetsData.NewAliases {
if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil {
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
}
}
// Remove old aliases after new ones are added
for _, alias := range updateSubnetsData.OldAliases {
if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil {
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
}
}
logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId)
}
// Handler for syncing peer configuration - reconciles expected state with actual state
func (o *Olm) handleSync(msg websocket.WSMessage) {
logger.Debug("Received sync message: %v", msg.Data)
if !o.registered {
logger.Warn("Not connected, ignoring sync request")
return
}
if o.peerManager == nil {
logger.Warn("Peer manager not initialized, ignoring sync request")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
var syncData SyncData
if err := json.Unmarshal(jsonData, &syncData); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
// Sync exit nodes for hole punching
o.syncExitNodes(syncData.ExitNodes)
// Build a map of expected peers from the incoming data
expectedPeers := make(map[int]peers.SiteConfig)
for _, site := range syncData.Sites {
expectedPeers[site.SiteId] = site
}
// Get all current peers
currentPeers := o.peerManager.GetAllPeers()
currentPeerMap := make(map[int]peers.SiteConfig)
for _, peer := range currentPeers {
currentPeerMap[peer.SiteId] = peer
}
// Find peers to remove (in current but not in expected)
for siteId := range currentPeerMap {
if _, exists := expectedPeers[siteId]; !exists {
logger.Info("Sync: Removing peer for site %d (no longer in expected config)", siteId)
if err := o.peerManager.RemovePeer(siteId); err != nil {
logger.Error("Sync: Failed to remove peer %d: %v", siteId, err)
} else {
// Remove any exit nodes associated with this peer from hole punching
if o.holePunchManager != nil {
removed := o.holePunchManager.RemoveExitNodesByPeer(siteId)
if removed > 0 {
logger.Info("Sync: Removed %d exit nodes associated with peer %d from hole punch rotation", removed, siteId)
}
}
}
}
}
// Find peers to add (in expected but not in current) and peers to update
for siteId, expectedSite := range expectedPeers {
if _, exists := currentPeerMap[siteId]; !exists {
// New peer - add it using the add flow (with holepunch)
logger.Info("Sync: Adding new peer for site %d", siteId)
o.holePunchManager.TriggerHolePunch()
// // TODO: do we need to send the message to the cloud to add the peer that way?
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
// logger.Error("Sync: Failed to add peer %d: %v", siteId, err)
// } else {
// logger.Info("Sync: Successfully added peer for site %d", siteId)
// }
// add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
}, 1*time.Second, 10)
} else {
// Existing peer - check if update is needed
currentSite := currentPeerMap[siteId]
needsUpdate := false
// Check if any fields have changed
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
needsUpdate = true
}
if expectedSite.RelayEndpoint != "" && expectedSite.RelayEndpoint != currentSite.RelayEndpoint {
needsUpdate = true
}
if expectedSite.PublicKey != "" && expectedSite.PublicKey != currentSite.PublicKey {
needsUpdate = true
}
if expectedSite.ServerIP != "" && expectedSite.ServerIP != currentSite.ServerIP {
needsUpdate = true
}
if expectedSite.ServerPort != 0 && expectedSite.ServerPort != currentSite.ServerPort {
needsUpdate = true
}
// Check remote subnets
if expectedSite.RemoteSubnets != nil && !slicesEqual(expectedSite.RemoteSubnets, currentSite.RemoteSubnets) {
needsUpdate = true
}
// Check aliases
if expectedSite.Aliases != nil && !aliasesEqual(expectedSite.Aliases, currentSite.Aliases) {
needsUpdate = true
}
if needsUpdate {
logger.Info("Sync: Updating peer for site %d", siteId)
// Merge expected data with current data
siteConfig := currentSite
if expectedSite.Endpoint != "" {
siteConfig.Endpoint = expectedSite.Endpoint
}
if expectedSite.RelayEndpoint != "" {
siteConfig.RelayEndpoint = expectedSite.RelayEndpoint
}
if expectedSite.PublicKey != "" {
siteConfig.PublicKey = expectedSite.PublicKey
}
if expectedSite.ServerIP != "" {
siteConfig.ServerIP = expectedSite.ServerIP
}
if expectedSite.ServerPort != 0 {
siteConfig.ServerPort = expectedSite.ServerPort
}
if expectedSite.RemoteSubnets != nil {
siteConfig.RemoteSubnets = expectedSite.RemoteSubnets
}
if expectedSite.Aliases != nil {
siteConfig.Aliases = expectedSite.Aliases
}
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
logger.Error("Sync: Failed to update peer %d: %v", siteId, err)
} else {
// If the endpoint changed, trigger holepunch to refresh NAT mappings
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
logger.Info("Sync: Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", siteId)
o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval()
}
logger.Info("Sync: Successfully updated peer for site %d", siteId)
}
}
}
}
logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers))
}
// syncExitNodes reconciles the expected exit nodes with the current ones in the hole punch manager
func (o *Olm) syncExitNodes(expectedExitNodes []SyncExitNode) {
if o.holePunchManager == nil {
logger.Warn("Hole punch manager not initialized, skipping exit node sync")
return
}
// Build a map of expected exit nodes by endpoint
expectedExitNodeMap := make(map[string]SyncExitNode)
for _, exitNode := range expectedExitNodes {
expectedExitNodeMap[exitNode.Endpoint] = exitNode
}
// Get current exit nodes from hole punch manager
currentExitNodes := o.holePunchManager.GetExitNodes()
currentExitNodeMap := make(map[string]holepunch.ExitNode)
for _, exitNode := range currentExitNodes {
currentExitNodeMap[exitNode.Endpoint] = exitNode
}
// Find exit nodes to remove (in current but not in expected)
for endpoint := range currentExitNodeMap {
if _, exists := expectedExitNodeMap[endpoint]; !exists {
logger.Info("Sync: Removing exit node %s (no longer in expected config)", endpoint)
o.holePunchManager.RemoveExitNode(endpoint)
}
}
// Find exit nodes to add (in expected but not in current)
for endpoint, expectedExitNode := range expectedExitNodeMap {
if _, exists := currentExitNodeMap[endpoint]; !exists {
logger.Info("Sync: Adding new exit node %s", endpoint)
relayPort := expectedExitNode.RelayPort
if relayPort == 0 {
relayPort = 21820 // default relay port
}
hpExitNode := holepunch.ExitNode{
Endpoint: expectedExitNode.Endpoint,
RelayPort: relayPort,
PublicKey: expectedExitNode.PublicKey,
SiteIds: expectedExitNode.SiteIds,
}
if o.holePunchManager.AddExitNode(hpExitNode) {
logger.Info("Sync: Successfully added exit node %s", endpoint)
}
o.holePunchManager.TriggerHolePunch()
}
}
logger.Info("Sync exit nodes completed: processed %d expected exit nodes, had %d current exit nodes", len(expectedExitNodeMap), len(currentExitNodeMap))
}

1394
olm/olm.go

File diff suppressed because it is too large Load Diff

282
olm/peer.go Normal file
View File

@@ -0,0 +1,282 @@
package olm
import (
"encoding/json"
"time"
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
)
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring add-peer message")
return
}
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var siteConfig peers.SiteConfig
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfig); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
}
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
logger.Debug("Received remove-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring remove-peer message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeData peers.PeerRemove
if err := json.Unmarshal(jsonData, &removeData); err != nil {
logger.Error("Error unmarshaling remove data: %v", err)
return
}
if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil {
logger.Error("Failed to remove peer: %v", err)
return
}
// Remove any exit nodes associated with this peer from hole punching
if o.holePunchManager != nil {
removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId)
if removed > 0 {
logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId)
}
}
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
}
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
logger.Debug("Received update-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring update-peer message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateData peers.SiteConfig
if err := json.Unmarshal(jsonData, &updateData); err != nil {
logger.Error("Error unmarshaling update data: %v", err)
return
}
// Get existing peer from PeerManager
existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId)
if !exists {
logger.Warn("Peer with site ID %d not found", updateData.SiteId)
return
}
// Create updated site config by merging with existing data
siteConfig := existingPeer
if updateData.Endpoint != "" {
siteConfig.Endpoint = updateData.Endpoint
}
if updateData.RelayEndpoint != "" {
siteConfig.RelayEndpoint = updateData.RelayEndpoint
}
if updateData.PublicKey != "" {
siteConfig.PublicKey = updateData.PublicKey
}
if updateData.ServerIP != "" {
siteConfig.ServerIP = updateData.ServerIP
}
if updateData.ServerPort != 0 {
siteConfig.ServerPort = updateData.ServerPort
}
if updateData.RemoteSubnets != nil {
siteConfig.RemoteSubnets = updateData.RemoteSubnets
}
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
// If the endpoint changed, trigger holepunch to refresh NAT mappings
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
_ = o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval()
}
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
}
func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
logger.Debug("Received relay-peer message: %v", msg.Data)
// Check if peerManager is still valid (may be nil during shutdown)
if o.peerManager == nil {
logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData peers.RelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err)
return
}
// Update HTTP server to mark this peer as using relay
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true)
o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort)
}
func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
logger.Debug("Received unrelay-peer message: %v", msg.Data)
// Check if peerManager is still valid (may be nil during shutdown)
if o.peerManager == nil {
logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData peers.UnRelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false)
o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
}
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
logger.Debug("Received peer-handshake message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring peer-handshake message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling handshake data: %v", err)
return
}
var handshakeData struct {
SiteId int `json:"siteId"`
ExitNode struct {
PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"`
RelayPort uint16 `json:"relayPort"`
} `json:"exitNode"`
}
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
logger.Error("Error unmarshaling handshake data: %v", err)
return
}
// Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
if exists {
logger.Warn("Peer with site ID %d already added", handshakeData.SiteId)
return
}
relayPort := handshakeData.ExitNode.RelayPort
if relayPort == 0 {
relayPort = 21820 // default relay port
}
siteId := handshakeData.SiteId
exitNode := holepunch.ExitNode{
Endpoint: handshakeData.ExitNode.Endpoint,
RelayPort: relayPort,
PublicKey: handshakeData.ExitNode.PublicKey,
SiteIds: []int{siteId},
}
added := o.holePunchManager.AddExitNode(exitNode)
if added {
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
} else {
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
}
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
}, 1*time.Second, 10)
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
}

View File

@@ -12,9 +12,22 @@ type WgData struct {
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
}
type GlobalConfig struct {
type SyncData struct {
Sites []peers.SiteConfig `json:"sites"`
ExitNodes []SyncExitNode `json:"exitNodes"`
}
type SyncExitNode struct {
Endpoint string `json:"endpoint"`
RelayPort uint16 `json:"relayPort"`
PublicKey string `json:"publicKey"`
SiteIds []int `json:"siteIds"`
}
type OlmConfig struct {
// Logging
LogLevel string
LogLevel string
LogFilePath string
// HTTP server
EnableAPI bool
@@ -23,11 +36,17 @@ type GlobalConfig struct {
Version string
Agent string
WakeUpDebounce time.Duration
// Debugging
PprofAddr string // Address to serve pprof on (e.g., "localhost:6060")
// Callbacks
OnRegistered func()
OnConnected func()
OnTerminated func()
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
OnOlmError func(code string, message string) // Called when registration fails
OnExit func() // Called when exit is requested via API
}
@@ -63,5 +82,8 @@ type TunnelConfig struct {
OverrideDNS bool
TunnelDNS bool
InitialFingerprint map[string]any
InitialPostures map[string]any
DisableRelay bool
}

View File

@@ -1,55 +1,47 @@
package olm
import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/olm/websocket"
"github.com/fosrl/olm/peers"
)
func sendPing(olm *websocket.Client) error {
err := olm.SendMessage("olm/ping", map[string]interface{}{
"timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken,
})
if err != nil {
logger.Error("Failed to send ping message: %v", err)
return err
// slicesEqual compares two string slices for equality (order-independent)
func slicesEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
logger.Debug("Sent ping message")
return nil
}
func keepSendingPing(olm *websocket.Client) {
// Send ping immediately on startup
if err := sendPing(olm); err != nil {
logger.Error("Failed to send initial ping: %v", err)
} else {
logger.Info("Sent initial ping message")
// Create a map to count occurrences in slice a
counts := make(map[string]int)
for _, v := range a {
counts[v]++
}
// Set up ticker for one minute intervals
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopPing:
logger.Info("Stopping ping messages")
return
case <-ticker.C:
if err := sendPing(olm); err != nil {
logger.Error("Failed to send periodic ping: %v", err)
}
// Check if slice b has the same elements
for _, v := range b {
counts[v]--
if counts[v] < 0 {
return false
}
}
return true
}
func GetNetworkSettingsJSON() (string, error) {
return network.GetJSON()
}
func GetNetworkSettingsIncrementor() int {
return network.GetIncrementor()
// aliasesEqual compares two Alias slices for equality (order-independent)
func aliasesEqual(a, b []peers.Alias) bool {
if len(a) != len(b) {
return false
}
// Create a map to count occurrences in slice a (using alias+address as key)
counts := make(map[string]int)
for _, v := range a {
key := v.Alias + "|" + v.AliasAddress
counts[key]++
}
// Check if slice b has the same elements
for _, v := range b {
key := v.Alias + "|" + v.AliasAddress
counts[key]--
if counts[key] < 0 {
return false
}
}
return true
}

View File

@@ -50,6 +50,8 @@ type PeerManager struct {
// key is the CIDR string, value is a set of siteIds that want this IP
allowedIPClaims map[string]map[int]bool
APIServer *api.API
PersistentKeepalive int
}
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
@@ -84,6 +86,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
return peer, ok
}
// GetPeerMonitor returns the internal peer monitor instance
func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor {
pm.mu.RLock()
defer pm.mu.RUnlock()
return pm.peerMonitor
}
func (pm *PeerManager) GetAllPeers() []SiteConfig {
pm.mu.RLock()
defer pm.mu.RUnlock()
@@ -120,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
return err
}
@@ -159,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
return nil
}
// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once
// without recreating them. Returns a map of siteId to error for any peers that failed to update.
func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error {
pm.mu.RLock()
defer pm.mu.RUnlock()
pm.PersistentKeepalive = interval
errors := make(map[int]error)
for siteId, peer := range pm.peers {
err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval)
if err != nil {
errors[siteId] = err
}
}
if len(errors) == 0 {
return nil
}
return errors
}
func (pm *PeerManager) RemovePeer(siteId int) error {
pm.mu.Lock()
defer pm.mu.Unlock()
@@ -238,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
wgConfig := promotedPeer
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}
@@ -314,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
return err
}
@@ -324,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
promotedWgConfig := promotedPeer
promotedWgConfig.AllowedIps = promotedOwnedIPs
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}

View File

@@ -31,8 +31,7 @@ type PeerMonitor struct {
monitors map[int]*Client
mutex sync.Mutex
running bool
interval time.Duration
timeout time.Duration
timeout time.Duration
maxAttempts int
wsClient *websocket.Client
@@ -42,7 +41,7 @@ type PeerMonitor struct {
stack *stack.Stack
ep *channel.Endpoint
activePorts map[uint16]bool
portsLock sync.Mutex
portsLock sync.RWMutex
nsCtx context.Context
nsCancel context.CancelFunc
nsWg sync.WaitGroup
@@ -50,17 +49,26 @@ type PeerMonitor struct {
// Holepunch testing fields
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchInterval time.Duration
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
// Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
holepunchMaxAttempts int // max consecutive failures before triggering relay
holepunchFailures map[int]int // siteID -> consecutive failure count
// Exponential backoff fields for holepunch monitor
defaultHolepunchMinInterval time.Duration // Minimum interval (initial)
defaultHolepunchMaxInterval time.Duration
holepunchMinInterval time.Duration // Minimum interval (initial)
holepunchMaxInterval time.Duration // Maximum interval (cap for backoff)
holepunchBackoffMultiplier float64 // Multiplier for each stable check
holepunchStableCount map[int]int // siteID -> consecutive stable status count
holepunchCurrentInterval time.Duration // Current interval with backoff applied
// Rapid initial test fields
rapidTestInterval time.Duration // interval between rapid test attempts
rapidTestTimeout time.Duration // timeout for each rapid test attempt
@@ -78,7 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
monitors: make(map[int]*Client),
interval: 2 * time.Second, // Default check interval (faster)
timeout: 3 * time.Second,
maxAttempts: 3,
wsClient: wsClient,
@@ -88,7 +95,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
nsCtx: ctx,
nsCancel: cancel,
sharedBind: sharedBind,
holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds
holepunchTimeout: 2 * time.Second, // Faster timeout
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
@@ -101,6 +107,15 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
apiServer: apiServer,
wgConnectionStatus: make(map[int]bool),
// Exponential backoff settings for holepunch monitor
defaultHolepunchMinInterval: 2 * time.Second,
defaultHolepunchMaxInterval: 30 * time.Second,
holepunchMinInterval: 2 * time.Second,
holepunchMaxInterval: 30 * time.Second,
holepunchBackoffMultiplier: 1.5,
holepunchStableCount: make(map[int]int),
holepunchCurrentInterval: 2 * time.Second,
holepunchUpdateChan: make(chan struct{}, 1),
}
if err := pm.initNetstack(); err != nil {
@@ -116,41 +131,75 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
}
// SetInterval changes how frequently peers are checked
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.interval = interval
// Update interval for all existing monitors
for _, client := range pm.monitors {
client.SetPacketInterval(interval)
client.SetPacketInterval(minInterval, maxInterval)
}
logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval)
}
// SetTimeout changes the timeout for waiting for responses
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
func (pm *PeerMonitor) ResetPeerInterval() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.timeout = timeout
// Update timeout for all existing monitors
// Update interval for all existing monitors
for _, client := range pm.monitors {
client.SetTimeout(timeout)
client.ResetPacketInterval()
}
}
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
pm.holepunchMinInterval = minInterval
pm.holepunchMaxInterval = maxInterval
// Reset current interval to the new minimum
pm.holepunchCurrentInterval = minInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.maxAttempts = attempts
return pm.holepunchMinInterval, pm.holepunchMaxInterval
}
// Update max attempts for all existing monitors
for _, client := range pm.monitors {
client.SetMaxAttempts(attempts)
func (pm *PeerMonitor) ResetPeerHolepunchInterval() {
pm.mutex.Lock()
pm.holepunchMinInterval = pm.defaultHolepunchMinInterval
pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval
pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
@@ -169,10 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
return err
}
client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts)
pm.monitors[siteID] = client
pm.holepunchEndpoints[siteID] = holepunchEndpoint
@@ -470,31 +515,59 @@ func (pm *PeerMonitor) stopHolepunchMonitor() {
logger.Info("Stopped holepunch connection monitor")
}
// runHolepunchMonitor runs the holepunch monitoring loop
// runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff
func (pm *PeerMonitor) runHolepunchMonitor() {
ticker := time.NewTicker(pm.holepunchInterval)
defer ticker.Stop()
pm.mutex.Lock()
pm.holepunchCurrentInterval = pm.holepunchMinInterval
pm.mutex.Unlock()
// Do initial check immediately
pm.checkHolepunchEndpoints()
timer := time.NewTimer(0) // Fire immediately for initial check
defer timer.Stop()
for {
select {
case <-pm.holepunchStopChan:
return
case <-ticker.C:
pm.checkHolepunchEndpoints()
case <-pm.holepunchUpdateChan:
// Interval settings changed, reset to minimum
pm.mutex.Lock()
pm.holepunchCurrentInterval = pm.holepunchMinInterval
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
case <-timer.C:
anyStatusChanged := pm.checkHolepunchEndpoints()
pm.mutex.Lock()
if anyStatusChanged {
// Reset to minimum interval on any status change
pm.holepunchCurrentInterval = pm.holepunchMinInterval
} else {
// Apply exponential backoff when stable
newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier)
if newInterval > pm.holepunchMaxInterval {
newInterval = pm.holepunchMaxInterval
}
pm.holepunchCurrentInterval = newInterval
}
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
}
}
}
// checkHolepunchEndpoints tests all holepunch endpoints
func (pm *PeerMonitor) checkHolepunchEndpoints() {
// Returns true if any endpoint's status changed
func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
pm.mutex.Lock()
// Check if we're still running before doing any work
if !pm.running {
pm.mutex.Unlock()
return
return false
}
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
for siteID, endpoint := range pm.holepunchEndpoints {
@@ -504,8 +577,10 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
maxAttempts := pm.holepunchMaxAttempts
pm.mutex.Unlock()
anyStatusChanged := false
for siteID, endpoint := range endpoints {
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
// logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint)
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
pm.mutex.Lock()
@@ -529,7 +604,9 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
pm.mutex.Unlock()
// Log status changes
if !exists || previousStatus != result.Success {
statusChanged := !exists || previousStatus != result.Success
if statusChanged {
anyStatusChanged = true
if result.Success {
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
} else {
@@ -562,7 +639,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
pm.mutex.Unlock()
if !stillRunning {
return // Stop processing if shutdown is in progress
return anyStatusChanged // Stop processing if shutdown is in progress
}
if !result.Success && !isRelayed && failureCount >= maxAttempts {
@@ -579,6 +656,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
}
}
}
return anyStatusChanged
}
// GetHolepunchStatus returns the current holepunch status for all endpoints
@@ -650,55 +729,55 @@ func (pm *PeerMonitor) Close() {
logger.Debug("PeerMonitor: Cleanup complete")
}
// TestPeer tests connectivity to a specific peer
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
pm.mutex.Lock()
client, exists := pm.monitors[siteID]
pm.mutex.Unlock()
// // TestPeer tests connectivity to a specific peer
// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
// pm.mutex.Lock()
// client, exists := pm.monitors[siteID]
// pm.mutex.Unlock()
if !exists {
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
}
// if !exists {
// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
// }
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
defer cancel()
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
// defer cancel()
connected, rtt := client.TestConnection(ctx)
return connected, rtt, nil
}
// connected, rtt := client.TestPeerConnection(ctx)
// return connected, rtt, nil
// }
// TestAllPeers tests connectivity to all peers
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
Connected bool
RTT time.Duration
} {
pm.mutex.Lock()
peers := make(map[int]*Client, len(pm.monitors))
for siteID, client := range pm.monitors {
peers[siteID] = client
}
pm.mutex.Unlock()
// // TestAllPeers tests connectivity to all peers
// func (pm *PeerMonitor) TestAllPeers() map[int]struct {
// Connected bool
// RTT time.Duration
// } {
// pm.mutex.Lock()
// peers := make(map[int]*Client, len(pm.monitors))
// for siteID, client := range pm.monitors {
// peers[siteID] = client
// }
// pm.mutex.Unlock()
results := make(map[int]struct {
Connected bool
RTT time.Duration
})
for siteID, client := range peers {
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
connected, rtt := client.TestConnection(ctx)
cancel()
// results := make(map[int]struct {
// Connected bool
// RTT time.Duration
// })
// for siteID, client := range peers {
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
// connected, rtt := client.TestPeerConnection(ctx)
// cancel()
results[siteID] = struct {
Connected bool
RTT time.Duration
}{
Connected: connected,
RTT: rtt,
}
}
// results[siteID] = struct {
// Connected bool
// RTT time.Duration
// }{
// Connected: connected,
// RTT: rtt,
// }
// }
return results
}
// return results
// }
// initNetstack initializes the gvisor netstack
func (pm *PeerMonitor) initNetstack() error {
@@ -770,9 +849,9 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool {
}
// Check if we are listening on this port
pm.portsLock.Lock()
pm.portsLock.RLock()
active := pm.activePorts[uint16(port)]
pm.portsLock.Unlock()
pm.portsLock.RUnlock()
if !active {
return false
@@ -803,13 +882,12 @@ func (pm *PeerMonitor) runPacketSender() {
defer pm.nsWg.Done()
logger.Debug("PeerMonitor: Packet sender goroutine started")
// Use a ticker to periodically check for packets without blocking indefinitely
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-pm.nsCtx.Done():
// Use blocking ReadContext instead of polling - much more CPU efficient
// This will block until a packet is available or context is cancelled
pkt := pm.ep.ReadContext(pm.nsCtx)
if pkt == nil {
// Context was cancelled or endpoint closed
logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets")
// Drain any remaining packets before exiting
for {
@@ -821,36 +899,28 @@ func (pm *PeerMonitor) runPacketSender() {
}
logger.Debug("PeerMonitor: Packet sender goroutine exiting")
return
case <-ticker.C:
// Try to read packets in batches
for i := 0; i < 10; i++ {
pkt := pm.ep.Read()
if pkt == nil {
break
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
pm.middleDev.InjectOutbound(buf)
}
pkt.DecRef()
}
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
pm.middleDev.InjectOutbound(buf)
}
pkt.DecRef()
}
}

View File

@@ -32,10 +32,19 @@ type Client struct {
monitorLock sync.Mutex
connLock sync.Mutex // Protects connection operations
shutdownCh chan struct{}
updateCh chan struct{}
packetInterval time.Duration
timeout time.Duration
maxAttempts int
dialer Dialer
// Exponential backoff fields
defaultMinInterval time.Duration // Default minimum interval (initial)
defaultMaxInterval time.Duration // Default maximum interval (cap for backoff)
minInterval time.Duration // Minimum interval (initial)
maxInterval time.Duration // Maximum interval (cap for backoff)
backoffMultiplier float64 // Multiplier for each stable check
stableCountToBackoff int // Number of stable checks before backing off
}
// Dialer is a function that creates a connection
@@ -50,28 +59,59 @@ type ConnectionStatus struct {
// NewClient creates a new connection test client
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
return &Client{
serverAddr: serverAddr,
shutdownCh: make(chan struct{}),
packetInterval: 2 * time.Second,
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
dialer: dialer,
serverAddr: serverAddr,
shutdownCh: make(chan struct{}),
updateCh: make(chan struct{}, 1),
packetInterval: 2 * time.Second,
defaultMinInterval: 2 * time.Second,
defaultMaxInterval: 30 * time.Second,
minInterval: 2 * time.Second,
maxInterval: 30 * time.Second,
backoffMultiplier: 1.5,
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
dialer: dialer,
}, nil
}
// SetPacketInterval changes how frequently packets are sent in monitor mode
func (c *Client) SetPacketInterval(interval time.Duration) {
c.packetInterval = interval
func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) {
c.monitorLock.Lock()
c.packetInterval = minInterval
c.minInterval = minInterval
c.maxInterval = maxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// Signal the goroutine to apply the new interval if running
if monitorRunning && updateCh != nil {
select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// SetTimeout changes the timeout for waiting for responses
func (c *Client) SetTimeout(timeout time.Duration) {
c.timeout = timeout
}
func (c *Client) ResetPacketInterval() {
c.monitorLock.Lock()
c.packetInterval = c.defaultMinInterval
c.minInterval = c.defaultMinInterval
c.maxInterval = c.defaultMaxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (c *Client) SetMaxAttempts(attempts int) {
c.maxAttempts = attempts
// Signal the goroutine to apply the new interval if running
if monitorRunning && updateCh != nil {
select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// UpdateServerAddr updates the server address and resets the connection
@@ -125,9 +165,10 @@ func (c *Client) ensureConnection() error {
return nil
}
// TestConnection checks if the connection to the server is working
// TestPeerConnection checks if the connection to the server is working
// Returns true if connected, false otherwise
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
// logger.Debug("wgtester: testing connection to peer %s", c.serverAddr)
if err := c.ensureConnection(); err != nil {
logger.Warn("Failed to ensure connection: %v", err)
return false, 0
@@ -138,6 +179,9 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
packet[4] = packetTypeRequest
// Reusable response buffer
responseBuffer := make([]byte, packetSize)
// Send multiple attempts as specified
for attempt := 0; attempt < c.maxAttempts; attempt++ {
select {
@@ -157,20 +201,17 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
return false, 0
}
// logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
_, err := c.conn.Write(packet)
if err != nil {
c.connLock.Unlock()
logger.Info("Error sending packet: %v", err)
continue
}
// logger.Debug("Successfully sent monitor packet")
// Set read deadline
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
// Wait for response
responseBuffer := make([]byte, packetSize)
n, err := c.conn.Read(responseBuffer)
c.connLock.Unlock()
@@ -211,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return c.TestConnection(ctx)
return c.TestPeerConnection(ctx)
}
// MonitorCallback is the function type for connection status change callbacks
@@ -238,28 +279,61 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
go func() {
var lastConnected bool
firstRun := true
stableCount := 0
currentInterval := c.minInterval
ticker := time.NewTicker(c.packetInterval)
defer ticker.Stop()
timer := time.NewTimer(currentInterval)
defer timer.Stop()
for {
select {
case <-c.shutdownCh:
return
case <-ticker.C:
case <-c.updateCh:
// Interval settings changed, reset to minimum
c.monitorLock.Lock()
currentInterval = c.minInterval
c.monitorLock.Unlock()
// Reset backoff state
stableCount = 0
timer.Reset(currentInterval)
logger.Debug("Packet interval updated, reset to %v", currentInterval)
case <-timer.C:
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
connected, rtt := c.TestConnection(ctx)
connected, rtt := c.TestPeerConnection(ctx)
cancel()
statusChanged := connected != lastConnected
// Callback if status changed or it's the first check
if connected != lastConnected || firstRun {
if statusChanged || firstRun {
callback(ConnectionStatus{
Connected: connected,
RTT: rtt,
})
lastConnected = connected
firstRun = false
// Reset backoff on status change
stableCount = 0
currentInterval = c.minInterval
} else {
// Status is stable, increment counter
stableCount++
// Apply exponential backoff after stable threshold
if stableCount >= c.stableCountToBackoff {
newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier)
if newInterval > c.maxInterval {
newInterval = c.maxInterval
}
currentInterval = newInterval
}
}
// Reset timer with current interval
timer.Reset(currentInterval)
}
}
}()

View File

@@ -11,7 +11,7 @@ import (
)
// ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error {
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
var endpoint string
if relay && siteConfig.RelayEndpoint != "" {
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
@@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
}
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
configBuilder.WriteString("persistent_keepalive_interval=5\n")
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive))
config := configBuilder.String()
logger.Debug("Configuring peer with config: %s", config)
@@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [
return nil
}
// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it
func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error {
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
configBuilder.WriteString("update_only=true\n")
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval))
config := configBuilder.String()
logger.Debug("Updating persistent keepalive for peer with config: %s", config)
err := dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err)
}
return nil
}
func formatEndpoint(endpoint string) string {
if strings.Contains(endpoint, ":") {
return endpoint

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -54,8 +55,9 @@ type ExitNode struct {
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
Type string `json:"type"`
Data interface{} `json:"data"`
ConfigVersion int `json:"configVersion,omitempty"`
}
// this is not json anymore
@@ -77,6 +79,7 @@ type Client struct {
handlersMux sync.RWMutex
reconnectInterval time.Duration
isConnected bool
isDisconnected bool // Flag to track if client is intentionally disconnected
reconnectMux sync.RWMutex
pingInterval time.Duration
pingTimeout time.Duration
@@ -87,6 +90,19 @@ type Client struct {
clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig
configNeedsSave bool // Flag to track if config needs to be saved
configVersion int // Latest config version received from server
configVersionMux sync.RWMutex
token string // Cached authentication token
exitNodes []ExitNode // Cached exit nodes from token response
tokenMux sync.RWMutex // Protects token and exitNodes
forceNewToken bool // Flag to force fetching a new token on next connection
processingMessage bool // Flag to track if a message is currently being processed
processingMux sync.RWMutex // Protects processingMessage
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
getPingData func() map[string]any // Callback to get additional ping data
pingStarted bool // Flag to track if ping monitor has been started
pingStartedMux sync.Mutex // Protects pingStarted
pingDone chan struct{} // Channel to stop the ping monitor independently
}
type ClientOption func(*Client)
@@ -122,6 +138,13 @@ func WithTLSConfig(config TLSConfig) ClientOption {
}
}
// WithPingDataProvider sets a callback to provide additional data for ping messages
func WithPingDataProvider(fn func() map[string]any) ClientOption {
return func(c *Client) {
c.getPingData = fn
}
}
func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
@@ -154,6 +177,7 @@ func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.
pingInterval: pingInterval,
pingTimeout: pingTimeout,
clientType: "olm",
pingDone: make(chan struct{}),
}
// Apply options before loading config
@@ -173,6 +197,9 @@ func (c *Client) GetConfig() *Config {
// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
if c.isDisconnected {
c.isDisconnected = false
}
go c.connectWithRetry()
return nil
}
@@ -205,9 +232,31 @@ func (c *Client) Close() error {
return nil
}
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
func (c *Client) Disconnect() error {
c.isDisconnected = true
c.setConnected(false)
// Stop the ping monitor
c.stopPingMonitor()
// Wait for any message currently being processed to complete
c.processingWg.Wait()
if c.conn != nil {
c.writeMux.Lock()
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.writeMux.Unlock()
err := c.conn.Close()
c.conn = nil
return err
}
return nil
}
// SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessage(messageType string, data interface{}) error {
if c.conn == nil {
if c.isDisconnected || c.conn == nil {
return fmt.Errorf("not connected")
}
@@ -216,14 +265,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
Data: data,
}
logger.Debug("Sending message: %s, data: %+v", messageType, data)
logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
c.writeMux.Lock()
defer c.writeMux.Unlock()
return c.conn.WriteJSON(msg)
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
stopChan := make(chan struct{})
updateChan := make(chan interface{})
var dataMux sync.Mutex
@@ -231,30 +280,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
go func() {
count := 0
maxAttempts := 10
err := c.SendMessage(messageType, currentData) // Send immediately
if err != nil {
logger.Error("Failed to send initial message: %v", err)
send := func() {
if c.isDisconnected || c.conn == nil {
return
}
err := c.SendMessage(messageType, currentData)
if err != nil {
logger.Error("websocket: Failed to send message: %v", err)
}
count++
}
count++
send() // Send immediately
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if count >= maxAttempts {
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
if maxAttempts != -1 && count >= maxAttempts {
logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return
}
dataMux.Lock()
err = c.SendMessage(messageType, currentData)
send()
dataMux.Unlock()
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case newData := <-updateChan:
dataMux.Lock()
// Merge newData into currentData if both are maps
@@ -277,6 +328,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
case <-stopChan:
return
}
// Suspend sending if disconnected
for c.isDisconnected {
select {
case <-stopChan:
return
case <-time.After(500 * time.Millisecond):
}
}
}
}()
return func() {
@@ -323,7 +382,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
tlsConfig = &tls.Config{}
}
tlsConfig.InsecureSkipVerify = true
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
tokenData := map[string]interface{}{
@@ -352,7 +411,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// print out the request for debugging
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
// Make the request
client := &http.Client{}
@@ -369,7 +428,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
// Return AuthError for 401/403 status codes
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
@@ -385,7 +444,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("Failed to decode token response.")
logger.Error("websocket: Failed to decode token response.")
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
}
@@ -397,7 +456,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
return "", nil, fmt.Errorf("received empty token from server")
}
logger.Debug("Received token: %s", tokenResp.Data.Token)
logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
}
@@ -411,7 +470,8 @@ func (c *Client) connectWithRetry() {
err := c.establishConnection()
if err != nil {
// Check if this is an auth error (401/403)
if authErr, ok := err.(*AuthError); ok {
var authErr *AuthError
if errors.As(err, &authErr) {
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
// Trigger auth error callback if set (this should terminate the tunnel)
if c.onAuthError != nil {
@@ -422,7 +482,7 @@ func (c *Client) connectWithRetry() {
continue
}
// For other errors (5xx, network issues), continue retrying
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
time.Sleep(c.reconnectInterval)
continue
}
@@ -432,15 +492,25 @@ func (c *Client) connectWithRetry() {
}
func (c *Client) establishConnection() error {
// Get token for authentication
token, exitNodes, err := c.getToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
// Get token for authentication - reuse cached token unless forced to get new one
c.tokenMux.Lock()
needNewToken := c.token == "" || c.forceNewToken
if needNewToken {
token, exitNodes, err := c.getToken()
if err != nil {
c.tokenMux.Unlock()
return fmt.Errorf("failed to get token: %w", err)
}
c.token = token
c.exitNodes = exitNodes
c.forceNewToken = false
if c.onTokenUpdate != nil {
c.onTokenUpdate(token, exitNodes)
if c.onTokenUpdate != nil {
c.onTokenUpdate(token, exitNodes)
}
}
token := c.token
c.tokenMux.Unlock()
// Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL)
@@ -475,7 +545,7 @@ func (c *Client) establishConnection() error {
// Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
logger.Info("Setting up TLS configuration for WebSocket connection")
logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
tlsConfig, err := c.setupTLS()
if err != nil {
return fmt.Errorf("failed to setup TLS configuration: %w", err)
@@ -489,25 +559,38 @@ func (c *Client) establishConnection() error {
dialer.TLSClientConfig = &tls.Config{}
}
dialer.TLSClientConfig.InsecureSkipVerify = true
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
conn, _, err := dialer.Dial(u.String(), nil)
conn, resp, err := dialer.Dial(u.String(), nil)
if err != nil {
// Check if this is an unauthorized error (401)
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized")
// Force getting a new token on next reconnect attempt
c.tokenMux.Lock()
c.forceNewToken = true
c.tokenMux.Unlock()
return &AuthError{
StatusCode: http.StatusUnauthorized,
Message: "WebSocket connection unauthorized",
}
}
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
c.conn = conn
c.setConnected(true)
// Start the ping monitor
go c.pingMonitor()
// Note: ping monitor is NOT started here - it will be started when
// StartPingMonitor() is called after registration completes
// Start the read pump with disconnect detection
go c.readPumpWithDisconnectDetection()
if c.onConnect != nil {
if err := c.onConnect(); err != nil {
logger.Error("OnConnect callback failed: %v", err)
logger.Error("websocket: OnConnect callback failed: %v", err)
}
}
@@ -520,9 +603,9 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Handle new separate certificate configuration
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
logger.Info("Loading separate certificate files for mTLS")
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
logger.Info("websocket: Loading separate certificate files for mTLS")
logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
// Load client certificate and key
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
@@ -533,7 +616,7 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Load CA certificates for remote validation if specified
if len(c.tlsConfig.CAFiles) > 0 {
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
caCertPool := x509.NewCertPool()
for _, caFile := range c.tlsConfig.CAFiles {
caCert, err := os.ReadFile(caFile)
@@ -559,13 +642,13 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Fallback to existing PKCS12 implementation for backward compatibility
if c.tlsConfig.PKCS12File != "" {
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
return c.setupPKCS12TLS()
}
// Legacy fallback using config.TlsClientCert
if c.config.TlsClientCert != "" {
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
return loadClientCertificate(c.config.TlsClientCert)
}
@@ -577,6 +660,59 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
return loadClientCertificate(c.tlsConfig.PKCS12File)
}
// sendPing sends a single ping message
func (c *Client) sendPing() {
if c.isDisconnected || c.conn == nil {
return
}
// Skip ping if a message is currently being processed
c.processingMux.RLock()
isProcessing := c.processingMessage
c.processingMux.RUnlock()
if isProcessing {
logger.Debug("websocket: Skipping ping, message is being processed")
return
}
// Send application-level ping with config version
c.configVersionMux.RLock()
configVersion := c.configVersion
c.configVersionMux.RUnlock()
pingData := map[string]any{
"timestamp": time.Now().Unix(),
"userToken": c.config.UserToken,
}
if c.getPingData != nil {
for k, v := range c.getPingData() {
pingData[k] = v
}
}
pingMsg := WSMessage{
Type: "olm/ping",
Data: pingData,
ConfigVersion: configVersion,
}
logger.Debug("websocket: Sending ping: %+v", pingMsg)
c.writeMux.Lock()
err := c.conn.WriteJSON(pingMsg)
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("websocket: Ping failed: %v", err)
c.reconnect()
return
}
}
}
// pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) pingMonitor() {
ticker := time.NewTicker(c.pingInterval)
@@ -586,29 +722,65 @@ func (c *Client) pingMonitor() {
select {
case <-c.done:
return
case <-c.pingDone:
return
case <-ticker.C:
if c.conn == nil {
return
}
c.writeMux.Lock()
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("Ping failed: %v", err)
c.reconnect()
return
}
}
c.sendPing()
}
}
}
// StartPingMonitor starts the ping monitor goroutine.
// This should be called after the client is registered and connected.
// It is safe to call multiple times - only the first call will start the monitor.
func (c *Client) StartPingMonitor() {
c.pingStartedMux.Lock()
defer c.pingStartedMux.Unlock()
if c.pingStarted {
return
}
c.pingStarted = true
// Create a new pingDone channel for this ping monitor instance
c.pingDone = make(chan struct{})
// Send an initial ping immediately
go func() {
c.sendPing()
c.pingMonitor()
}()
}
// stopPingMonitor stops the ping monitor goroutine if it's running.
func (c *Client) stopPingMonitor() {
c.pingStartedMux.Lock()
defer c.pingStartedMux.Unlock()
if !c.pingStarted {
return
}
// Close the pingDone channel to stop the monitor
close(c.pingDone)
c.pingStarted = false
}
// GetConfigVersion returns the current config version
func (c *Client) GetConfigVersion() int {
c.configVersionMux.RLock()
defer c.configVersionMux.RUnlock()
return c.configVersion
}
// setConfigVersion updates the config version if the new version is higher
func (c *Client) setConfigVersion(version int) {
c.configVersionMux.Lock()
defer c.configVersionMux.Unlock()
logger.Debug("websocket: setting config version to %d", version)
c.configVersion = version
}
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
func (c *Client) readPumpWithDisconnectDetection() {
defer func() {
@@ -633,26 +805,47 @@ func (c *Client) readPumpWithDisconnectDetection() {
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err != nil {
// Check if we're shutting down before logging error
// Check if we're shutting down or explicitly disconnected before logging error
select {
case <-c.done:
// Expected during shutdown, don't log as error
logger.Debug("WebSocket connection closed during shutdown")
logger.Debug("websocket: connection closed during shutdown")
return
default:
// Check if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: connection closed: client was explicitly disconnected")
return
}
// Unexpected error during normal operation
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Error("WebSocket read error: %v", err)
logger.Error("websocket: read error: %v", err)
} else {
logger.Debug("WebSocket connection closed: %v", err)
logger.Debug("websocket: connection closed: %v", err)
}
return // triggers reconnect via defer
}
}
// Update config version from incoming message
c.setConfigVersion(msg.ConfigVersion)
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
// Mark that we're processing a message
c.processingMux.Lock()
c.processingMessage = true
c.processingMux.Unlock()
c.processingWg.Add(1)
handler(msg)
// Mark that we're done processing
c.processingWg.Done()
c.processingMux.Lock()
c.processingMessage = false
c.processingMux.Unlock()
}
c.handlersMux.RUnlock()
}
@@ -666,6 +859,12 @@ func (c *Client) reconnect() {
c.conn = nil
}
// Don't reconnect if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
return
}
// Only reconnect if we're not shutting down
select {
case <-c.done:
@@ -683,7 +882,7 @@ func (c *Client) setConnected(status bool) {
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
func loadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("Loading tls-client-cert %s", p12Path)
logger.Info("websocket: Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path)
if err != nil {