Compare commits

...

94 Commits
1.3.0 ... main

Author SHA1 Message Date
Owen
5527bff671 Merge branch 'dev' 2026-02-06 15:17:21 -08:00
Owen
af973b2440 Support prt records 2026-02-06 15:17:01 -08:00
Owen
dd9bff9a4b Fix peer names clearing 2026-02-02 18:03:29 -08:00
Owen
1be5e454ba Default override dns to true
Ref #59
2026-02-02 10:03:22 -08:00
Owen
4850b1b332 Handle cross platform close
Former-commit-id: 89932bb736c7f4b3eb9bb2384b0cf6bd27872c1c
2026-01-31 17:50:31 -08:00
Owen
1ff74f7173 Dont go unregistered when low power mode
Former-commit-id: f55fc8fb39f8efc9d5438465f655dc2d734223c3
2026-01-31 17:15:30 -08:00
Owen
4a25a0d413 Dont go unregistered when low power mode
Former-commit-id: 0938564038
2026-01-31 16:58:05 -08:00
Owen
7fc3c7088e Lowercase all domains before matching
Former-commit-id: 8f8872aa47
2026-01-30 14:53:25 -08:00
Owen
1869e70894 Merge branch 'dev'
Former-commit-id: 43cc56a961
2026-01-30 10:58:00 -08:00
Owen
79783cc3dc Merge branch 'main' of github.com:fosrl/olm
Former-commit-id: 0b31f4e5d1
2026-01-30 10:57:40 -08:00
Owen
584298e3bd Fix terminate due to inactivity 2026-01-27 20:19:41 -08:00
miloschwartz
f683afa647 improve override-dns and tunnel-dns descriptions 2026-01-27 17:53:34 -08:00
Owen
ba2631d388 Prevent crashing on close before connect
Former-commit-id: ea461e0bfb
2026-01-23 14:47:54 -08:00
Owen Schwartz
6ae4e2b691 Merge pull request #87 from fosrl/dev
1.4.0

Former-commit-id: 1212217421
2026-01-23 10:25:03 -08:00
Owen
51eee9dcf5 Bump newt
Former-commit-id: f4885e9c4d
2026-01-23 10:23:42 -08:00
Owen
660e9e0e35 Merge branch 'main' into dev
Former-commit-id: b5580036d3
2026-01-23 10:22:21 -08:00
Owen
4ef6089053 Comment out local newt
Former-commit-id: c4ef1e724e
2026-01-23 10:19:38 -08:00
Owen
c4e297cc96 Handle properly stopping and starting the ping
Former-commit-id: 34c7717767
2026-01-20 11:30:06 -08:00
Owen
e3f5497176 Add stale bot
Former-commit-id: 313dee9ba8
2026-01-19 17:12:15 -08:00
dependabot[bot]
6a5dcc01a6 Bump actions/checkout from 5.0.0 to 6.0.1
Bumps [actions/checkout](https://github.com/actions/checkout) from 5.0.0 to 6.0.1.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](08c6903cd8...8e8c483db8)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: 6.0.1
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: e19b33e2fa
2026-01-19 17:08:10 -08:00
dependabot[bot]
18b6d3bb0f Bump the patch-updates group across 1 directory with 3 updates
Bumps the patch-updates group with 3 updates in the / directory: [github.com/fosrl/newt](https://github.com/fosrl/newt), [github.com/godbus/dbus/v5](https://github.com/godbus/dbus) and [github.com/miekg/dns](https://github.com/miekg/dns).


Updates `github.com/fosrl/newt` from 1.8.0 to 1.8.1
- [Release notes](https://github.com/fosrl/newt/releases)
- [Commits](https://github.com/fosrl/newt/compare/1.8.0...1.8.1)

Updates `github.com/godbus/dbus/v5` from 5.2.0 to 5.2.2
- [Release notes](https://github.com/godbus/dbus/releases)
- [Commits](https://github.com/godbus/dbus/compare/v5.2.0...v5.2.2)

Updates `github.com/miekg/dns` from 1.1.68 to 1.1.70
- [Commits](https://github.com/miekg/dns/compare/v1.1.68...v1.1.70)

---
updated-dependencies:
- dependency-name: github.com/fosrl/newt
  dependency-version: 1.8.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: patch-updates
- dependency-name: github.com/godbus/dbus/v5
  dependency-version: 5.2.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: patch-updates
- dependency-name: github.com/miekg/dns
  dependency-version: 1.1.70
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: patch-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: 69f25032cb
2026-01-19 17:08:00 -08:00
dependabot[bot]
ccbfdc5265 Bump docker/metadata-action from 5.9.0 to 5.10.0
Bumps [docker/metadata-action](https://github.com/docker/metadata-action) from 5.9.0 to 5.10.0.
- [Release notes](https://github.com/docker/metadata-action/releases)
- [Commits](318604b99e...c299e40c65)

---
updated-dependencies:
- dependency-name: docker/metadata-action
  dependency-version: 5.10.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: 225779c665
2026-01-19 17:06:44 -08:00
dependabot[bot]
ab04537278 Bump softprops/action-gh-release from 2.4.2 to 2.5.0
Bumps [softprops/action-gh-release](https://github.com/softprops/action-gh-release) from 2.4.2 to 2.5.0.
- [Release notes](https://github.com/softprops/action-gh-release/releases)
- [Changelog](https://github.com/softprops/action-gh-release/blob/master/CHANGELOG.md)
- [Commits](5be0e66d93...a06a81a03e)

---
updated-dependencies:
- dependency-name: softprops/action-gh-release
  dependency-version: 2.5.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: a7f029e232
2026-01-19 17:06:23 -08:00
dependabot[bot]
29c36c9837 Bump docker/setup-buildx-action from 3.11.1 to 3.12.0
Bumps [docker/setup-buildx-action](https://github.com/docker/setup-buildx-action) from 3.11.1 to 3.12.0.
- [Release notes](https://github.com/docker/setup-buildx-action/releases)
- [Commits](e468171a9d...8d2750c68a)

---
updated-dependencies:
- dependency-name: docker/setup-buildx-action
  dependency-version: 3.12.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: af4e74de81
2026-01-19 17:05:50 -08:00
dependabot[bot]
c47e9bf547 Bump actions/cache from 4.3.0 to 5.0.2
Bumps [actions/cache](https://github.com/actions/cache) from 4.3.0 to 5.0.2.
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](0057852bfa...8b402f58fb)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-version: 5.0.2
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: f87d043d59
2026-01-19 17:05:42 -08:00
dependabot[bot]
abb682c935 Bump the minor-updates group across 1 directory with 2 updates
Bumps the minor-updates group with 2 updates in the / directory: [golang.org/x/sys](https://github.com/golang/sys) and software.sslmate.com/src/go-pkcs12.


Updates `golang.org/x/sys` from 0.38.0 to 0.40.0
- [Commits](https://github.com/golang/sys/compare/v0.38.0...v0.40.0)

Updates `software.sslmate.com/src/go-pkcs12` from 0.6.0 to 0.7.0

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-version: 0.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-updates
- dependency-name: software.sslmate.com/src/go-pkcs12
  dependency-version: 0.7.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: ae1436c5d1
2026-01-19 17:05:19 -08:00
Owen
79e8a4a8bb Dont start holepunching if we rebind while in low power mode
Former-commit-id: 4a5ebd41f3
2026-01-19 15:57:20 -08:00
Owen
f2e81c024a Set fingerprint earlier
Former-commit-id: ef36f7ca82
2026-01-19 15:05:29 -08:00
Owen
6d10650e70 Send an initial ping so we get online faster in the dashboard
Former-commit-id: 41e4eb24a2
2026-01-18 15:14:11 -08:00
Owen
a81c683c66 Reorder websocket disconnect message
Former-commit-id: 592a0d60c6
2026-01-18 14:49:42 -08:00
Owen
25cb50901e Quiet up logs again
Former-commit-id: 112283191c
2026-01-18 12:18:48 -08:00
Owen
a8e0844758 Send disconnecting message when stopping
Former-commit-id: 1fb6e2a00d
2026-01-18 11:55:09 -08:00
Owen
8b9ee6f26a Move power mode to the api from signal
Former-commit-id: 5d8ea92ef0
2026-01-18 11:46:18 -08:00
Owen
82e8fcc3a7 Merge branch 'bubble-errors-up' into dev
Former-commit-id: 61846f9ec4
2026-01-18 11:38:20 -08:00
Owen
e2b7777ba7 Merge branch 'rebind' into dev
Former-commit-id: 2139aeaa85
2026-01-18 11:37:43 -08:00
Owen
4e4d1a39f6 Try to close the socket first
Former-commit-id: ed4775bd26
2026-01-17 17:35:00 -08:00
Owen
17dc1b0be1 Dont start the ping until we are connected
Former-commit-id: 43c8a14fda
2026-01-17 17:32:01 -08:00
Owen
a06436eeab Add rebind endpoints for the shared socket
Former-commit-id: 6fd0984b13
2026-01-17 17:05:29 -08:00
Lokowitz
a83cc2a3a3 clean up dependabot
Former-commit-id: a37f0514c4
2026-01-17 15:43:06 -08:00
Lokowitz
d56537d0fd add docker build dev
Former-commit-id: b983216808
2026-01-17 15:43:06 -08:00
Lokowitz
31bb483e40 add qemu
Former-commit-id: 172eb97aa1
2026-01-17 15:43:06 -08:00
Lokowitz
cd91ae6e3a update test
Former-commit-id: b034f81ed9
2026-01-17 15:43:06 -08:00
Lokowitz
a9ec1e61d3 fix test
Former-commit-id: 076d01b48c
2026-01-17 15:43:06 -08:00
Owen
a13010c4af Update docs for metadata
Former-commit-id: 9d77a1daf7
2026-01-16 17:33:40 -08:00
Owen
cfac3cdd53 Use the right duration
Former-commit-id: c921f08bd5
2026-01-16 15:17:41 -08:00
Owen
5ecba61718 Use the right duration
Former-commit-id: 352b122166
2026-01-16 15:17:20 -08:00
Owen
2ea12ce258 Set the error on terminate as well
Former-commit-id: 8ff58e6efc
2026-01-16 14:59:13 -08:00
Owen
0b46289136 Add error can be sent from cloud to display in api
Former-commit-id: 2167f22713
2026-01-16 14:19:02 -08:00
Owen
71044165d0 Include fingerprint and posture info in ping
Former-commit-id: f061596e5b
2026-01-16 12:16:51 -08:00
Owen
eafd816159 Clean up log messages
Former-commit-id: 0231591f36
2026-01-16 12:02:02 -08:00
Owen
e1a687407e Set the ping inteval to 30 seconds
Former-commit-id: 737ffca15d
2026-01-15 21:59:18 -08:00
Owen
bd8031651e Message syncing works
Former-commit-id: 1650624a55
2026-01-15 21:25:53 -08:00
Owen
a63439543d Merge branch 'dev' into msg-delivery
Former-commit-id: d6b9170e79
2026-01-15 16:41:00 -08:00
Owen
90cd6e7f6e Merge branch 'power-state' into dev
Former-commit-id: e2a071e6dc
2026-01-15 16:39:41 -08:00
Owen
ea4a63c9b3 Merge branch 'dev' of github.com:fosrl/olm into dev
Former-commit-id: 1c21071ee1
2026-01-15 16:37:09 -08:00
Owen
e047330ffd Handle and test config version bugs
Former-commit-id: 285f8ce530
2026-01-15 16:36:11 -08:00
Owen
9dcc0796a6 Small clean up and move ping to client.go
Former-commit-id: af33218792
2026-01-15 14:20:12 -08:00
Varun Narravula
4b6999e06a feat(ping): send fingerprint and posture checks as part of ping/register
Former-commit-id: 70a7e83291
2026-01-15 12:13:36 -08:00
Varun Narravula
69952ee5c5 feat(api): add fingerprint + posture fields to client state
Former-commit-id: 566084683a
2026-01-15 12:13:36 -08:00
Owen
3710880ce0 Merge branch 'power-state' into msg-delivery
Former-commit-id: bda6606098
2026-01-14 17:51:42 -08:00
Owen
17b75bf58f Dont get token each time
Former-commit-id: 07dfc651f1
2026-01-14 16:51:04 -08:00
Owen
3ba1714524 Power state getting set correctly
Former-commit-id: 0895156efd
2026-01-14 16:38:40 -08:00
Owen
3470da76fc Update resetting intervals
Former-commit-id: 303c2dc0b7
2026-01-14 12:32:29 -08:00
Owen
c86df2c041 Refactor operation
Former-commit-id: 4f09d122bb
2026-01-14 11:58:12 -08:00
Owen
0e8315b149 Merge branch 'dev' into power-state
Former-commit-id: e9728efee3
2026-01-14 11:19:46 -08:00
Owen
2ab9790588 Reduce the pings
Former-commit-id: 5c6ad1ea75
2026-01-14 11:12:10 -08:00
Owen
1ecb97306f Add back AddDevice function
Former-commit-id: cae0ffa2e1
2026-01-13 21:38:37 -08:00
Varun Narravula
15e96a779c refactor(olm): convert global state into an olm instance
Former-commit-id: b755f77d95
2026-01-13 20:52:10 -08:00
miloschwartz
dada0cc124 add low power state for testing
Former-commit-id: 996fe59999
2026-01-13 14:30:02 -08:00
Owen
9c0b4fcd5f Fix error checking
Former-commit-id: 231808476b
2026-01-13 11:51:51 -08:00
Owen
8a788ef238 Merge branch 'dev' of github.com:fosrl/olm into dev
Former-commit-id: 8c5c8d3966
2026-01-12 17:12:45 -08:00
Owen
20e0c18845 Try to reduce cpu when idle
Former-commit-id: ba91478b89
2026-01-12 12:29:42 -08:00
Owen
5b637bb4ca Add expo backoff
Former-commit-id: faae551aca
2026-01-12 12:20:59 -08:00
Varun Narravula
c565a46a6f feat(logger): configure log file path thorugh global options
Former-commit-id: 577d89f4fb
2026-01-11 13:49:39 -08:00
Varun Narravula
7b7eae617a chore: format files using gofmt
Former-commit-id: 5cfa0dfb97
2026-01-11 13:49:39 -08:00
miloschwartz
1ed27fec1a set mtu to 0 on darwin
Former-commit-id: fbe686961e
2026-01-01 17:38:01 -05:00
Owen
83edde3449 Fix build on darwin
Former-commit-id: fbeb5be88d
2025-12-31 18:01:25 -05:00
Owen
1b43f029a9 Dont pass in dns proxy to override
Former-commit-id: 51dd927f9b
2025-12-31 15:42:51 -05:00
Owen
aeb908b68c Exiting the middle device works now?
Former-commit-id: d76b3c366f
2025-12-31 11:33:00 -05:00
Owen
f08b17c7bd Middle device working but not closing
Former-commit-id: c85fcc434b
2025-12-31 11:22:09 -05:00
Owen
cce8742490 Try to make the tun replacable
Former-commit-id: 6be0958887
2025-12-30 21:38:07 -05:00
Owen
c56696bab1 Use a different method on android
Former-commit-id: adf4c21f7b
2025-12-30 16:59:36 -05:00
Owen
7bb004cf50 Update docs
Former-commit-id: 543ca05eb9
2025-12-29 22:15:01 -05:00
Owen
28910ce188 Add stub
Former-commit-id: ece4239aaa
2025-12-29 17:50:15 -05:00
miloschwartz
f8dc134210 add content-length header to status payload
Former-commit-id: 8152d4133f
2025-12-29 17:28:12 -05:00
Varun Narravula
148f5fde23 fix(ci): add back missing docker build local image rule
Former-commit-id: 6d2afb4c72
2025-12-24 10:08:40 -05:00
Owen
b76259bc31 Add sync message
Former-commit-id: d01f180941
2025-12-24 10:06:25 -05:00
Owen
88cc57bcef Update mod
Former-commit-id: 1b474ebc1c
2025-12-23 18:00:15 -05:00
Owen
385c64c364 Dont run on v tags
Former-commit-id: 69a00b6231
2025-12-23 17:54:04 -05:00
Owen
0b05497c25 Merge branch 'dev' into msg-delivery
Former-commit-id: 4deb3e07b0
2025-12-23 15:44:02 -05:00
Owen
4e3e824276 Fix latest
Former-commit-id: 6fcd8ac6cb
2025-12-22 21:32:59 -05:00
Owen
effc1a31ac Update readme
Former-commit-id: 44282226b4
2025-12-22 17:24:51 -05:00
Owen
dde79bb2dc Fix go mod
Former-commit-id: e355d8db5f
2025-12-21 20:57:20 -05:00
Owen
3822b1a065 Add version and send it down
Former-commit-id: 52273a81c8
2025-12-19 16:45:11 -05:00
37 changed files with 4051 additions and 1574 deletions

View File

@@ -5,20 +5,10 @@ updates:
schedule: schedule:
interval: "daily" interval: "daily"
groups: groups:
dev-patch-updates: patch-updates:
dependency-type: "development"
update-types: update-types:
- "patch" - "patch"
dev-minor-updates: minor-updates:
dependency-type: "development"
update-types:
- "minor"
prod-patch-updates:
dependency-type: "production"
update-types:
- "patch"
prod-minor-updates:
dependency-type: "production"
update-types: update-types:
- "minor" - "minor"

View File

@@ -11,7 +11,9 @@ permissions:
on: on:
push: push:
tags: tags:
- "*" - "[0-9]+.[0-9]+.[0-9]+"
- "[0-9]+.[0-9]+.[0-9]+-rc.[0-9]+"
workflow_dispatch: workflow_dispatch:
inputs: inputs:
version: version:
@@ -46,7 +48,7 @@ jobs:
contents: write contents: write
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with: with:
fetch-depth: 0 fetch-depth: 0
@@ -90,7 +92,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with: with:
fetch-depth: 0 fetch-depth: 0
@@ -102,7 +104,7 @@ jobs:
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0 uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
- name: Set up 1.2.0 Buildx - name: Set up 1.2.0 Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1 uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
- name: Log in to Docker Hub - name: Log in to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
@@ -232,7 +234,7 @@ jobs:
- name: Cache Go modules - name: Cache Go modules
if: ${{ hashFiles('**/go.sum') != '' }} if: ${{ hashFiles('**/go.sum') != '' }}
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
with: with:
path: | path: |
~/.cache/go-build ~/.cache/go-build
@@ -267,13 +269,13 @@ jobs:
} >> "$GITHUB_ENV" } >> "$GITHUB_ENV"
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # v5.9.0 uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5.10.0
with: with:
images: ${{ env.IMAGE_LIST }} images: ${{ env.IMAGE_LIST }}
tags: | tags: |
type=semver,pattern={{version}},value=${{ env.TAG }} type=semver,pattern={{version}},value=${{ env.TAG }}
type=semver,pattern={{major}}.{{minor}},value=${{ env.TAG }},enable=${{ env.PUBLISH_MINOR == 'true' && env.IS_RC != 'true' }} type=semver,pattern={{major}}.{{minor}},value=${{ env.TAG }},enable=${{ env.PUBLISH_MINOR == 'true' && env.IS_RC != 'true' }}
type=raw,value=latest,enable=${{ env.PUBLISH_LATEST == 'true' && env.IS_RC != 'true' }} type=raw,value=latest,enable=${{ env.IS_RC != 'true' }}
flavor: | flavor: |
latest=false latest=false
labels: | labels: |
@@ -597,7 +599,7 @@ jobs:
shell: bash shell: bash
- name: Create GitHub Release - name: Create GitHub Release
uses: softprops/action-gh-release@5be0e66d93ac7ed76da52eca8bb058f665c3a5fe # v2.4.2 uses: softprops/action-gh-release@a06a81a03ee405af7f2048a818ed3f03bbf83c7b # v2.5.0
with: with:
tag_name: ${{ env.TAG }} tag_name: ${{ env.TAG }}
generate_release_notes: true generate_release_notes: true

37
.github/workflows/stale-bot.yml vendored Normal file
View File

@@ -0,0 +1,37 @@
name: Mark and Close Stale Issues
on:
schedule:
- cron: '0 0 * * *'
workflow_dispatch: # Allow manual trigger
permissions:
contents: write # only for delete-branch option
issues: write
pull-requests: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
with:
days-before-stale: 14
days-before-close: 14
stale-issue-message: 'This issue has been automatically marked as stale due to 14 days of inactivity. It will be closed in 14 days if no further activity occurs.'
close-issue-message: 'This issue has been automatically closed due to inactivity. If you believe this is still relevant, please open a new issue with up-to-date information.'
stale-issue-label: 'stale'
exempt-issue-labels: 'needs investigating, networking, new feature, reverse proxy, bug, api, authentication, documentation, enhancement, help wanted, good first issue, question'
exempt-all-issue-assignees: true
only-labels: ''
exempt-pr-labels: ''
days-before-pr-stale: -1
days-before-pr-close: -1
operations-per-run: 100
remove-stale-when-updated: true
delete-branch: false
enable-statistics: true

View File

@@ -1,5 +1,8 @@
name: Run Tests name: Run Tests
permissions:
contents: read
on: on:
pull_request: pull_request:
branches: branches:
@@ -7,11 +10,12 @@ on:
- dev - dev
jobs: jobs:
test: build-go:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Checkout repository
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- name: Set up Go - name: Set up Go
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0 uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
@@ -21,5 +25,18 @@ jobs:
- name: Build binaries - name: Build binaries
run: make go-build-release run: make go-build-release
build-docker:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- name: Set up QEMU
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
- name: Set up 1.2.0 Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
- name: Build Docker image - name: Build Docker image
run: make docker-build-release run: make docker-build-dev

95
API.md
View File

@@ -1,10 +1,10 @@
## HTTP API ## API
Olm can be controlled with an embedded HTTP server when using `--enable-http`. This allows you to start it as a daemon and trigger it with the following endpoints. The API can listen on either a TCP address or a Unix socket/Windows named pipe. Olm can be controlled with an embedded API server when using `--enable-api`. This allows you to start it as a daemon and trigger it with the following endpoints. The API can listen on either a TCP address or a Unix socket/Windows named pipe.
### Socket vs TCP ### Socket vs TCP
By default, when `--enable-http` is used, Olm listens on a TCP address (configured via `--http-addr`, default `:9452`). Alternatively, Olm can listen on a Unix socket (Linux/macOS) or Windows named pipe for local-only communication with better security. When `--enable-api` is used, Olm can listen on a TCP address when configured via `--http-addr` (like `:9452`). Alternatively, Olm can listen on a Unix socket (Linux/macOS) or Windows named pipe for local-only communication with better security when using `--socket-path` (like `/var/run/olm.sock`).
**Unix Socket (Linux/macOS):** **Unix Socket (Linux/macOS):**
- Socket path example: `/var/run/olm/olm.sock` - Socket path example: `/var/run/olm/olm.sock`
@@ -46,7 +46,18 @@ Initiates a new connection request to a Pangolin server.
"tlsClientCert": "string", "tlsClientCert": "string",
"pingInterval": "3s", "pingInterval": "3s",
"pingTimeout": "5s", "pingTimeout": "5s",
"orgId": "string" "orgId": "string",
"fingerprint": {
"username": "string",
"hostname": "string",
"platform": "string",
"osVersion": "string",
"kernelVersion": "string",
"arch": "string",
"deviceModel": "string",
"serialNumber": "string"
},
"postures": {}
} }
``` ```
@@ -67,6 +78,16 @@ Initiates a new connection request to a Pangolin server.
- `pingInterval`: Interval for pinging the server (default: 3s) - `pingInterval`: Interval for pinging the server (default: 3s)
- `pingTimeout`: Timeout for each ping (default: 5s) - `pingTimeout`: Timeout for each ping (default: 5s)
- `orgId`: Organization ID to connect to - `orgId`: Organization ID to connect to
- `fingerprint`: Device fingerprinting information (should be set before connecting)
- `username`: Current username on the device
- `hostname`: Device hostname
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
- `osVersion`: Operating system version
- `kernelVersion`: Kernel version
- `arch`: System architecture (e.g., amd64, arm64)
- `deviceModel`: Device model identifier
- `serialNumber`: Device serial number
- `postures`: Device posture/security information
**Response:** **Response:**
- **Status Code:** `202 Accepted` - **Status Code:** `202 Accepted`
@@ -205,6 +226,56 @@ Switches to a different organization while maintaining the connection.
--- ---
### PUT /metadata
Updates device fingerprinting and posture information. This endpoint can be called at any time to update metadata, but it's recommended to provide this information in the initial `/connect` request or immediately before connecting.
**Request Body:**
```json
{
"fingerprint": {
"username": "string",
"hostname": "string",
"platform": "string",
"osVersion": "string",
"kernelVersion": "string",
"arch": "string",
"deviceModel": "string",
"serialNumber": "string"
},
"postures": {}
}
```
**Optional Fields:**
- `fingerprint`: Device fingerprinting information
- `username`: Current username on the device
- `hostname`: Device hostname
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
- `osVersion`: Operating system version
- `kernelVersion`: Kernel version
- `arch`: System architecture (e.g., amd64, arm64)
- `deviceModel`: Device model identifier
- `serialNumber`: Device serial number
- `postures`: Device posture/security information (object with arbitrary key-value pairs)
**Response:**
- **Status Code:** `200 OK`
- **Content-Type:** `application/json`
```json
{
"status": "metadata updated"
}
```
**Error Responses:**
- `405 Method Not Allowed` - Non-PUT requests
- `400 Bad Request` - Invalid JSON
**Note:** It's recommended to call this endpoint BEFORE `/connect` to ensure fingerprinting information is available during the initial connection handshake.
---
### POST /exit ### POST /exit
Initiates a graceful shutdown of the Olm process. Initiates a graceful shutdown of the Olm process.
@@ -247,6 +318,22 @@ Simple health check endpoint to verify the API server is running.
## Usage Examples ## Usage Examples
### Update metadata before connecting (recommended)
```bash
curl -X PUT http://localhost:9452/metadata \
-H "Content-Type: application/json" \
-d '{
"fingerprint": {
"username": "john",
"hostname": "johns-laptop",
"platform": "macos",
"osVersion": "14.2.1",
"arch": "arm64",
"deviceModel": "MacBookPro18,3"
}
}'
```
### Connect to a peer ### Connect to a peer
```bash ```bash
curl -X POST http://localhost:9452/connect \ curl -X POST http://localhost:9452/connect \

View File

@@ -5,6 +5,9 @@ all: local
local: local:
CGO_ENABLED=0 go build -o ./bin/olm CGO_ENABLED=0 go build -o ./bin/olm
docker-build:
docker build -t fosrl/olm:latest .
docker-build-release: docker-build-release:
@if [ -z "$(tag)" ]; then \ @if [ -z "$(tag)" ]; then \
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \ echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
@@ -17,6 +20,12 @@ docker-build-release:
-f Dockerfile \ -f Dockerfile \
--push --push
docker-build-dev:
docker buildx build . \
--platform linux/arm/v7,linux/arm64,linux/amd64 \
-t fosrl/olm:latest \
-f Dockerfile
.PHONY: go-build-release \ .PHONY: go-build-release \
go-build-release-linux-arm64 go-build-release-linux-arm32-v7 \ go-build-release-linux-arm64 go-build-release-linux-arm32-v7 \
go-build-release-linux-arm32-v6 go-build-release-linux-amd64 \ go-build-release-linux-arm32-v6 go-build-release-linux-amd64 \

View File

@@ -20,13 +20,7 @@ When Olm receives WireGuard control messages, it will use the information encode
## Hole Punching ## Hole Punching
In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to newt. If you want to disable hole punching, use the `--disable-holepunch` flag. Hole punching attempts to orchestrate a NAT hole punch between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil. In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to Newt. Hole punching attempts to orchestrate a NAT traversal between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil.
Right now, basic NAT hole punching is supported. We plan to add:
- [ ] Birthday paradox
- [ ] UPnP
- [ ] LAN detection
## Build ## Build

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"strconv"
"sync" "sync"
"time" "time"
@@ -32,7 +33,12 @@ type ConnectionRequest struct {
// SwitchOrgRequest defines the structure for switching organizations // SwitchOrgRequest defines the structure for switching organizations
type SwitchOrgRequest struct { type SwitchOrgRequest struct {
OrgID string `json:"orgId"` OrgID string `json:"org_id"`
}
// PowerModeRequest represents a request to change power mode
type PowerModeRequest struct {
Mode string `json:"mode"` // "normal" or "low"
} }
// PeerStatus represents the status of a peer connection // PeerStatus represents the status of a peer connection
@@ -48,11 +54,18 @@ type PeerStatus struct {
HolepunchConnected bool `json:"holepunchConnected"` HolepunchConnected bool `json:"holepunchConnected"`
} }
// OlmError holds error information from registration failures
type OlmError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// StatusResponse is returned by the status endpoint // StatusResponse is returned by the status endpoint
type StatusResponse struct { type StatusResponse struct {
Connected bool `json:"connected"` Connected bool `json:"connected"`
Registered bool `json:"registered"` Registered bool `json:"registered"`
Terminated bool `json:"terminated"` Terminated bool `json:"terminated"`
OlmError *OlmError `json:"error,omitempty"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
Agent string `json:"agent,omitempty"` Agent string `json:"agent,omitempty"`
OrgID string `json:"orgId,omitempty"` OrgID string `json:"orgId,omitempty"`
@@ -60,25 +73,37 @@ type StatusResponse struct {
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"` NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
} }
type MetadataChangeRequest struct {
Fingerprint map[string]any `json:"fingerprint"`
Postures map[string]any `json:"postures"`
}
// API represents the HTTP server and its state // API represents the HTTP server and its state
type API struct { type API struct {
addr string addr string
socketPath string socketPath string
listener net.Listener listener net.Listener
server *http.Server server *http.Server
onConnect func(ConnectionRequest) error
onSwitchOrg func(SwitchOrgRequest) error onConnect func(ConnectionRequest) error
onDisconnect func() error onSwitchOrg func(SwitchOrgRequest) error
onExit func() error onMetadataChange func(MetadataChangeRequest) error
onDisconnect func() error
onExit func() error
onRebind func() error
onPowerMode func(PowerModeRequest) error
statusMu sync.RWMutex statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus peerStatuses map[int]*PeerStatus
connectedAt time.Time connectedAt time.Time
isConnected bool isConnected bool
isRegistered bool isRegistered bool
isTerminated bool isTerminated bool
version string olmError *OlmError
agent string
orgID string version string
agent string
orgID string
} }
// NewAPI creates a new HTTP server that listens on a TCP address // NewAPI creates a new HTTP server that listens on a TCP address
@@ -101,28 +126,49 @@ func NewAPISocket(socketPath string) *API {
return s return s
} }
func NewAPIStub() *API {
s := &API{
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// SetHandlers sets the callback functions for handling API requests // SetHandlers sets the callback functions for handling API requests
func (s *API) SetHandlers( func (s *API) SetHandlers(
onConnect func(ConnectionRequest) error, onConnect func(ConnectionRequest) error,
onSwitchOrg func(SwitchOrgRequest) error, onSwitchOrg func(SwitchOrgRequest) error,
onMetadataChange func(MetadataChangeRequest) error,
onDisconnect func() error, onDisconnect func() error,
onExit func() error, onExit func() error,
onRebind func() error,
onPowerMode func(PowerModeRequest) error,
) { ) {
s.onConnect = onConnect s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg s.onSwitchOrg = onSwitchOrg
s.onMetadataChange = onMetadataChange
s.onDisconnect = onDisconnect s.onDisconnect = onDisconnect
s.onExit = onExit s.onExit = onExit
s.onRebind = onRebind
s.onPowerMode = onPowerMode
} }
// Start starts the HTTP server // Start starts the HTTP server
func (s *API) Start() error { func (s *API) Start() error {
if s.socketPath == "" && s.addr == "" {
return fmt.Errorf("either socketPath or addr must be provided to start the API server")
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg) mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/metadata", s.handleMetadataChange)
mux.HandleFunc("/disconnect", s.handleDisconnect) mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit) mux.HandleFunc("/exit", s.handleExit)
mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/rebind", s.handleRebind)
mux.HandleFunc("/power-mode", s.handlePowerMode)
s.server = &http.Server{ s.server = &http.Server{
Handler: mux, Handler: mux,
@@ -160,7 +206,7 @@ func (s *API) Stop() error {
// Close the server first, which will also close the listener gracefully // Close the server first, which will also close the listener gracefully
if s.server != nil { if s.server != nil {
s.server.Close() _ = s.server.Close()
} }
// Clean up socket file if using Unix socket // Clean up socket file if using Unix socket
@@ -226,9 +272,6 @@ func (s *API) SetConnectionStatus(isConnected bool) {
if isConnected { if isConnected {
s.connectedAt = time.Now() s.connectedAt = time.Now()
} else {
// Clear peer statuses when disconnected
s.peerStatuses = make(map[int]*PeerStatus)
} }
} }
@@ -236,6 +279,27 @@ func (s *API) SetRegistered(registered bool) {
s.statusMu.Lock() s.statusMu.Lock()
defer s.statusMu.Unlock() defer s.statusMu.Unlock()
s.isRegistered = registered s.isRegistered = registered
// Clear any registration error when successfully registered
if registered {
s.olmError = nil
}
}
// SetOlmError sets the registration error
func (s *API) SetOlmError(code string, message string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.olmError = &OlmError{
Code: code,
Message: message,
}
}
// ClearOlmError clears any registration error
func (s *API) ClearOlmError() {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.olmError = nil
} }
func (s *API) SetTerminated(terminated bool) { func (s *API) SetTerminated(terminated bool) {
@@ -345,7 +409,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
// Return a success response // Return a success response
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted) w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{ _ = json.NewEncoder(w).Encode(map[string]string{
"status": "connection request accepted", "status": "connection request accepted",
}) })
} }
@@ -358,12 +422,12 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
} }
s.statusMu.RLock() s.statusMu.RLock()
defer s.statusMu.RUnlock()
resp := StatusResponse{ resp := StatusResponse{
Connected: s.isConnected, Connected: s.isConnected,
Registered: s.isRegistered, Registered: s.isRegistered,
Terminated: s.isTerminated, Terminated: s.isTerminated,
OlmError: s.olmError,
Version: s.version, Version: s.version,
Agent: s.agent, Agent: s.agent,
OrgID: s.orgID, OrgID: s.orgID,
@@ -371,8 +435,18 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
NetworkSettings: network.GetSettings(), NetworkSettings: network.GetSettings(),
} }
s.statusMu.RUnlock()
data, err := json.Marshal(resp)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp) w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
} }
// handleHealth handles the /health endpoint // handleHealth handles the /health endpoint
@@ -384,7 +458,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{ _ = json.NewEncoder(w).Encode(map[string]string{
"status": "ok", "status": "ok",
}) })
} }
@@ -401,7 +475,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
// Return a success response first // Return a success response first
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{ _ = json.NewEncoder(w).Encode(map[string]string{
"status": "shutdown initiated", "status": "shutdown initiated",
}) })
@@ -450,7 +524,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
// Return a success response // Return a success response
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{ _ = json.NewEncoder(w).Encode(map[string]string{
"status": "org switch request accepted", "status": "org switch request accepted",
}) })
} }
@@ -484,16 +558,43 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
// Return a success response // Return a success response
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{ _ = json.NewEncoder(w).Encode(map[string]string{
"status": "disconnect initiated", "status": "disconnect initiated",
}) })
} }
// handleMetadataChange handles the /metadata endpoint
func (s *API) handleMetadataChange(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req MetadataChangeRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
logger.Info("Received metadata change request via API: %v", req)
_ = s.onMetadataChange(req)
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "metadata updated",
})
}
func (s *API) GetStatus() StatusResponse { func (s *API) GetStatus() StatusResponse {
return StatusResponse{ return StatusResponse{
Connected: s.isConnected, Connected: s.isConnected,
Registered: s.isRegistered, Registered: s.isRegistered,
Terminated: s.isTerminated, Terminated: s.isTerminated,
OlmError: s.olmError,
Version: s.version, Version: s.version,
Agent: s.agent, Agent: s.agent,
OrgID: s.orgID, OrgID: s.orgID,
@@ -501,3 +602,74 @@ func (s *API) GetStatus() StatusResponse {
NetworkSettings: network.GetSettings(), NetworkSettings: network.GetSettings(),
} }
} }
// handleRebind handles the /rebind endpoint
// This triggers a socket rebind, which is necessary when network connectivity changes
// (e.g., WiFi to cellular transition on macOS/iOS) and the old socket becomes stale.
func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("Received rebind request via API")
// Call the rebind handler if set
if s.onRebind != nil {
if err := s.onRebind(); err != nil {
http.Error(w, fmt.Sprintf("Rebind failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "Rebind handler not configured", http.StatusNotImplemented)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "socket rebound successfully",
})
}
// handlePowerMode handles the /power-mode endpoint
// This allows changing the power mode between "normal" and "low"
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req PowerModeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate power mode
if req.Mode != "normal" && req.Mode != "low" {
http.Error(w, "Invalid power mode: must be 'normal' or 'low'", http.StatusBadRequest)
return
}
logger.Info("Received power mode change request via API: mode=%s", req.Mode)
// Call the power mode handler if set
if s.onPowerMode != nil {
if err := s.onPowerMode(req); err != nil {
http.Error(w, fmt.Sprintf("Power mode change failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "Power mode handler not configured", http.StatusNotImplemented)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": fmt.Sprintf("power mode changed to %s successfully", req.Mode),
})
}

View File

@@ -89,6 +89,7 @@ func DefaultConfig() *OlmConfig {
PingInterval: "3s", PingInterval: "3s",
PingTimeout: "5s", PingTimeout: "5s",
DisableHolepunch: false, DisableHolepunch: false,
OverrideDNS: true,
TunnelDNS: false, TunnelDNS: false,
// DoNotCreateNewClient: false, // DoNotCreateNewClient: false,
sources: make(map[string]string), sources: make(map[string]string),
@@ -324,9 +325,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping") serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests") serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching") serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings") serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "When enabled, the client uses custom DNS servers to resolve internal resources and aliases. This overrides your system's default DNS settings. Queries that cannot be resolved as a Pangolin resource will be forwarded to your configured Upstream DNS Server. (default false)")
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections") serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "Use tunnel for DNS traffic") serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "When enabled, DNS queries are routed through the tunnel for remote resolution. To ensure queries are tunneled correctly, you must define the DNS server as a Pangolin resource and enter its address as an Upstream DNS Server. (default false)")
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client") // serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
version := serviceFlags.Bool("version", false, "Print the version") version := serviceFlags.Bool("version", false, "Print the version")

View File

@@ -1,9 +1,12 @@
package device package device
import ( import (
"io"
"net/netip" "net/netip"
"os" "os"
"sync" "sync"
"sync/atomic"
"time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@@ -18,14 +21,68 @@ type FilterRule struct {
Handler PacketHandler Handler PacketHandler
} }
// MiddleDevice wraps a TUN device with packet filtering capabilities // closeAwareDevice wraps a tun.Device along with a flag
type MiddleDevice struct { // indicating whether its Close method was called.
type closeAwareDevice struct {
isClosed atomic.Bool
tun.Device tun.Device
rules []FilterRule closeEventCh chan struct{}
mutex sync.RWMutex wg sync.WaitGroup
readCh chan readResult closeOnce sync.Once
injectCh chan []byte }
closed chan struct{}
func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice {
return &closeAwareDevice{
Device: tunDevice,
isClosed: atomic.Bool{},
closeEventCh: make(chan struct{}),
}
}
// redirectEvents redirects the Events() method of the underlying tun.Device
// to the given channel.
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
for {
select {
case ev, ok := <-c.Device.Events():
if !ok {
return
}
if ev == tun.EventDown {
continue
}
select {
case out <- ev:
case <-c.closeEventCh:
return
}
case <-c.closeEventCh:
return
}
}
}()
}
// Close calls the underlying Device's Close method
// after setting isClosed to true.
func (c *closeAwareDevice) Close() (err error) {
c.closeOnce.Do(func() {
c.isClosed.Store(true)
close(c.closeEventCh)
err = c.Device.Close()
c.wg.Wait()
})
return err
}
func (c *closeAwareDevice) IsClosed() bool {
return c.isClosed.Load()
} }
type readResult struct { type readResult struct {
@@ -36,58 +93,136 @@ type readResult struct {
err error err error
} }
// MiddleDevice wraps a TUN device with packet filtering capabilities
// and supports swapping the underlying device.
type MiddleDevice struct {
devices []*closeAwareDevice
mu sync.Mutex
cond *sync.Cond
rules []FilterRule
rulesMutex sync.RWMutex
readCh chan readResult
injectCh chan []byte
closed atomic.Bool
events chan tun.Event
}
// NewMiddleDevice creates a new filtered TUN device wrapper // NewMiddleDevice creates a new filtered TUN device wrapper
func NewMiddleDevice(device tun.Device) *MiddleDevice { func NewMiddleDevice(device tun.Device) *MiddleDevice {
d := &MiddleDevice{ d := &MiddleDevice{
Device: device, devices: make([]*closeAwareDevice, 0),
rules: make([]FilterRule, 0), rules: make([]FilterRule, 0),
readCh: make(chan readResult), readCh: make(chan readResult, 16),
injectCh: make(chan []byte, 100), injectCh: make(chan []byte, 100),
closed: make(chan struct{}), events: make(chan tun.Event, 16),
} }
go d.pump() d.cond = sync.NewCond(&d.mu)
if device != nil {
d.AddDevice(device)
}
return d return d
} }
func (d *MiddleDevice) pump() { // AddDevice adds a new underlying TUN device, closing any previous one
func (d *MiddleDevice) AddDevice(device tun.Device) {
d.mu.Lock()
if d.closed.Load() {
d.mu.Unlock()
_ = device.Close()
return
}
var toClose *closeAwareDevice
if len(d.devices) > 0 {
toClose = d.devices[len(d.devices)-1]
}
cad := newCloseAwareDevice(device)
cad.redirectEvents(d.events)
d.devices = []*closeAwareDevice{cad}
// Start pump for the new device
go d.pump(cad)
d.cond.Broadcast()
d.mu.Unlock()
if toClose != nil {
logger.Debug("MiddleDevice: Closing previous device")
if err := toClose.Close(); err != nil {
logger.Debug("MiddleDevice: Error closing previous device: %v", err)
}
}
}
func (d *MiddleDevice) pump(dev *closeAwareDevice) {
const defaultOffset = 16 const defaultOffset = 16
batchSize := d.Device.BatchSize() batchSize := dev.BatchSize()
logger.Debug("MiddleDevice: pump started") logger.Debug("MiddleDevice: pump started for device")
// Recover from panic if readCh is closed while we're trying to send
defer func() {
if r := recover(); r != nil {
logger.Debug("MiddleDevice: pump recovered from panic (channel closed)")
}
}()
for { for {
// Check closed first with priority // Check if this device is closed
select { if dev.IsClosed() {
case <-d.closed: logger.Debug("MiddleDevice: pump exiting, device is closed")
logger.Debug("MiddleDevice: pump exiting due to closed channel") return
}
// Check if MiddleDevice itself is closed
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed")
return return
default:
} }
// Allocate buffers for reading // Allocate buffers for reading
// We allocate new buffers for each read to avoid race conditions
// since we pass them to the channel
bufs := make([][]byte, batchSize) bufs := make([][]byte, batchSize)
sizes := make([]int, batchSize) sizes := make([]int, batchSize)
for i := range bufs { for i := range bufs {
bufs[i] = make([]byte, 2048) // Standard MTU + headroom bufs[i] = make([]byte, 2048) // Standard MTU + headroom
} }
n, err := d.Device.Read(bufs, sizes, defaultOffset) n, err := dev.Read(bufs, sizes, defaultOffset)
// Check closed again after read returns // Check if device was closed during read
select { if dev.IsClosed() {
case <-d.closed: logger.Debug("MiddleDevice: pump exiting, device closed during read")
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)") return
}
// Check if MiddleDevice was closed during read
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read")
return
}
// Try to send the result - check closed state first to avoid sending on closed channel
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, device closed before send")
return return
default:
} }
// Now try to send the result
select { select {
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
case <-d.closed: default:
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)") // Channel full, check if we should exit
return if dev.IsClosed() || d.closed.Load() {
return
}
// Try again with blocking
select {
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
case <-dev.closeEventCh:
return
}
} }
if err != nil { if err != nil {
@@ -99,16 +234,28 @@ func (d *MiddleDevice) pump() {
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN) // InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
func (d *MiddleDevice) InjectOutbound(packet []byte) { func (d *MiddleDevice) InjectOutbound(packet []byte) {
if d.closed.Load() {
return
}
// Use defer/recover to handle panic from sending on closed channel
// This can happen during shutdown race conditions
defer func() {
if r := recover(); r != nil {
logger.Debug("MiddleDevice: InjectOutbound recovered from panic (channel closed)")
}
}()
select { select {
case d.injectCh <- packet: case d.injectCh <- packet:
case <-d.closed: default:
// Channel full, drop packet
logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full")
} }
} }
// AddRule adds a packet filtering rule // AddRule adds a packet filtering rule
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) { func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
d.mutex.Lock() d.rulesMutex.Lock()
defer d.mutex.Unlock() defer d.rulesMutex.Unlock()
d.rules = append(d.rules, FilterRule{ d.rules = append(d.rules, FilterRule{
DestIP: destIP, DestIP: destIP,
Handler: handler, Handler: handler,
@@ -117,8 +264,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
// RemoveRule removes all rules for a given destination IP // RemoveRule removes all rules for a given destination IP
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
d.mutex.Lock() d.rulesMutex.Lock()
defer d.mutex.Unlock() defer d.rulesMutex.Unlock()
newRules := make([]FilterRule, 0, len(d.rules)) newRules := make([]FilterRule, 0, len(d.rules))
for _, rule := range d.rules { for _, rule := range d.rules {
if rule.DestIP != destIP { if rule.DestIP != destIP {
@@ -130,18 +277,120 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
// Close stops the device // Close stops the device
func (d *MiddleDevice) Close() error { func (d *MiddleDevice) Close() error {
select { if !d.closed.CompareAndSwap(false, true) {
case <-d.closed: return nil // already closed
// Already closed
return nil
default:
logger.Debug("MiddleDevice: Closing, signaling closed channel")
close(d.closed)
} }
logger.Debug("MiddleDevice: Closing underlying TUN device")
err := d.Device.Close() d.mu.Lock()
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err) devices := d.devices
return err d.devices = nil
d.cond.Broadcast()
d.mu.Unlock()
// Close underlying devices first - this causes the pump goroutines to exit
// when their read operations return errors
var lastErr error
logger.Debug("MiddleDevice: Closing %d devices", len(devices))
for _, device := range devices {
if err := device.Close(); err != nil {
logger.Debug("MiddleDevice: Error closing device: %v", err)
lastErr = err
}
}
// Now close channels to unblock any remaining readers
// The pump should have exited by now, but close channels to be safe
close(d.readCh)
close(d.injectCh)
close(d.events)
return lastErr
}
// Events returns the events channel
func (d *MiddleDevice) Events() <-chan tun.Event {
return d.events
}
// File returns the underlying file descriptor
func (d *MiddleDevice) File() *os.File {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return nil
}
continue
}
file := dev.File()
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return file
}
}
// MTU returns the MTU of the underlying device
func (d *MiddleDevice) MTU() (int, error) {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
continue
}
mtu, err := dev.MTU()
if err == nil {
return mtu, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return 0, err
}
}
// Name returns the name of the underlying device
func (d *MiddleDevice) Name() (string, error) {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return "", io.EOF
}
continue
}
name, err := dev.Name()
if err == nil {
return name, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return "", err
}
}
// BatchSize returns the batch size
func (d *MiddleDevice) BatchSize() int {
dev := d.peekLast()
if dev == nil {
return 1
}
return dev.BatchSize()
} }
// extractDestIP extracts destination IP from packet (fast path) // extractDestIP extracts destination IP from packet (fast path)
@@ -176,156 +425,239 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
// Read intercepts packets going UP from the TUN device (towards WireGuard) // Read intercepts packets going UP from the TUN device (towards WireGuard)
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
// Check if already closed first (non-blocking) for {
select { if d.closed.Load() {
case <-d.closed: logger.Debug("MiddleDevice: Read returning io.EOF, device closed")
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)") return 0, io.EOF
return 0, os.ErrClosed
default:
}
// Now block waiting for data
select {
case res := <-d.readCh:
if res.err != nil {
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
return 0, res.err
} }
// Copy packets from result to provided buffers // Wait for a device to be available
count := 0 dev := d.peekLast()
for i := 0; i < res.n && i < len(bufs); i++ { if dev == nil {
// Handle offset mismatch if necessary if !d.waitForDevice() {
// We assume the pump used defaultOffset (16) return 0, io.EOF
// If caller asks for different offset, we need to shift
src := res.bufs[i]
srcOffset := res.offset
srcSize := res.sizes[i]
// Calculate where the packet data starts and ends in src
pktData := src[srcOffset : srcOffset+srcSize]
// Ensure dest buffer is large enough
if len(bufs[i]) < offset+len(pktData) {
continue // Skip if buffer too small
} }
copy(bufs[i][offset:], pktData)
sizes[i] = len(pktData)
count++
}
n = count
case pkt := <-d.injectCh:
if len(bufs) == 0 {
return 0, nil
}
if len(bufs[0]) < offset+len(pkt) {
return 0, nil // Buffer too small
}
copy(bufs[0][offset:], pkt)
sizes[0] = len(pkt)
n = 1
case <-d.closed:
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
return 0, os.ErrClosed // Signal that device is closed
}
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
if len(rules) == 0 {
return n, nil
}
// Process packets and filter out handled ones
writeIdx := 0
for readIdx := 0; readIdx < n; readIdx++ {
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
continue continue
} }
// Check if packet matches any rule // Now block waiting for data from readCh or injectCh
handled := false select {
for _, rule := range rules { case res, ok := <-d.readCh:
if rule.DestIP == destIP { if !ok {
if rule.Handler(packet) { // Channel closed, device is shutting down
// Packet was handled and should be dropped return 0, io.EOF
handled = true }
break if res.err != nil {
// Check if device was swapped
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
return 0, res.err
}
// Copy packets from result to provided buffers
count := 0
for i := 0; i < res.n && i < len(bufs); i++ {
src := res.bufs[i]
srcOffset := res.offset
srcSize := res.sizes[i]
pktData := src[srcOffset : srcOffset+srcSize]
if len(bufs[i]) < offset+len(pktData) {
continue
}
copy(bufs[i][offset:], pktData)
sizes[i] = len(pktData)
count++
}
n = count
case pkt, ok := <-d.injectCh:
if !ok {
// Channel closed, device is shutting down
return 0, io.EOF
}
if len(bufs) == 0 {
return 0, nil
}
if len(bufs[0]) < offset+len(pkt) {
return 0, nil
}
copy(bufs[0][offset:], pkt)
sizes[0] = len(pkt)
n = 1
}
// Apply filtering rules
d.rulesMutex.RLock()
rules := d.rules
d.rulesMutex.RUnlock()
if len(rules) == 0 {
return n, nil
}
// Process packets and filter out handled ones
writeIdx := 0
for readIdx := 0; readIdx < n; readIdx++ {
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
destIP, ok := extractDestIP(packet)
if !ok {
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
continue
}
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
handled = true
break
}
} }
} }
}
if !handled { if !handled {
// Keep packet if writeIdx != readIdx {
if writeIdx != readIdx { bufs[writeIdx] = bufs[readIdx]
bufs[writeIdx] = bufs[readIdx] sizes[writeIdx] = sizes[readIdx]
sizes[writeIdx] = sizes[readIdx] }
writeIdx++
} }
writeIdx++
} }
}
return writeIdx, err return writeIdx, nil
}
} }
// Write intercepts packets going DOWN to the TUN device (from WireGuard) // Write intercepts packets going DOWN to the TUN device (from WireGuard)
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) { func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock() for {
rules := d.rules if d.closed.Load() {
d.mutex.RUnlock() return 0, io.EOF
}
if len(rules) == 0 { dev := d.peekLast()
return d.Device.Write(bufs, offset) if dev == nil {
} if !d.waitForDevice() {
return 0, io.EOF
// Filter packets going down }
filteredBufs := make([][]byte, 0, len(bufs))
for _, buf := range bufs {
if len(buf) <= offset {
continue continue
} }
packet := buf[offset:] d.rulesMutex.RLock()
destIP, ok := extractDestIP(packet) rules := d.rules
if !ok { d.rulesMutex.RUnlock()
// Can't parse, keep packet
filteredBufs = append(filteredBufs, buf)
continue
}
// Check if packet matches any rule var filteredBufs [][]byte
handled := false if len(rules) == 0 {
for _, rule := range rules { filteredBufs = bufs
if rule.DestIP == destIP { } else {
if rule.Handler(packet) { filteredBufs = make([][]byte, 0, len(bufs))
// Packet was handled and should be dropped for _, buf := range bufs {
handled = true if len(buf) <= offset {
break continue
}
packet := buf[offset:]
destIP, ok := extractDestIP(packet)
if !ok {
filteredBufs = append(filteredBufs, buf)
continue
}
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
handled = true
break
}
}
}
if !handled {
filteredBufs = append(filteredBufs, buf)
} }
} }
} }
if !handled { if len(filteredBufs) == 0 {
filteredBufs = append(filteredBufs, buf) return len(bufs), nil
} }
}
if len(filteredBufs) == 0 { n, err := dev.Write(filteredBufs, offset)
return len(bufs), nil // All packets were handled if err == nil {
} return n, nil
}
return d.Device.Write(filteredBufs, offset) if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
} }
func (d *MiddleDevice) waitForDevice() bool {
d.mu.Lock()
defer d.mu.Unlock()
for len(d.devices) == 0 && !d.closed.Load() {
d.cond.Wait()
}
return !d.closed.Load()
}
func (d *MiddleDevice) peekLast() *closeAwareDevice {
d.mu.Lock()
defer d.mu.Unlock()
if len(d.devices) == 0 {
return nil
}
return d.devices[len(d.devices)-1]
}
// WriteToTun writes packets directly to the underlying TUN device,
// bypassing WireGuard. This is useful for sending packets that should
// appear to come from the TUN interface (e.g., DNS responses from a proxy).
// Unlike Write(), this does not go through packet filtering rules.
func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) {
for {
if d.closed.Load() {
return 0, io.EOF
}
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
continue
}
n, err := dev.Write(bufs, offset)
if err == nil {
return n, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
}

View File

@@ -1,4 +1,4 @@
//go:build !windows //go:build darwin
package device package device
@@ -26,7 +26,7 @@ func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
} }
file := os.NewFile(uintptr(dupTunFd), "/dev/tun") file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
device, err := tun.CreateTUNFromFile(file, mtuInt) device, err := tun.CreateTUNFromFile(file, 0)
if err != nil { if err != nil {
file.Close() file.Close()
return nil, err return nil, err

50
device/tun_linux.go Normal file
View File

@@ -0,0 +1,50 @@
//go:build linux
package device
import (
"net"
"os"
"runtime"
"github.com/fosrl/newt/logger"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
)
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
if runtime.GOOS == "android" { // otherwise we get a permission denied
theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
return theTun, err
}
dupTunFd, err := unix.Dup(int(tunFd))
if err != nil {
logger.Error("Unable to dup tun fd: %v", err)
return nil, err
}
err = unix.SetNonblock(dupTunFd, true)
if err != nil {
unix.Close(dupTunFd)
return nil, err
}
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
device, err := tun.CreateTUNFromFile(file, mtuInt)
if err != nil {
file.Close()
return nil, err
}
return device, nil
}
func UapiOpen(interfaceName string) (*os.File, error) {
return ipc.UAPIOpen(interfaceName)
}
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
return ipc.UAPIListen(interfaceName, fileUAPI)
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/fosrl/newt/util" "github.com/fosrl/newt/util"
"github.com/fosrl/olm/device" "github.com/fosrl/olm/device"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
@@ -34,18 +33,17 @@ type DNSProxy struct {
ep *channel.Endpoint ep *channel.Endpoint
proxyIP netip.Addr proxyIP netip.Addr
upstreamDNS []string upstreamDNS []string
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
mtu int mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
recordStore *DNSRecordStore // Local DNS records recordStore *DNSRecordStore // Local DNS records
// Tunnel DNS fields - for sending queries over WireGuard // Tunnel DNS fields - for sending queries over WireGuard
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries) tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
tunnelEp *channel.Endpoint tunnelEp *channel.Endpoint
tunnelActivePorts map[uint16]bool tunnelActivePorts map[uint16]bool
tunnelPortsLock sync.Mutex tunnelPortsLock sync.Mutex
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
@@ -53,7 +51,7 @@ type DNSProxy struct {
} }
// NewDNSProxy creates a new DNS proxy // NewDNSProxy creates a new DNS proxy
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) { func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
proxyIP, err := PickIPFromSubnet(utilitySubnet) proxyIP, err := PickIPFromSubnet(utilitySubnet)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err) return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
@@ -68,7 +66,6 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
proxy := &DNSProxy{ proxy := &DNSProxy{
proxyIP: proxyIP, proxyIP: proxyIP,
mtu: mtu, mtu: mtu,
tunDevice: tunDevice,
middleDevice: middleDevice, middleDevice: middleDevice,
upstreamDNS: upstreamDns, upstreamDNS: upstreamDns,
tunnelDNS: tunnelDns, tunnelDNS: tunnelDns,
@@ -383,7 +380,7 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
// Check if we have local records for this query // Check if we have local records for this query
var response *dns.Msg var response *dns.Msg
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA { if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA || question.Qtype == dns.TypePTR {
response = p.checkLocalRecords(msg, question) response = p.checkLocalRecords(msg, question)
} }
@@ -413,6 +410,34 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
// checkLocalRecords checks if we have local records for the query // checkLocalRecords checks if we have local records for the query
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg { func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
// Handle PTR queries
if question.Qtype == dns.TypePTR {
if ptrDomain, ok := p.recordStore.GetPTRRecord(question.Name); ok {
logger.Debug("Found local PTR record for %s -> %s", question.Name, ptrDomain)
// Create response message
response := new(dns.Msg)
response.SetReply(query)
response.Authoritative = true
// Add PTR answer record
rr := &dns.PTR{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 300, // 5 minutes
},
Ptr: ptrDomain,
}
response.Answer = append(response.Answer, rr)
return response
}
return nil
}
// Handle A and AAAA queries
var recordType RecordType var recordType RecordType
if question.Qtype == dns.TypeA { if question.Qtype == dns.TypeA {
recordType = RecordTypeA recordType = RecordTypeA
@@ -602,12 +627,12 @@ func (p *DNSProxy) runTunnelPacketSender() {
defer p.wg.Done() defer p.wg.Done()
logger.Debug("DNS tunnel packet sender goroutine started") logger.Debug("DNS tunnel packet sender goroutine started")
ticker := time.NewTicker(1 * time.Millisecond)
defer ticker.Stop()
for { for {
select { // Use blocking ReadContext instead of polling - much more CPU efficient
case <-p.ctx.Done(): // This will block until a packet is available or context is cancelled
pkt := p.tunnelEp.ReadContext(p.ctx)
if pkt == nil {
// Context was cancelled or endpoint closed
logger.Debug("DNS tunnel packet sender exiting") logger.Debug("DNS tunnel packet sender exiting")
// Drain any remaining packets // Drain any remaining packets
for { for {
@@ -618,36 +643,28 @@ func (p *DNSProxy) runTunnelPacketSender() {
pkt.DecRef() pkt.DecRef()
} }
return return
case <-ticker.C:
// Try to read packets
for i := 0; i < 10; i++ {
pkt := p.tunnelEp.Read()
if pkt == nil {
break
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
p.middleDevice.InjectOutbound(buf)
}
pkt.DecRef()
}
} }
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
p.middleDevice.InjectOutbound(buf)
}
pkt.DecRef()
} }
} }
@@ -660,18 +677,12 @@ func (p *DNSProxy) runPacketSender() {
const offset = 16 const offset = 16
for { for {
select { // Use blocking ReadContext instead of polling - much more CPU efficient
case <-p.ctx.Done(): // This will block until a packet is available or context is cancelled
return pkt := p.ep.ReadContext(p.ctx)
default:
}
// Read packets from netstack endpoint
pkt := p.ep.Read()
if pkt == nil { if pkt == nil {
// No packet available, small sleep to avoid busy loop // Context was cancelled or endpoint closed
time.Sleep(1 * time.Millisecond) return
continue
} }
// Extract packet data as slices // Extract packet data as slices
@@ -694,9 +705,9 @@ func (p *DNSProxy) runPacketSender() {
pos += len(slice) pos += len(slice)
} }
// Write packet to TUN device // Write packet to TUN device via MiddleDevice
// offset=16 indicates packet data starts at position 16 in the buffer // offset=16 indicates packet data starts at position 16 in the buffer
_, err := p.tunDevice.Write([][]byte{buf}, offset) _, err := p.middleDevice.WriteToTun([][]byte{buf}, offset)
if err != nil { if err != nil {
logger.Error("Failed to write DNS response to TUN: %v", err) logger.Error("Failed to write DNS response to TUN: %v", err)
} }

View File

@@ -1,6 +1,7 @@
package dns package dns
import ( import (
"fmt"
"net" "net"
"strings" "strings"
"sync" "sync"
@@ -14,15 +15,17 @@ type RecordType uint16
const ( const (
RecordTypeA RecordType = RecordType(dns.TypeA) RecordTypeA RecordType = RecordType(dns.TypeA)
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA) RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
RecordTypePTR RecordType = RecordType(dns.TypePTR)
) )
// DNSRecordStore manages local DNS records for A and AAAA queries // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries
type DNSRecordStore struct { type DNSRecordStore struct {
mu sync.RWMutex mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses aRecords map[string][]net.IP // domain -> list of IPv4 addresses
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
ptrRecords map[string]string // IP address string -> domain name
} }
// NewDNSRecordStore creates a new DNS record store // NewDNSRecordStore creates a new DNS record store
@@ -32,6 +35,7 @@ func NewDNSRecordStore() *DNSRecordStore {
aaaaRecords: make(map[string][]net.IP), aaaaRecords: make(map[string][]net.IP),
aWildcards: make(map[string][]net.IP), aWildcards: make(map[string][]net.IP),
aaaaWildcards: make(map[string][]net.IP), aaaaWildcards: make(map[string][]net.IP),
ptrRecords: make(map[string]string),
} }
} }
@@ -39,6 +43,7 @@ func NewDNSRecordStore() *DNSRecordStore {
// domain should be in FQDN format (e.g., "example.com.") // domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char) // domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
// Automatically adds a corresponding PTR record for non-wildcard domains
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -48,8 +53,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
domain = domain + "." domain = domain + "."
} }
// Normalize domain to lowercase // Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain) domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards // Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?") isWildcard := strings.ContainsAny(domain, "*?")
@@ -60,6 +65,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.aWildcards[domain] = append(s.aWildcards[domain], ip) s.aWildcards[domain] = append(s.aWildcards[domain], ip)
} else { } else {
s.aRecords[domain] = append(s.aRecords[domain], ip) s.aRecords[domain] = append(s.aRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
} }
} else if ip.To16() != nil { } else if ip.To16() != nil {
// IPv6 address // IPv6 address
@@ -67,6 +74,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip) s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
} else { } else {
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip) s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
} }
} else { } else {
return &net.ParseError{Type: "IP address", Text: ip.String()} return &net.ParseError{Type: "IP address", Text: ip.String()}
@@ -75,8 +84,30 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
return nil return nil
} }
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
// ip should be a valid IPv4 or IPv6 address
// domain should be in FQDN format (e.g., "example.com.")
func (s *DNSRecordStore) AddPTRRecord(ip net.IP, domain string) error {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Store PTR record using IP string as key
s.ptrRecords[ip.String()] = domain
return nil
}
// RemoveRecord removes a specific DNS record mapping // RemoveRecord removes a specific DNS record mapping
// If ip is nil, removes all records for the domain (including wildcards) // If ip is nil, removes all records for the domain (including wildcards)
// Automatically removes corresponding PTR records for non-wildcard domains
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -86,8 +117,8 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
domain = domain + "." domain = domain + "."
} }
// Normalize domain to lowercase // Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain) domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards // Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?") isWildcard := strings.ContainsAny(domain, "*?")
@@ -98,6 +129,23 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
delete(s.aWildcards, domain) delete(s.aWildcards, domain)
delete(s.aaaaWildcards, domain) delete(s.aaaaWildcards, domain)
} else { } else {
// For non-wildcard domains, remove PTR records for all IPs
if ips, ok := s.aRecords[domain]; ok {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
if ips, ok := s.aaaaRecords[domain]; ok {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
delete(s.aRecords, domain) delete(s.aRecords, domain)
delete(s.aaaaRecords, domain) delete(s.aaaaRecords, domain)
} }
@@ -119,6 +167,10 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
if len(s.aRecords[domain]) == 0 { if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain) delete(s.aRecords, domain)
} }
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
} }
} }
} else if ip.To16() != nil { } else if ip.To16() != nil {
@@ -136,11 +188,23 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
if len(s.aaaaRecords[domain]) == 0 { if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain) delete(s.aaaaRecords, domain)
} }
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
} }
} }
} }
} }
// RemovePTRRecord removes a PTR record for an IP address
func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.ptrRecords, ip.String())
}
// GetRecords returns all IP addresses for a domain and record type // GetRecords returns all IP addresses for a domain and record type
// First checks for exact matches, then checks wildcard patterns // First checks for exact matches, then checks wildcard patterns
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
@@ -148,7 +212,7 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN // Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain) domain = strings.ToLower(dns.Fqdn(domain))
var records []net.IP var records []net.IP
switch recordType { switch recordType {
@@ -198,6 +262,26 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
return records return records
} }
// GetPTRRecord returns the domain name for a PTR record query
// domain should be in reverse DNS format (e.g., "1.0.0.127.in-addr.arpa.")
func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
// Convert reverse DNS format to IP address
ip := reverseDNSToIP(domain)
if ip == nil {
return "", false
}
// Look up the PTR record
if ptrDomain, ok := s.ptrRecords[ip.String()]; ok {
return ptrDomain, true
}
return "", false
}
// HasRecord checks if a domain has any records of the specified type // HasRecord checks if a domain has any records of the specified type
// Checks both exact matches and wildcard patterns // Checks both exact matches and wildcard patterns
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
@@ -205,7 +289,7 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN // Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain) domain = strings.ToLower(dns.Fqdn(domain))
switch recordType { switch recordType {
case RecordTypeA: case RecordTypeA:
@@ -235,6 +319,21 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
return false return false
} }
// HasPTRRecord checks if a PTR record exists for the given reverse DNS domain
func (s *DNSRecordStore) HasPTRRecord(domain string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
// Convert reverse DNS format to IP address
ip := reverseDNSToIP(domain)
if ip == nil {
return false
}
_, ok := s.ptrRecords[ip.String()]
return ok
}
// Clear removes all records from the store // Clear removes all records from the store
func (s *DNSRecordStore) Clear() { func (s *DNSRecordStore) Clear() {
s.mu.Lock() s.mu.Lock()
@@ -244,6 +343,7 @@ func (s *DNSRecordStore) Clear() {
s.aaaaRecords = make(map[string][]net.IP) s.aaaaRecords = make(map[string][]net.IP)
s.aWildcards = make(map[string][]net.IP) s.aWildcards = make(map[string][]net.IP)
s.aaaaWildcards = make(map[string][]net.IP) s.aaaaWildcards = make(map[string][]net.IP)
s.ptrRecords = make(map[string]string)
} }
// removeIP is a helper function to remove a specific IP from a slice // removeIP is a helper function to remove a specific IP from a slice
@@ -322,4 +422,76 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool {
} }
return matchWildcardInternal(pattern, domain, pi+1, di+1) return matchWildcardInternal(pattern, domain, pi+1, di+1)
}
// reverseDNSToIP converts a reverse DNS query name to an IP address
// Supports both IPv4 (in-addr.arpa) and IPv6 (ip6.arpa) formats
func reverseDNSToIP(domain string) net.IP {
// Normalize to lowercase and ensure FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Check for IPv4 reverse DNS (in-addr.arpa)
if strings.HasSuffix(domain, ".in-addr.arpa.") {
// Remove the suffix
ipPart := strings.TrimSuffix(domain, ".in-addr.arpa.")
// Split by dots and reverse
parts := strings.Split(ipPart, ".")
if len(parts) != 4 {
return nil
}
// Reverse the octets
reversed := make([]string, 4)
for i := 0; i < 4; i++ {
reversed[i] = parts[3-i]
}
// Parse as IP
return net.ParseIP(strings.Join(reversed, "."))
}
// Check for IPv6 reverse DNS (ip6.arpa)
if strings.HasSuffix(domain, ".ip6.arpa.") {
// Remove the suffix
ipPart := strings.TrimSuffix(domain, ".ip6.arpa.")
// Split by dots and reverse
parts := strings.Split(ipPart, ".")
if len(parts) != 32 {
return nil
}
// Reverse the nibbles and group into 16-bit hex values
reversed := make([]string, 32)
for i := 0; i < 32; i++ {
reversed[i] = parts[31-i]
}
// Join into IPv6 format (groups of 4 nibbles separated by colons)
var ipv6Parts []string
for i := 0; i < 32; i += 4 {
ipv6Parts = append(ipv6Parts, reversed[i]+reversed[i+1]+reversed[i+2]+reversed[i+3])
}
// Parse as IP
return net.ParseIP(strings.Join(ipv6Parts, ":"))
}
return nil
}
// IPToReverseDNS converts an IP address to reverse DNS format
// Returns the domain name for PTR queries (e.g., "1.0.0.127.in-addr.arpa.")
func IPToReverseDNS(ip net.IP) string {
if ip4 := ip.To4(); ip4 != nil {
// IPv4: reverse octets and append .in-addr.arpa.
return dns.Fqdn(fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa",
ip4[3], ip4[2], ip4[1], ip4[0]))
}
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
// IPv6: expand to 32 nibbles, reverse, and append .ip6.arpa.
var nibbles []string
for i := 15; i >= 0; i-- {
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]&0x0f))
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]>>4))
}
return dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
}
return ""
} }

View File

@@ -37,7 +37,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "autoco.internal.", domain: "autoco.internal.",
expected: false, expected: false,
}, },
// Question mark wildcard tests // Question mark wildcard tests
{ {
name: "host-0?.autoco.internal matches host-01.autoco.internal", name: "host-0?.autoco.internal matches host-01.autoco.internal",
@@ -63,7 +63,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-012.autoco.internal.", domain: "host-012.autoco.internal.",
expected: false, expected: false,
}, },
// Combined wildcard tests // Combined wildcard tests
{ {
name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal", name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal",
@@ -83,7 +83,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-01.autoco.internal.", domain: "host-01.autoco.internal.",
expected: false, expected: false,
}, },
// Multiple asterisks // Multiple asterisks
{ {
name: "*.*. autoco.internal matches any.thing.autoco.internal", name: "*.*. autoco.internal matches any.thing.autoco.internal",
@@ -97,7 +97,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "single.autoco.internal.", domain: "single.autoco.internal.",
expected: false, expected: false,
}, },
// Asterisk in middle // Asterisk in middle
{ {
name: "host-*.autoco.internal matches host-anything.autoco.internal", name: "host-*.autoco.internal matches host-anything.autoco.internal",
@@ -111,7 +111,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-.autoco.internal.", domain: "host-.autoco.internal.",
expected: true, expected: true,
}, },
// Multiple question marks // Multiple question marks
{ {
name: "host-??.autoco.internal matches host-01.autoco.internal", name: "host-??.autoco.internal matches host-01.autoco.internal",
@@ -125,7 +125,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-1.autoco.internal.", domain: "host-1.autoco.internal.",
expected: false, expected: false,
}, },
// Exact match (no wildcards) // Exact match (no wildcards)
{ {
name: "exact.autoco.internal matches exact.autoco.internal", name: "exact.autoco.internal matches exact.autoco.internal",
@@ -139,7 +139,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "other.autoco.internal.", domain: "other.autoco.internal.",
expected: false, expected: false,
}, },
// Edge cases // Edge cases
{ {
name: "* matches anything", name: "* matches anything",
@@ -154,7 +154,7 @@ func TestWildcardMatching(t *testing.T) {
expected: true, expected: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := matchWildcard(tt.pattern, tt.domain) result := matchWildcard(tt.pattern, tt.domain)
@@ -167,21 +167,21 @@ func TestWildcardMatching(t *testing.T) {
func TestDNSRecordStoreWildcard(t *testing.T) { func TestDNSRecordStoreWildcard(t *testing.T) {
store := NewDNSRecordStore() store := NewDNSRecordStore()
// Add wildcard records // Add wildcard records
wildcardIP := net.ParseIP("10.0.0.1") wildcardIP := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", wildcardIP) err := store.AddRecord("*.autoco.internal", wildcardIP)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
// Add exact record // Add exact record
exactIP := net.ParseIP("10.0.0.2") exactIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("exact.autoco.internal", exactIP) err = store.AddRecord("exact.autoco.internal", exactIP)
if err != nil { if err != nil {
t.Fatalf("Failed to add exact record: %v", err) t.Fatalf("Failed to add exact record: %v", err)
} }
// Test exact match takes precedence // Test exact match takes precedence
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
@@ -190,7 +190,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
if !ips[0].Equal(exactIP) { if !ips[0].Equal(exactIP) {
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
} }
// Test wildcard match // Test wildcard match
ips = store.GetRecords("host.autoco.internal.", RecordTypeA) ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
@@ -199,7 +199,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
if !ips[0].Equal(wildcardIP) { if !ips[0].Equal(wildcardIP) {
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
} }
// Test non-match (base domain) // Test non-match (base domain)
ips = store.GetRecords("autoco.internal.", RecordTypeA) ips = store.GetRecords("autoco.internal.", RecordTypeA)
if len(ips) != 0 { if len(ips) != 0 {
@@ -209,14 +209,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
func TestDNSRecordStoreComplexWildcard(t *testing.T) { func TestDNSRecordStoreComplexWildcard(t *testing.T) {
store := NewDNSRecordStore() store := NewDNSRecordStore()
// Add complex wildcard pattern // Add complex wildcard pattern
ip1 := net.ParseIP("10.0.0.1") ip1 := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.host-0?.autoco.internal", ip1) err := store.AddRecord("*.host-0?.autoco.internal", ip1)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
// Test matching domain // Test matching domain
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
@@ -225,13 +225,13 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
if len(ips) > 0 && !ips[0].Equal(ip1) { if len(ips) > 0 && !ips[0].Equal(ip1) {
t.Errorf("Expected IP %v, got %v", ip1, ips[0]) t.Errorf("Expected IP %v, got %v", ip1, ips[0])
} }
// Test non-matching domain (missing prefix) // Test non-matching domain (missing prefix)
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
} }
// Test non-matching domain (wrong ? position) // Test non-matching domain (wrong ? position)
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
if len(ips) != 0 { if len(ips) != 0 {
@@ -241,23 +241,23 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
func TestDNSRecordStoreRemoveWildcard(t *testing.T) { func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
store := NewDNSRecordStore() store := NewDNSRecordStore()
// Add wildcard record // Add wildcard record
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
// Verify it exists // Verify it exists
ips := store.GetRecords("host.autoco.internal.", RecordTypeA) ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP before removal, got %d", len(ips)) t.Errorf("Expected 1 IP before removal, got %d", len(ips))
} }
// Remove wildcard record // Remove wildcard record
store.RemoveRecord("*.autoco.internal", nil) store.RemoveRecord("*.autoco.internal", nil)
// Verify it's gone // Verify it's gone
ips = store.GetRecords("host.autoco.internal.", RecordTypeA) ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 0 { if len(ips) != 0 {
@@ -267,40 +267,40 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
func TestDNSRecordStoreMultipleWildcards(t *testing.T) { func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
store := NewDNSRecordStore() store := NewDNSRecordStore()
// Add multiple wildcard patterns that don't overlap // Add multiple wildcard patterns that don't overlap
ip1 := net.ParseIP("10.0.0.1") ip1 := net.ParseIP("10.0.0.1")
ip2 := net.ParseIP("10.0.0.2") ip2 := net.ParseIP("10.0.0.2")
ip3 := net.ParseIP("10.0.0.3") ip3 := net.ParseIP("10.0.0.3")
err := store.AddRecord("*.prod.autoco.internal", ip1) err := store.AddRecord("*.prod.autoco.internal", ip1)
if err != nil { if err != nil {
t.Fatalf("Failed to add first wildcard: %v", err) t.Fatalf("Failed to add first wildcard: %v", err)
} }
err = store.AddRecord("*.dev.autoco.internal", ip2) err = store.AddRecord("*.dev.autoco.internal", ip2)
if err != nil { if err != nil {
t.Fatalf("Failed to add second wildcard: %v", err) t.Fatalf("Failed to add second wildcard: %v", err)
} }
// Add a broader wildcard that matches both // Add a broader wildcard that matches both
err = store.AddRecord("*.autoco.internal", ip3) err = store.AddRecord("*.autoco.internal", ip3)
if err != nil { if err != nil {
t.Fatalf("Failed to add third wildcard: %v", err) t.Fatalf("Failed to add third wildcard: %v", err)
} }
// Test domain matching only the prod pattern and the broad pattern // Test domain matching only the prod pattern and the broad pattern
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
if len(ips) != 2 { if len(ips) != 2 {
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
} }
// Test domain matching only the dev pattern and the broad pattern // Test domain matching only the dev pattern and the broad pattern
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
if len(ips) != 2 { if len(ips) != 2 {
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
} }
// Test domain matching only the broad pattern // Test domain matching only the broad pattern
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
@@ -310,14 +310,14 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
store := NewDNSRecordStore() store := NewDNSRecordStore()
// Add IPv6 wildcard record // Add IPv6 wildcard record
ip := net.ParseIP("2001:db8::1") ip := net.ParseIP("2001:db8::1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip)
if err != nil { if err != nil {
t.Fatalf("Failed to add IPv6 wildcard record: %v", err) t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
} }
// Test wildcard match for IPv6 // Test wildcard match for IPv6
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
if len(ips) != 1 { if len(ips) != 1 {
@@ -330,21 +330,535 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
func TestHasRecordWildcard(t *testing.T) { func TestHasRecordWildcard(t *testing.T) {
store := NewDNSRecordStore() store := NewDNSRecordStore()
// Add wildcard record // Add wildcard record
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
// Test HasRecord with wildcard match // Test HasRecord with wildcard match
if !store.HasRecord("host.autoco.internal.", RecordTypeA) { if !store.HasRecord("host.autoco.internal.", RecordTypeA) {
t.Error("Expected HasRecord to return true for wildcard match") t.Error("Expected HasRecord to return true for wildcard match")
} }
// Test HasRecord with non-match // Test HasRecord with non-match
if store.HasRecord("autoco.internal.", RecordTypeA) { if store.HasRecord("autoco.internal.", RecordTypeA) {
t.Error("Expected HasRecord to return false for base domain") t.Error("Expected HasRecord to return false for base domain")
} }
} }
func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
store := NewDNSRecordStore()
// Add record with mixed case
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("MyHost.AutoCo.Internal", ip)
if err != nil {
t.Fatalf("Failed to add mixed case record: %v", err)
}
// Test lookup with different cases
testCases := []string{
"myhost.autoco.internal.",
"MYHOST.AUTOCO.INTERNAL.",
"MyHost.AutoCo.Internal.",
"mYhOsT.aUtOcO.iNtErNaL.",
}
for _, domain := range testCases {
ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
}
if len(ips) > 0 && !ips[0].Equal(ip) {
t.Errorf("Expected IP %v for domain %q, got %v", ip, domain, ips[0])
}
}
// Test wildcard with mixed case
wildcardIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("*.Example.Com", wildcardIP)
if err != nil {
t.Fatalf("Failed to add mixed case wildcard: %v", err)
}
wildcardTestCases := []string{
"host.example.com.",
"HOST.EXAMPLE.COM.",
"Host.Example.Com.",
"HoSt.ExAmPlE.CoM.",
}
for _, domain := range wildcardTestCases {
ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
}
if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
t.Errorf("Expected IP %v for wildcard domain %q, got %v", wildcardIP, domain, ips[0])
}
}
// Test removal with different case
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
}
// Test HasRecord with different case
if !store.HasRecord("HOST.EXAMPLE.COM.", RecordTypeA) {
t.Error("Expected HasRecord to return true for mixed case wildcard match")
}
}
func TestPTRRecordIPv4(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record for IPv4
ip := net.ParseIP("192.168.1.1")
domain := "host.example.com."
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Test reverse DNS lookup
reverseDomain := "1.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to be found")
}
if result != domain {
t.Errorf("Expected domain %q, got %q", domain, result)
}
// Test HasPTRRecord
if !store.HasPTRRecord(reverseDomain) {
t.Error("Expected HasPTRRecord to return true")
}
// Test non-existent PTR record
_, ok = store.GetPTRRecord("2.1.168.192.in-addr.arpa.")
if ok {
t.Error("Expected PTR record not to be found for different IP")
}
}
func TestPTRRecordIPv6(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record for IPv6
ip := net.ParseIP("2001:db8::1")
domain := "ipv6host.example.com."
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Test reverse DNS lookup
// 2001:db8::1 = 2001:0db8:0000:0000:0000:0000:0000:0001
// Reverse: 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
reverseDomain := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected IPv6 PTR record to be found")
}
if result != domain {
t.Errorf("Expected domain %q, got %q", domain, result)
}
// Test HasPTRRecord
if !store.HasPTRRecord(reverseDomain) {
t.Error("Expected HasPTRRecord to return true for IPv6")
}
}
func TestRemovePTRRecord(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record
ip := net.ParseIP("10.0.0.1")
domain := "test.example.com."
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Verify it exists
reverseDomain := "1.0.0.10.in-addr.arpa."
_, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to exist before removal")
}
// Remove PTR record
store.RemovePTRRecord(ip)
// Verify it's gone
_, ok = store.GetPTRRecord(reverseDomain)
if ok {
t.Error("Expected PTR record to be removed")
}
// Test HasPTRRecord after removal
if store.HasPTRRecord(reverseDomain) {
t.Error("Expected HasPTRRecord to return false after removal")
}
}
func TestIPToReverseDNS(t *testing.T) {
tests := []struct {
name string
ip string
expected string
}{
{
name: "IPv4 simple",
ip: "192.168.1.1",
expected: "1.1.168.192.in-addr.arpa.",
},
{
name: "IPv4 localhost",
ip: "127.0.0.1",
expected: "1.0.0.127.in-addr.arpa.",
},
{
name: "IPv4 with zeros",
ip: "10.0.0.1",
expected: "1.0.0.10.in-addr.arpa.",
},
{
name: "IPv6 simple",
ip: "2001:db8::1",
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
},
{
name: "IPv6 localhost",
ip: "::1",
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("Failed to parse IP: %s", tt.ip)
}
result := IPToReverseDNS(ip)
if result != tt.expected {
t.Errorf("IPToReverseDNS(%s) = %q, want %q", tt.ip, result, tt.expected)
}
})
}
}
func TestReverseDNSToIP(t *testing.T) {
tests := []struct {
name string
reverseDNS string
expectedIP string
shouldMatch bool
}{
{
name: "IPv4 simple",
reverseDNS: "1.1.168.192.in-addr.arpa.",
expectedIP: "192.168.1.1",
shouldMatch: true,
},
{
name: "IPv4 localhost",
reverseDNS: "1.0.0.127.in-addr.arpa.",
expectedIP: "127.0.0.1",
shouldMatch: true,
},
{
name: "IPv6 simple",
reverseDNS: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
expectedIP: "2001:db8::1",
shouldMatch: true,
},
{
name: "Invalid IPv4 format",
reverseDNS: "1.1.168.in-addr.arpa.",
expectedIP: "",
shouldMatch: false,
},
{
name: "Invalid IPv6 format",
reverseDNS: "1.0.0.0.ip6.arpa.",
expectedIP: "",
shouldMatch: false,
},
{
name: "Not a reverse DNS domain",
reverseDNS: "example.com.",
expectedIP: "",
shouldMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := reverseDNSToIP(tt.reverseDNS)
if tt.shouldMatch {
if result == nil {
t.Errorf("reverseDNSToIP(%q) returned nil, expected IP", tt.reverseDNS)
return
}
expectedIP := net.ParseIP(tt.expectedIP)
if !result.Equal(expectedIP) {
t.Errorf("reverseDNSToIP(%q) = %v, want %v", tt.reverseDNS, result, expectedIP)
}
} else {
if result != nil {
t.Errorf("reverseDNSToIP(%q) = %v, expected nil", tt.reverseDNS, result)
}
}
})
}
}
func TestPTRRecordCaseInsensitive(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record with mixed case domain
ip := net.ParseIP("192.168.1.1")
domain := "MyHost.Example.Com"
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Test lookup with different cases in reverse DNS format
reverseDomain := "1.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to be found")
}
// Domain should be normalized to lowercase
if result != "myhost.example.com." {
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
}
// Test with uppercase reverse DNS
reverseDomainUpper := "1.1.168.192.IN-ADDR.ARPA."
result, ok = store.GetPTRRecord(reverseDomainUpper)
if !ok {
t.Error("Expected PTR record to be found with uppercase reverse DNS")
}
if result != "myhost.example.com." {
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
}
}
func TestClearPTRRecords(t *testing.T) {
store := NewDNSRecordStore()
// Add some PTR records
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.2")
store.AddPTRRecord(ip1, "host1.example.com.")
store.AddPTRRecord(ip2, "host2.example.com.")
// Add some A records too
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"))
// Verify PTR records exist
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
t.Error("Expected PTR record to exist before clear")
}
// Clear all records
store.Clear()
// Verify PTR records are gone
if store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
t.Error("Expected PTR record to be cleared")
}
if store.HasPTRRecord("2.1.168.192.in-addr.arpa.") {
t.Error("Expected PTR record to be cleared")
}
// Verify A records are also gone
if store.HasRecord("test.example.com.", RecordTypeA) {
t.Error("Expected A record to be cleared")
}
}
func TestAutomaticPTRRecordOnAdd(t *testing.T) {
store := NewDNSRecordStore()
// Add an A record - should automatically add PTR record
domain := "host.example.com."
ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip)
if err != nil {
t.Fatalf("Failed to add A record: %v", err)
}
// Verify PTR record was automatically created
reverseDomain := "100.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to be automatically created")
}
if result != domain {
t.Errorf("Expected PTR to point to %q, got %q", domain, result)
}
// Add AAAA record - should also automatically add PTR record
domain6 := "ipv6host.example.com."
ip6 := net.ParseIP("2001:db8::1")
err = store.AddRecord(domain6, ip6)
if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err)
}
// Verify IPv6 PTR record was automatically created
reverseDomain6 := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
result6, ok := store.GetPTRRecord(reverseDomain6)
if !ok {
t.Error("Expected IPv6 PTR record to be automatically created")
}
if result6 != domain6 {
t.Errorf("Expected PTR to point to %q, got %q", domain6, result6)
}
}
func TestAutomaticPTRRecordOnRemove(t *testing.T) {
store := NewDNSRecordStore()
// Add an A record (with automatic PTR)
domain := "host.example.com."
ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain, ip)
// Verify PTR exists
reverseDomain := "100.1.168.192.in-addr.arpa."
if !store.HasPTRRecord(reverseDomain) {
t.Error("Expected PTR record to exist after adding A record")
}
// Remove the A record
store.RemoveRecord(domain, ip)
// Verify PTR was automatically removed
if store.HasPTRRecord(reverseDomain) {
t.Error("Expected PTR record to be automatically removed")
}
// Verify A record is also gone
ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected A record to be removed, got %d records", len(ips))
}
}
func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
store := NewDNSRecordStore()
// Add multiple IPs for the same domain
domain := "host.example.com."
ip1 := net.ParseIP("192.168.1.100")
ip2 := net.ParseIP("192.168.1.101")
store.AddRecord(domain, ip1)
store.AddRecord(domain, ip2)
// Verify both PTR records exist
reverseDomain1 := "100.1.168.192.in-addr.arpa."
reverseDomain2 := "101.1.168.192.in-addr.arpa."
if !store.HasPTRRecord(reverseDomain1) {
t.Error("Expected first PTR record to exist")
}
if !store.HasPTRRecord(reverseDomain2) {
t.Error("Expected second PTR record to exist")
}
// Remove all records for the domain
store.RemoveRecord(domain, nil)
// Verify both PTR records were removed
if store.HasPTRRecord(reverseDomain1) {
t.Error("Expected first PTR record to be removed")
}
if store.HasPTRRecord(reverseDomain2) {
t.Error("Expected second PTR record to be removed")
}
}
func TestNoPTRForWildcardRecords(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard record - should NOT create PTR record
domain := "*.example.com."
ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Verify no PTR record was created
reverseDomain := "100.1.168.192.in-addr.arpa."
_, ok := store.GetPTRRecord(reverseDomain)
if ok {
t.Error("Expected no PTR record for wildcard domain")
}
// Verify wildcard A record exists
if !store.HasRecord("host.example.com.", RecordTypeA) {
t.Error("Expected wildcard A record to exist")
}
}
func TestPTRRecordOverwrite(t *testing.T) {
store := NewDNSRecordStore()
// Add first domain with IP
domain1 := "host1.example.com."
ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain1, ip)
// Verify PTR points to first domain
reverseDomain := "100.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Fatal("Expected PTR record to exist")
}
if result != domain1 {
t.Errorf("Expected PTR to point to %q, got %q", domain1, result)
}
// Add second domain with same IP - should overwrite PTR
domain2 := "host2.example.com."
store.AddRecord(domain2, ip)
// Verify PTR now points to second domain (last one added)
result, ok = store.GetPTRRecord(reverseDomain)
if !ok {
t.Fatal("Expected PTR record to still exist")
}
if result != domain2 {
t.Errorf("Expected PTR to point to %q (overwritten), got %q", domain2, result)
}
// Remove first domain - PTR should remain pointing to second domain
store.RemoveRecord(domain1, ip)
result, ok = store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to still exist after removing first domain")
}
if result != domain2 {
t.Errorf("Expected PTR to still point to %q, got %q", domain2, result)
}
// Remove second domain - PTR should now be gone
store.RemoveRecord(domain2, ip)
_, ok = store.GetPTRRecord(reverseDomain)
if ok {
t.Error("Expected PTR record to be removed after removing second domain")
}
}

View File

@@ -0,0 +1,16 @@
//go:build android
package olm
import "net/netip"
// SetupDNSOverride is a no-op on Android
// Android handles DNS through the VpnService API at the Java/Kotlin layer
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
return nil
}
// RestoreDNSOverride is a no-op on Android
func RestoreDNSOverride() error {
return nil
}

View File

@@ -7,7 +7,6 @@ import (
"net/netip" "net/netip"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform" platform "github.com/fosrl/olm/dns/platform"
) )
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS // SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
// Uses scutil for DNS configuration // Uses scutil for DNS configuration
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
var err error var err error
configurator, err = platform.NewDarwinDNSConfigurator() configurator, err = platform.NewDarwinDNSConfigurator()
if err != nil { if err != nil {
@@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
// Set new DNS servers to point to our proxy // Set new DNS servers to point to our proxy
newDNS := []netip.Addr{ newDNS := []netip.Addr{
dnsProxy.GetProxyIP(), proxyIp,
} }
logger.Info("Setting DNS servers to: %v", newDNS) logger.Info("Setting DNS servers to: %v", newDNS)

View File

@@ -0,0 +1,15 @@
//go:build ios
package olm
import "net/netip"
// SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
return nil
}
// RestoreDNSOverride is a no-op on iOS as DNS configuration is handled by the system
func RestoreDNSOverride() error {
return nil
}

View File

@@ -7,7 +7,6 @@ import (
"net/netip" "net/netip"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform" platform "github.com/fosrl/olm/dns/platform"
) )
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD // SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability // Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
var err error var err error
// Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability // Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability
@@ -32,7 +27,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName) configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName)
if err == nil { if err == nil {
logger.Info("Using systemd-resolved DNS configurator") logger.Info("Using systemd-resolved DNS configurator")
return setDNS(dnsProxy, configurator) return setDNS(proxyIp, configurator)
} }
logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err) logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err)
@@ -40,7 +35,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName) configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName)
if err == nil { if err == nil {
logger.Info("Using NetworkManager DNS configurator") logger.Info("Using NetworkManager DNS configurator")
return setDNS(dnsProxy, configurator) return setDNS(proxyIp, configurator)
} }
logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err) logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err)
@@ -48,7 +43,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName) configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName)
if err == nil { if err == nil {
logger.Info("Using resolvconf DNS configurator") logger.Info("Using resolvconf DNS configurator")
return setDNS(dnsProxy, configurator) return setDNS(proxyIp, configurator)
} }
logger.Warn("Failed to create resolvconf configurator: %v, falling back", err) logger.Warn("Failed to create resolvconf configurator: %v, falling back", err)
} }
@@ -60,11 +55,11 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
} }
logger.Info("Using file-based DNS configurator") logger.Info("Using file-based DNS configurator")
return setDNS(dnsProxy, configurator) return setDNS(proxyIp, configurator)
} }
// setDNS is a helper function to set DNS and log the results // setDNS is a helper function to set DNS and log the results
func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error { func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error {
// Get current DNS servers before changing // Get current DNS servers before changing
currentDNS, err := conf.GetCurrentDNS() currentDNS, err := conf.GetCurrentDNS()
if err != nil { if err != nil {
@@ -75,7 +70,7 @@ func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
// Set new DNS servers to point to our proxy // Set new DNS servers to point to our proxy
newDNS := []netip.Addr{ newDNS := []netip.Addr{
dnsProxy.GetProxyIP(), proxyIp,
} }
logger.Info("Setting DNS servers to: %v", newDNS) logger.Info("Setting DNS servers to: %v", newDNS)

View File

@@ -7,7 +7,6 @@ import (
"net/netip" "net/netip"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform" platform "github.com/fosrl/olm/dns/platform"
) )
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows // SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
// Uses registry-based configuration (automatically extracts interface GUID) // Uses registry-based configuration (automatically extracts interface GUID)
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
var err error var err error
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName) configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
if err != nil { if err != nil {
@@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
// Set new DNS servers to point to our proxy // Set new DNS servers to point to our proxy
newDNS := []netip.Addr{ newDNS := []netip.Addr{
dnsProxy.GetProxyIP(), proxyIp,
} }
logger.Info("Setting DNS servers to: %v", newDNS) logger.Info("Setting DNS servers to: %v", newDNS)

View File

@@ -416,4 +416,4 @@ func (d *DarwinDNSConfigurator) clearState() error {
logger.Debug("Cleared DNS state file") logger.Debug("Cleared DNS state file")
return nil return nil
} }

68
go.mod
View File

@@ -4,74 +4,32 @@ go 1.25
require ( require (
github.com/Microsoft/go-winio v0.6.2 github.com/Microsoft/go-winio v0.6.2
github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06 github.com/fosrl/newt v1.9.0
github.com/godbus/dbus/v5 v5.2.0 github.com/godbus/dbus/v5 v5.2.2
github.com/gorilla/websocket v1.5.3 github.com/gorilla/websocket v1.5.3
github.com/miekg/dns v1.1.68 github.com/miekg/dns v1.1.70
golang.org/x/sys v0.38.0 golang.org/x/sys v0.40.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
software.sslmate.com/src/go-pkcs12 v0.6.0 software.sslmate.com/src/go-pkcs12 v0.7.0
) )
require ( require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v0.3.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.2+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.4.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/btree v1.1.3 // indirect github.com/google/btree v1.1.3 // indirect
github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-cmp v0.7.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/client_golang v1.23.2 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/otlptranslator v0.0.2 // indirect
github.com/prometheus/procfs v0.17.0 // indirect
github.com/vishvananda/netlink v1.3.1 // indirect github.com/vishvananda/netlink v1.3.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect github.com/vishvananda/netns v0.0.5 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect golang.org/x/crypto v0.46.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect
go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0 // indirect
go.opentelemetry.io/otel v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/prometheus v0.60.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/otel/sdk v1.38.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
golang.org/x/mod v0.30.0 // indirect golang.org/x/mod v0.31.0 // indirect
golang.org/x/net v0.47.0 // indirect golang.org/x/net v0.48.0 // indirect
golang.org/x/sync v0.18.0 // indirect golang.org/x/sync v0.19.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/time v0.12.0 // indirect golang.org/x/time v0.12.0 // indirect
golang.org/x/tools v0.39.0 // indirect golang.org/x/tools v0.40.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/grpc v1.76.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )
// To be used ONLY for local development
// replace github.com/fosrl/newt => ../newt

136
go.sum
View File

@@ -1,123 +1,39 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/containerd/errdefs v0.3.0 h1:FSZgGOeK4yuT/+DnF07/Olde/q4KBoMsaamhXxIMDp4=
github.com/containerd/errdefs v0.3.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM=
github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01 h1:VpuI42l4enih//6IFFQDln/B7WukfMePxIRIpXsNe/0=
github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06 h1:xWuCn+gzX0W7bHs/cV/ykNBliisNzNomPR76E4M0dtI=
github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc h1:GN2Lv3MGO7AS6PrRoT6yV5+wkrOpcszoIsO4+4ds248= github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA=
github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk= github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/otlptranslator v0.0.2 h1:+1CdeLVrRQ6Psmhnobldo0kTp96Rj80DRXRd5OSnMEQ=
github.com/prometheus/otlptranslator v0.0.2/go.mod h1:P8AwMgdD7XEr6QRUJ2QWLpiAZTgTE2UYgjlu3svompI=
github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0=
github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg=
go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0 h1:PeBoRj6af6xMI7qCupwFvTbbnd49V7n5YpG6pg8iDYQ=
go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0/go.mod h1:ingqBCtMCe8I4vpz/UVzCW6sxoqgZB37nao91mLQ3Bw=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 h1:vl9obrcoWVKp/lwl8tRE33853I8Xru9HFbw/skNeLs8=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0/go.mod h1:GAXRxmLJcVM3u22IjTg74zWBrRCKq8BnOqUVLodpcpw=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
go.opentelemetry.io/otel/exporters/prometheus v0.60.0 h1:cGtQxGvZbnrWdC2GyjZi0PDKVSLWP/Jocix3QWfXtbo=
go.opentelemetry.io/otel/exporters/prometheus v0.60.0/go.mod h1:hkd1EekxNo69PTV4OWFGZcKQiIqg0RfuWExcPKFvepk=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4=
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A= golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
@@ -126,19 +42,7 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdI
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
google.golang.org/genproto v0.0.0-20230920204549-e6e6cdab5c13 h1:vlzZttNJGVqTsRFU9AmdnrcO1Znh8Ew9kCD//yjigk0=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A=
google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU= software.sslmate.com/src/go-pkcs12 v0.7.0 h1:Db8W44cB54TWD7stUFFSWxdfpdn6fZVcDl0w3R4RVM0=
software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= software.sslmate.com/src/go-pkcs12 v0.7.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=

13
main.go
View File

@@ -10,7 +10,7 @@ import (
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates" "github.com/fosrl/newt/updates"
"github.com/fosrl/olm/olm" olmpkg "github.com/fosrl/olm/olm"
) )
func main() { func main() {
@@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
} }
// Create a new olm.Config struct and copy values from the main config // Create a new olm.Config struct and copy values from the main config
olmConfig := olm.GlobalConfig{ olmConfig := olmpkg.OlmConfig{
LogLevel: config.LogLevel, LogLevel: config.LogLevel,
EnableAPI: config.EnableAPI, EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr, HTTPAddr: config.HTTPAddr,
@@ -219,15 +219,20 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
Agent: "Olm CLI", Agent: "Olm CLI",
OnExit: cancel, // Pass cancel function directly to trigger shutdown OnExit: cancel, // Pass cancel function directly to trigger shutdown
OnTerminated: cancel, OnTerminated: cancel,
PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE
}
olm, err := olmpkg.Init(ctx, olmConfig)
if err != nil {
logger.Fatal("Failed to initialize olm: %v", err)
} }
olm.Init(ctx, olmConfig)
if err := olm.StartApi(); err != nil { if err := olm.StartApi(); err != nil {
logger.Fatal("Failed to start API server: %v", err) logger.Fatal("Failed to start API server: %v", err)
} }
if config.ID != "" && config.Secret != "" && config.Endpoint != "" { if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
tunnelConfig := olm.TunnelConfig{ tunnelConfig := olmpkg.TunnelConfig{
Endpoint: config.Endpoint, Endpoint: config.Endpoint,
ID: config.ID, ID: config.ID,
Secret: config.Secret, Secret: config.Secret,

299
olm/connect.go Normal file
View File

@@ -0,0 +1,299 @@
package olm
import (
"encoding/json"
"fmt"
"os"
"runtime"
"strconv"
"strings"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
olmDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns"
dnsOverride "github.com/fosrl/olm/dns/override"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
)
// OlmErrorData represents the error data sent from the server
type OlmErrorData struct {
Code string `json:"code"`
Message string `json:"message"`
}
func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring connect message")
return
}
var wgData WgData
if o.registered {
logger.Info("Already connected. Ignoring new connection request.")
return
}
if o.stopRegister != nil {
o.stopRegister()
o.stopRegister = nil
}
if o.updateRegister != nil {
o.updateRegister = nil
}
// if there is an existing tunnel then close it
if o.dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
o.dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
o.tdev, err = func() (tun.Device, error) {
if o.tunnelConfig.FileDescriptorTun != 0 {
return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU)
}
ifName := o.tunnelConfig.InterfaceName
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
ifName, err = network.FindUnusedUTUN()
if err != nil {
return nil, err
}
}
return tun.CreateTUN(ifName, o.tunnelConfig.MTU)
}()
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
return
}
// if config.FileDescriptorTun == 0 {
if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
o.tunnelConfig.InterfaceName = realInterfaceName
}
// }
// Wrap TUN device with packet filter for DNS proxy
o.middleDev = olmDevice.NewMiddleDevice(o.tdev)
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
// Use filtered device instead of raw TUN device
o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger))
if o.tunnelConfig.EnableUAPI {
fileUAPI, err := func() (*os.File, error) {
if o.tunnelConfig.FileDescriptorUAPI != 0 {
fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err)
}
return os.NewFile(uintptr(fd), ""), nil
}
return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName)
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() {
for {
conn, err := o.uapiListener.Accept()
if err != nil {
return
}
go o.dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
}
if err = o.dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
// Extract interface IP (strip CIDR notation if present)
interfaceIP := wgData.TunnelIP
if strings.Contains(interfaceIP, "/") {
interfaceIP = strings.Split(interfaceIP, "/")[0]
}
// Create and start DNS proxy
o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP)
if err != nil {
logger.Error("Failed to create DNS proxy: %v", err)
}
if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil {
logger.Error("Failed to o.tunnelConfigure interface: %v", err)
}
if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet
logger.Error("Failed to add route for utility subnet: %v", err)
}
// Create peer manager with integrated peer monitoring
o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
Device: o.dev,
DNSProxy: o.dnsProxy,
InterfaceName: o.tunnelConfig.InterfaceName,
PrivateKey: o.privateKey,
MiddleDev: o.middleDev,
LocalIP: interfaceIP,
SharedBind: o.sharedBind,
WSClient: o.websocket,
APIServer: o.apiServer,
})
for i := range wgData.Sites {
site := wgData.Sites[i]
var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {
siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
}
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
if err := o.peerManager.AddPeer(site); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
logger.Info("Configured peer %s", site.PublicKey)
}
o.peerManager.Start()
if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime
logger.Error("Failed to start DNS proxy: %v", err)
}
if o.tunnelConfig.OverrideDNS {
// Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
logger.Error("Failed to setup DNS override: %v", err)
return
}
network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()})
}
o.apiServer.SetRegistered(true)
o.registered = true
// Start ping monitor now that we are registered and connected
o.websocket.StartPingMonitor()
// Invoke onConnected callback if configured
if o.olmConfig.OnConnected != nil {
go o.olmConfig.OnConnected()
}
logger.Info("WireGuard device created.")
}
func (o *Olm) handleOlmError(msg websocket.WSMessage) {
logger.Debug("Received olm error message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring olm error message")
return
}
var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling olm error data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &errorData); err != nil {
logger.Error("Error unmarshaling olm error data: %v", err)
return
}
logger.Error("Olm error (code: %s): %s", errorData.Code, errorData.Message)
// Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
// Invoke onOlmError callback if configured
if o.olmConfig.OnOlmError != nil {
go o.olmConfig.OnOlmError(errorData.Code, errorData.Message)
}
}
func (o *Olm) handleTerminate(msg websocket.WSMessage) {
logger.Info("Received terminate message")
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring terminate message")
return
}
var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling terminate error data: %v", err)
} else {
if err := json.Unmarshal(jsonData, &errorData); err != nil {
logger.Error("Error unmarshaling terminate error data: %v", err)
} else {
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
if errorData.Code == "TERMINATED_INACTIVITY" {
logger.Info("Ignoring...")
return
}
// Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
}
}
o.apiServer.SetTerminated(true)
o.apiServer.SetConnectionStatus(false)
o.apiServer.SetRegistered(false)
o.apiServer.ClearPeerStatuses()
network.ClearNetworkSettings()
o.Close()
if o.olmConfig.OnTerminated != nil {
go o.olmConfig.OnTerminated()
}
}

365
olm/data.go Normal file
View File

@@ -0,0 +1,365 @@
package olm
import (
"encoding/json"
"time"
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
)
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring add-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var addSubnetsData peers.PeerAdd
if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil {
logger.Error("Error unmarshaling add-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId)
return
}
// Add new subnets
for _, subnet := range addSubnetsData.RemoteSubnets {
if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
}
}
// Add new aliases
for _, alias := range addSubnetsData.Aliases {
if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil {
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
}
}
}
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring remove-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeSubnetsData peers.RemovePeerData
if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil {
logger.Error("Error unmarshaling remove-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId)
return
}
// Remove subnets
for _, subnet := range removeSubnetsData.RemoteSubnets {
if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
}
}
// Remove aliases
for _, alias := range removeSubnetsData.Aliases {
if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil {
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
}
}
}
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring update-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateSubnetsData peers.UpdatePeerData
if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil {
logger.Error("Error unmarshaling update-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId)
return
}
// Add new subnets BEFORE removing old ones to preserve shared subnets
// This ensures that if an old and new subnet are the same on different peers,
// the route won't be temporarily removed
for _, subnet := range updateSubnetsData.NewRemoteSubnets {
if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
}
}
// Remove old subnets after new ones are added
for _, subnet := range updateSubnetsData.OldRemoteSubnets {
if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
}
}
// Add new aliases BEFORE removing old ones to preserve shared IP addresses
// This ensures that if an old and new alias share the same IP, the IP won't be
// temporarily removed from the allowed IPs list
for _, alias := range updateSubnetsData.NewAliases {
if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil {
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
}
}
// Remove old aliases after new ones are added
for _, alias := range updateSubnetsData.OldAliases {
if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil {
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
}
}
logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId)
}
// Handler for syncing peer configuration - reconciles expected state with actual state
func (o *Olm) handleSync(msg websocket.WSMessage) {
logger.Debug("Received sync message: %v", msg.Data)
if !o.registered {
logger.Warn("Not connected, ignoring sync request")
return
}
if o.peerManager == nil {
logger.Warn("Peer manager not initialized, ignoring sync request")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
var syncData SyncData
if err := json.Unmarshal(jsonData, &syncData); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
// Sync exit nodes for hole punching
o.syncExitNodes(syncData.ExitNodes)
// Build a map of expected peers from the incoming data
expectedPeers := make(map[int]peers.SiteConfig)
for _, site := range syncData.Sites {
expectedPeers[site.SiteId] = site
}
// Get all current peers
currentPeers := o.peerManager.GetAllPeers()
currentPeerMap := make(map[int]peers.SiteConfig)
for _, peer := range currentPeers {
currentPeerMap[peer.SiteId] = peer
}
// Find peers to remove (in current but not in expected)
for siteId := range currentPeerMap {
if _, exists := expectedPeers[siteId]; !exists {
logger.Info("Sync: Removing peer for site %d (no longer in expected config)", siteId)
if err := o.peerManager.RemovePeer(siteId); err != nil {
logger.Error("Sync: Failed to remove peer %d: %v", siteId, err)
} else {
// Remove any exit nodes associated with this peer from hole punching
if o.holePunchManager != nil {
removed := o.holePunchManager.RemoveExitNodesByPeer(siteId)
if removed > 0 {
logger.Info("Sync: Removed %d exit nodes associated with peer %d from hole punch rotation", removed, siteId)
}
}
}
}
}
// Find peers to add (in expected but not in current) and peers to update
for siteId, expectedSite := range expectedPeers {
if _, exists := currentPeerMap[siteId]; !exists {
// New peer - add it using the add flow (with holepunch)
logger.Info("Sync: Adding new peer for site %d", siteId)
o.holePunchManager.TriggerHolePunch()
// // TODO: do we need to send the message to the cloud to add the peer that way?
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
// logger.Error("Sync: Failed to add peer %d: %v", siteId, err)
// } else {
// logger.Info("Sync: Successfully added peer for site %d", siteId)
// }
// add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
}, 1*time.Second, 10)
} else {
// Existing peer - check if update is needed
currentSite := currentPeerMap[siteId]
needsUpdate := false
// Check if any fields have changed
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
needsUpdate = true
}
if expectedSite.RelayEndpoint != "" && expectedSite.RelayEndpoint != currentSite.RelayEndpoint {
needsUpdate = true
}
if expectedSite.PublicKey != "" && expectedSite.PublicKey != currentSite.PublicKey {
needsUpdate = true
}
if expectedSite.ServerIP != "" && expectedSite.ServerIP != currentSite.ServerIP {
needsUpdate = true
}
if expectedSite.ServerPort != 0 && expectedSite.ServerPort != currentSite.ServerPort {
needsUpdate = true
}
// Check remote subnets
if expectedSite.RemoteSubnets != nil && !slicesEqual(expectedSite.RemoteSubnets, currentSite.RemoteSubnets) {
needsUpdate = true
}
// Check aliases
if expectedSite.Aliases != nil && !aliasesEqual(expectedSite.Aliases, currentSite.Aliases) {
needsUpdate = true
}
if needsUpdate {
logger.Info("Sync: Updating peer for site %d", siteId)
// Merge expected data with current data
siteConfig := currentSite
if expectedSite.Endpoint != "" {
siteConfig.Endpoint = expectedSite.Endpoint
}
if expectedSite.RelayEndpoint != "" {
siteConfig.RelayEndpoint = expectedSite.RelayEndpoint
}
if expectedSite.PublicKey != "" {
siteConfig.PublicKey = expectedSite.PublicKey
}
if expectedSite.ServerIP != "" {
siteConfig.ServerIP = expectedSite.ServerIP
}
if expectedSite.ServerPort != 0 {
siteConfig.ServerPort = expectedSite.ServerPort
}
if expectedSite.RemoteSubnets != nil {
siteConfig.RemoteSubnets = expectedSite.RemoteSubnets
}
if expectedSite.Aliases != nil {
siteConfig.Aliases = expectedSite.Aliases
}
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
logger.Error("Sync: Failed to update peer %d: %v", siteId, err)
} else {
// If the endpoint changed, trigger holepunch to refresh NAT mappings
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
logger.Info("Sync: Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", siteId)
o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval()
}
logger.Info("Sync: Successfully updated peer for site %d", siteId)
}
}
}
}
logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers))
}
// syncExitNodes reconciles the expected exit nodes with the current ones in the hole punch manager
func (o *Olm) syncExitNodes(expectedExitNodes []SyncExitNode) {
if o.holePunchManager == nil {
logger.Warn("Hole punch manager not initialized, skipping exit node sync")
return
}
// Build a map of expected exit nodes by endpoint
expectedExitNodeMap := make(map[string]SyncExitNode)
for _, exitNode := range expectedExitNodes {
expectedExitNodeMap[exitNode.Endpoint] = exitNode
}
// Get current exit nodes from hole punch manager
currentExitNodes := o.holePunchManager.GetExitNodes()
currentExitNodeMap := make(map[string]holepunch.ExitNode)
for _, exitNode := range currentExitNodes {
currentExitNodeMap[exitNode.Endpoint] = exitNode
}
// Find exit nodes to remove (in current but not in expected)
for endpoint := range currentExitNodeMap {
if _, exists := expectedExitNodeMap[endpoint]; !exists {
logger.Info("Sync: Removing exit node %s (no longer in expected config)", endpoint)
o.holePunchManager.RemoveExitNode(endpoint)
}
}
// Find exit nodes to add (in expected but not in current)
for endpoint, expectedExitNode := range expectedExitNodeMap {
if _, exists := currentExitNodeMap[endpoint]; !exists {
logger.Info("Sync: Adding new exit node %s", endpoint)
relayPort := expectedExitNode.RelayPort
if relayPort == 0 {
relayPort = 21820 // default relay port
}
hpExitNode := holepunch.ExitNode{
Endpoint: expectedExitNode.Endpoint,
RelayPort: relayPort,
PublicKey: expectedExitNode.PublicKey,
SiteIds: expectedExitNode.SiteIds,
}
if o.holePunchManager.AddExitNode(hpExitNode) {
logger.Info("Sync: Successfully added exit node %s", endpoint)
}
o.holePunchManager.TriggerHolePunch()
}
}
logger.Info("Sync exit nodes completed: processed %d expected exit nodes, had %d current exit nodes", len(expectedExitNodeMap), len(currentExitNodeMap))
}

1393
olm/olm.go

File diff suppressed because it is too large Load Diff

10
olm/olm_unix.go Normal file
View File

@@ -0,0 +1,10 @@
//go:build !windows
package olm
import "syscall"
// closeFD closes a file descriptor in a platform-specific way
func closeFD(fd uint32) error {
return syscall.Close(int(fd))
}

10
olm/olm_windows.go Normal file
View File

@@ -0,0 +1,10 @@
//go:build windows
package olm
import "syscall"
// closeFD closes a file descriptor in a platform-specific way
func closeFD(fd uint32) error {
return syscall.Close(syscall.Handle(fd))
}

282
olm/peer.go Normal file
View File

@@ -0,0 +1,282 @@
package olm
import (
"encoding/json"
"time"
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
)
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring add-peer message")
return
}
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var siteConfig peers.SiteConfig
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfig); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
}
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
logger.Debug("Received remove-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring remove-peer message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeData peers.PeerRemove
if err := json.Unmarshal(jsonData, &removeData); err != nil {
logger.Error("Error unmarshaling remove data: %v", err)
return
}
if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil {
logger.Error("Failed to remove peer: %v", err)
return
}
// Remove any exit nodes associated with this peer from hole punching
if o.holePunchManager != nil {
removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId)
if removed > 0 {
logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId)
}
}
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
}
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
logger.Debug("Received update-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring update-peer message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateData peers.SiteConfig
if err := json.Unmarshal(jsonData, &updateData); err != nil {
logger.Error("Error unmarshaling update data: %v", err)
return
}
// Get existing peer from PeerManager
existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId)
if !exists {
logger.Warn("Peer with site ID %d not found", updateData.SiteId)
return
}
// Create updated site config by merging with existing data
siteConfig := existingPeer
if updateData.Endpoint != "" {
siteConfig.Endpoint = updateData.Endpoint
}
if updateData.RelayEndpoint != "" {
siteConfig.RelayEndpoint = updateData.RelayEndpoint
}
if updateData.PublicKey != "" {
siteConfig.PublicKey = updateData.PublicKey
}
if updateData.ServerIP != "" {
siteConfig.ServerIP = updateData.ServerIP
}
if updateData.ServerPort != 0 {
siteConfig.ServerPort = updateData.ServerPort
}
if updateData.RemoteSubnets != nil {
siteConfig.RemoteSubnets = updateData.RemoteSubnets
}
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
// If the endpoint changed, trigger holepunch to refresh NAT mappings
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
_ = o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval()
}
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
}
func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
logger.Debug("Received relay-peer message: %v", msg.Data)
// Check if peerManager is still valid (may be nil during shutdown)
if o.peerManager == nil {
logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData peers.RelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err)
return
}
// Update HTTP server to mark this peer as using relay
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true)
o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort)
}
func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
logger.Debug("Received unrelay-peer message: %v", msg.Data)
// Check if peerManager is still valid (may be nil during shutdown)
if o.peerManager == nil {
logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData peers.UnRelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false)
o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
}
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
logger.Debug("Received peer-handshake message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring peer-handshake message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling handshake data: %v", err)
return
}
var handshakeData struct {
SiteId int `json:"siteId"`
ExitNode struct {
PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"`
RelayPort uint16 `json:"relayPort"`
} `json:"exitNode"`
}
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
logger.Error("Error unmarshaling handshake data: %v", err)
return
}
// Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
if exists {
logger.Warn("Peer with site ID %d already added", handshakeData.SiteId)
return
}
relayPort := handshakeData.ExitNode.RelayPort
if relayPort == 0 {
relayPort = 21820 // default relay port
}
siteId := handshakeData.SiteId
exitNode := holepunch.ExitNode{
Endpoint: handshakeData.ExitNode.Endpoint,
RelayPort: relayPort,
PublicKey: handshakeData.ExitNode.PublicKey,
SiteIds: []int{siteId},
}
added := o.holePunchManager.AddExitNode(exitNode)
if added {
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
} else {
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
}
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
}, 1*time.Second, 10)
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
}

View File

@@ -12,9 +12,22 @@ type WgData struct {
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
} }
type GlobalConfig struct { type SyncData struct {
Sites []peers.SiteConfig `json:"sites"`
ExitNodes []SyncExitNode `json:"exitNodes"`
}
type SyncExitNode struct {
Endpoint string `json:"endpoint"`
RelayPort uint16 `json:"relayPort"`
PublicKey string `json:"publicKey"`
SiteIds []int `json:"siteIds"`
}
type OlmConfig struct {
// Logging // Logging
LogLevel string LogLevel string
LogFilePath string
// HTTP server // HTTP server
EnableAPI bool EnableAPI bool
@@ -23,11 +36,17 @@ type GlobalConfig struct {
Version string Version string
Agent string Agent string
WakeUpDebounce time.Duration
// Debugging
PprofAddr string // Address to serve pprof on (e.g., "localhost:6060")
// Callbacks // Callbacks
OnRegistered func() OnRegistered func()
OnConnected func() OnConnected func()
OnTerminated func() OnTerminated func()
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403) OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
OnOlmError func(code string, message string) // Called when registration fails
OnExit func() // Called when exit is requested via API OnExit func() // Called when exit is requested via API
} }
@@ -63,5 +82,8 @@ type TunnelConfig struct {
OverrideDNS bool OverrideDNS bool
TunnelDNS bool TunnelDNS bool
InitialFingerprint map[string]any
InitialPostures map[string]any
DisableRelay bool DisableRelay bool
} }

View File

@@ -1,55 +1,47 @@
package olm package olm
import ( import (
"time" "github.com/fosrl/olm/peers"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/olm/websocket"
) )
func sendPing(olm *websocket.Client) error { // slicesEqual compares two string slices for equality (order-independent)
err := olm.SendMessage("olm/ping", map[string]interface{}{ func slicesEqual(a, b []string) bool {
"timestamp": time.Now().Unix(), if len(a) != len(b) {
"userToken": olm.GetConfig().UserToken, return false
})
if err != nil {
logger.Error("Failed to send ping message: %v", err)
return err
} }
logger.Debug("Sent ping message") // Create a map to count occurrences in slice a
return nil counts := make(map[string]int)
} for _, v := range a {
counts[v]++
func keepSendingPing(olm *websocket.Client) {
// Send ping immediately on startup
if err := sendPing(olm); err != nil {
logger.Error("Failed to send initial ping: %v", err)
} else {
logger.Info("Sent initial ping message")
} }
// Check if slice b has the same elements
// Set up ticker for one minute intervals for _, v := range b {
ticker := time.NewTicker(1 * time.Minute) counts[v]--
defer ticker.Stop() if counts[v] < 0 {
return false
for {
select {
case <-stopPing:
logger.Info("Stopping ping messages")
return
case <-ticker.C:
if err := sendPing(olm); err != nil {
logger.Error("Failed to send periodic ping: %v", err)
}
} }
} }
return true
} }
func GetNetworkSettingsJSON() (string, error) { // aliasesEqual compares two Alias slices for equality (order-independent)
return network.GetJSON() func aliasesEqual(a, b []peers.Alias) bool {
} if len(a) != len(b) {
return false
func GetNetworkSettingsIncrementor() int { }
return network.GetIncrementor() // Create a map to count occurrences in slice a (using alias+address as key)
counts := make(map[string]int)
for _, v := range a {
key := v.Alias + "|" + v.AliasAddress
counts[key]++
}
// Check if slice b has the same elements
for _, v := range b {
key := v.Alias + "|" + v.AliasAddress
counts[key]--
if counts[key] < 0 {
return false
}
}
return true
} }

View File

@@ -50,6 +50,8 @@ type PeerManager struct {
// key is the CIDR string, value is a set of siteIds that want this IP // key is the CIDR string, value is a set of siteIds that want this IP
allowedIPClaims map[string]map[int]bool allowedIPClaims map[string]map[int]bool
APIServer *api.API APIServer *api.API
PersistentKeepalive int
} }
// NewPeerManager creates a new PeerManager with an internal PeerMonitor // NewPeerManager creates a new PeerManager with an internal PeerMonitor
@@ -84,6 +86,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
return peer, ok return peer, ok
} }
// GetPeerMonitor returns the internal peer monitor instance
func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor {
pm.mu.RLock()
defer pm.mu.RUnlock()
return pm.peerMonitor
}
func (pm *PeerManager) GetAllPeers() []SiteConfig { func (pm *PeerManager) GetAllPeers() []SiteConfig {
pm.mu.RLock() pm.mu.RLock()
defer pm.mu.RUnlock() defer pm.mu.RUnlock()
@@ -120,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
wgConfig := siteConfig wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
return err return err
} }
@@ -159,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
return nil return nil
} }
// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once
// without recreating them. Returns a map of siteId to error for any peers that failed to update.
func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error {
pm.mu.RLock()
defer pm.mu.RUnlock()
pm.PersistentKeepalive = interval
errors := make(map[int]error)
for siteId, peer := range pm.peers {
err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval)
if err != nil {
errors[siteId] = err
}
}
if len(errors) == 0 {
return nil
}
return errors
}
func (pm *PeerManager) RemovePeer(siteId int) error { func (pm *PeerManager) RemovePeer(siteId int) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
@@ -238,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId) ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
wgConfig := promotedPeer wgConfig := promotedPeer
wgConfig.AllowedIps = ownedIPs wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
} }
} }
@@ -314,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
wgConfig := siteConfig wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
return err return err
} }
@@ -324,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId) promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
promotedWgConfig := promotedPeer promotedWgConfig := promotedPeer
promotedWgConfig.AllowedIps = promotedOwnedIPs promotedWgConfig.AllowedIps = promotedOwnedIPs
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
} }
} }

View File

@@ -31,8 +31,7 @@ type PeerMonitor struct {
monitors map[int]*Client monitors map[int]*Client
mutex sync.Mutex mutex sync.Mutex
running bool running bool
interval time.Duration timeout time.Duration
timeout time.Duration
maxAttempts int maxAttempts int
wsClient *websocket.Client wsClient *websocket.Client
@@ -42,7 +41,7 @@ type PeerMonitor struct {
stack *stack.Stack stack *stack.Stack
ep *channel.Endpoint ep *channel.Endpoint
activePorts map[uint16]bool activePorts map[uint16]bool
portsLock sync.Mutex portsLock sync.RWMutex
nsCtx context.Context nsCtx context.Context
nsCancel context.CancelFunc nsCancel context.CancelFunc
nsWg sync.WaitGroup nsWg sync.WaitGroup
@@ -50,17 +49,26 @@ type PeerMonitor struct {
// Holepunch testing fields // Holepunch testing fields
sharedBind *bind.SharedBind sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester holepunchTester *holepunch.HolepunchTester
holepunchInterval time.Duration
holepunchTimeout time.Duration holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{} holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
// Relay tracking fields // Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
holepunchMaxAttempts int // max consecutive failures before triggering relay holepunchMaxAttempts int // max consecutive failures before triggering relay
holepunchFailures map[int]int // siteID -> consecutive failure count holepunchFailures map[int]int // siteID -> consecutive failure count
// Exponential backoff fields for holepunch monitor
defaultHolepunchMinInterval time.Duration // Minimum interval (initial)
defaultHolepunchMaxInterval time.Duration
holepunchMinInterval time.Duration // Minimum interval (initial)
holepunchMaxInterval time.Duration // Maximum interval (cap for backoff)
holepunchBackoffMultiplier float64 // Multiplier for each stable check
holepunchStableCount map[int]int // siteID -> consecutive stable status count
holepunchCurrentInterval time.Duration // Current interval with backoff applied
// Rapid initial test fields // Rapid initial test fields
rapidTestInterval time.Duration // interval between rapid test attempts rapidTestInterval time.Duration // interval between rapid test attempts
rapidTestTimeout time.Duration // timeout for each rapid test attempt rapidTestTimeout time.Duration // timeout for each rapid test attempt
@@ -78,7 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{ pm := &PeerMonitor{
monitors: make(map[int]*Client), monitors: make(map[int]*Client),
interval: 2 * time.Second, // Default check interval (faster)
timeout: 3 * time.Second, timeout: 3 * time.Second,
maxAttempts: 3, maxAttempts: 3,
wsClient: wsClient, wsClient: wsClient,
@@ -88,7 +95,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
nsCtx: ctx, nsCtx: ctx,
nsCancel: cancel, nsCancel: cancel,
sharedBind: sharedBind, sharedBind: sharedBind,
holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds
holepunchTimeout: 2 * time.Second, // Faster timeout holepunchTimeout: 2 * time.Second, // Faster timeout
holepunchEndpoints: make(map[int]string), holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool), holepunchStatus: make(map[int]bool),
@@ -101,6 +107,15 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
apiServer: apiServer, apiServer: apiServer,
wgConnectionStatus: make(map[int]bool), wgConnectionStatus: make(map[int]bool),
// Exponential backoff settings for holepunch monitor
defaultHolepunchMinInterval: 2 * time.Second,
defaultHolepunchMaxInterval: 30 * time.Second,
holepunchMinInterval: 2 * time.Second,
holepunchMaxInterval: 30 * time.Second,
holepunchBackoffMultiplier: 1.5,
holepunchStableCount: make(map[int]int),
holepunchCurrentInterval: 2 * time.Second,
holepunchUpdateChan: make(chan struct{}, 1),
} }
if err := pm.initNetstack(); err != nil { if err := pm.initNetstack(); err != nil {
@@ -116,41 +131,75 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
} }
// SetInterval changes how frequently peers are checked // SetInterval changes how frequently peers are checked
func (pm *PeerMonitor) SetInterval(interval time.Duration) { func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()
pm.interval = interval
// Update interval for all existing monitors // Update interval for all existing monitors
for _, client := range pm.monitors { for _, client := range pm.monitors {
client.SetPacketInterval(interval) client.SetPacketInterval(minInterval, maxInterval)
} }
logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval)
} }
// SetTimeout changes the timeout for waiting for responses func (pm *PeerMonitor) ResetPeerInterval() {
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()
pm.timeout = timeout // Update interval for all existing monitors
// Update timeout for all existing monitors
for _, client := range pm.monitors { for _, client := range pm.monitors {
client.SetTimeout(timeout) client.ResetPacketInterval()
} }
} }
// SetMaxAttempts changes the maximum number of attempts for TestConnection // SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) SetMaxAttempts(attempts int) { func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
pm.holepunchMinInterval = minInterval
pm.holepunchMaxInterval = maxInterval
// Reset current interval to the new minimum
pm.holepunchCurrentInterval = minInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) {
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()
pm.maxAttempts = attempts return pm.holepunchMinInterval, pm.holepunchMaxInterval
}
// Update max attempts for all existing monitors func (pm *PeerMonitor) ResetPeerHolepunchInterval() {
for _, client := range pm.monitors { pm.mutex.Lock()
client.SetMaxAttempts(attempts) pm.holepunchMinInterval = pm.defaultHolepunchMinInterval
pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval
pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
} }
} }
@@ -169,10 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
return err return err
} }
client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts)
pm.monitors[siteID] = client pm.monitors[siteID] = client
pm.holepunchEndpoints[siteID] = holepunchEndpoint pm.holepunchEndpoints[siteID] = holepunchEndpoint
@@ -470,31 +515,59 @@ func (pm *PeerMonitor) stopHolepunchMonitor() {
logger.Info("Stopped holepunch connection monitor") logger.Info("Stopped holepunch connection monitor")
} }
// runHolepunchMonitor runs the holepunch monitoring loop // runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff
func (pm *PeerMonitor) runHolepunchMonitor() { func (pm *PeerMonitor) runHolepunchMonitor() {
ticker := time.NewTicker(pm.holepunchInterval) pm.mutex.Lock()
defer ticker.Stop() pm.holepunchCurrentInterval = pm.holepunchMinInterval
pm.mutex.Unlock()
// Do initial check immediately timer := time.NewTimer(0) // Fire immediately for initial check
pm.checkHolepunchEndpoints() defer timer.Stop()
for { for {
select { select {
case <-pm.holepunchStopChan: case <-pm.holepunchStopChan:
return return
case <-ticker.C: case <-pm.holepunchUpdateChan:
pm.checkHolepunchEndpoints() // Interval settings changed, reset to minimum
pm.mutex.Lock()
pm.holepunchCurrentInterval = pm.holepunchMinInterval
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
case <-timer.C:
anyStatusChanged := pm.checkHolepunchEndpoints()
pm.mutex.Lock()
if anyStatusChanged {
// Reset to minimum interval on any status change
pm.holepunchCurrentInterval = pm.holepunchMinInterval
} else {
// Apply exponential backoff when stable
newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier)
if newInterval > pm.holepunchMaxInterval {
newInterval = pm.holepunchMaxInterval
}
pm.holepunchCurrentInterval = newInterval
}
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
} }
} }
} }
// checkHolepunchEndpoints tests all holepunch endpoints // checkHolepunchEndpoints tests all holepunch endpoints
func (pm *PeerMonitor) checkHolepunchEndpoints() { // Returns true if any endpoint's status changed
func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
pm.mutex.Lock() pm.mutex.Lock()
// Check if we're still running before doing any work // Check if we're still running before doing any work
if !pm.running { if !pm.running {
pm.mutex.Unlock() pm.mutex.Unlock()
return return false
} }
endpoints := make(map[int]string, len(pm.holepunchEndpoints)) endpoints := make(map[int]string, len(pm.holepunchEndpoints))
for siteID, endpoint := range pm.holepunchEndpoints { for siteID, endpoint := range pm.holepunchEndpoints {
@@ -504,8 +577,10 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
maxAttempts := pm.holepunchMaxAttempts maxAttempts := pm.holepunchMaxAttempts
pm.mutex.Unlock() pm.mutex.Unlock()
anyStatusChanged := false
for siteID, endpoint := range endpoints { for siteID, endpoint := range endpoints {
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) // logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint)
result := pm.holepunchTester.TestEndpoint(endpoint, timeout) result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
pm.mutex.Lock() pm.mutex.Lock()
@@ -529,7 +604,9 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
pm.mutex.Unlock() pm.mutex.Unlock()
// Log status changes // Log status changes
if !exists || previousStatus != result.Success { statusChanged := !exists || previousStatus != result.Success
if statusChanged {
anyStatusChanged = true
if result.Success { if result.Success {
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT) logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
} else { } else {
@@ -562,7 +639,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
pm.mutex.Unlock() pm.mutex.Unlock()
if !stillRunning { if !stillRunning {
return // Stop processing if shutdown is in progress return anyStatusChanged // Stop processing if shutdown is in progress
} }
if !result.Success && !isRelayed && failureCount >= maxAttempts { if !result.Success && !isRelayed && failureCount >= maxAttempts {
@@ -579,6 +656,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
} }
} }
} }
return anyStatusChanged
} }
// GetHolepunchStatus returns the current holepunch status for all endpoints // GetHolepunchStatus returns the current holepunch status for all endpoints
@@ -650,55 +729,55 @@ func (pm *PeerMonitor) Close() {
logger.Debug("PeerMonitor: Cleanup complete") logger.Debug("PeerMonitor: Cleanup complete")
} }
// TestPeer tests connectivity to a specific peer // // TestPeer tests connectivity to a specific peer
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { // func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
pm.mutex.Lock() // pm.mutex.Lock()
client, exists := pm.monitors[siteID] // client, exists := pm.monitors[siteID]
pm.mutex.Unlock() // pm.mutex.Unlock()
if !exists { // if !exists {
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) // return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
} // }
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) // ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
defer cancel() // defer cancel()
connected, rtt := client.TestConnection(ctx) // connected, rtt := client.TestPeerConnection(ctx)
return connected, rtt, nil // return connected, rtt, nil
} // }
// TestAllPeers tests connectivity to all peers // // TestAllPeers tests connectivity to all peers
func (pm *PeerMonitor) TestAllPeers() map[int]struct { // func (pm *PeerMonitor) TestAllPeers() map[int]struct {
Connected bool // Connected bool
RTT time.Duration // RTT time.Duration
} { // } {
pm.mutex.Lock() // pm.mutex.Lock()
peers := make(map[int]*Client, len(pm.monitors)) // peers := make(map[int]*Client, len(pm.monitors))
for siteID, client := range pm.monitors { // for siteID, client := range pm.monitors {
peers[siteID] = client // peers[siteID] = client
} // }
pm.mutex.Unlock() // pm.mutex.Unlock()
results := make(map[int]struct { // results := make(map[int]struct {
Connected bool // Connected bool
RTT time.Duration // RTT time.Duration
}) // })
for siteID, client := range peers { // for siteID, client := range peers {
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) // ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
connected, rtt := client.TestConnection(ctx) // connected, rtt := client.TestPeerConnection(ctx)
cancel() // cancel()
results[siteID] = struct { // results[siteID] = struct {
Connected bool // Connected bool
RTT time.Duration // RTT time.Duration
}{ // }{
Connected: connected, // Connected: connected,
RTT: rtt, // RTT: rtt,
} // }
} // }
return results // return results
} // }
// initNetstack initializes the gvisor netstack // initNetstack initializes the gvisor netstack
func (pm *PeerMonitor) initNetstack() error { func (pm *PeerMonitor) initNetstack() error {
@@ -770,9 +849,9 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool {
} }
// Check if we are listening on this port // Check if we are listening on this port
pm.portsLock.Lock() pm.portsLock.RLock()
active := pm.activePorts[uint16(port)] active := pm.activePorts[uint16(port)]
pm.portsLock.Unlock() pm.portsLock.RUnlock()
if !active { if !active {
return false return false
@@ -803,13 +882,12 @@ func (pm *PeerMonitor) runPacketSender() {
defer pm.nsWg.Done() defer pm.nsWg.Done()
logger.Debug("PeerMonitor: Packet sender goroutine started") logger.Debug("PeerMonitor: Packet sender goroutine started")
// Use a ticker to periodically check for packets without blocking indefinitely
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for { for {
select { // Use blocking ReadContext instead of polling - much more CPU efficient
case <-pm.nsCtx.Done(): // This will block until a packet is available or context is cancelled
pkt := pm.ep.ReadContext(pm.nsCtx)
if pkt == nil {
// Context was cancelled or endpoint closed
logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets") logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets")
// Drain any remaining packets before exiting // Drain any remaining packets before exiting
for { for {
@@ -821,36 +899,28 @@ func (pm *PeerMonitor) runPacketSender() {
} }
logger.Debug("PeerMonitor: Packet sender goroutine exiting") logger.Debug("PeerMonitor: Packet sender goroutine exiting")
return return
case <-ticker.C:
// Try to read packets in batches
for i := 0; i < 10; i++ {
pkt := pm.ep.Read()
if pkt == nil {
break
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
pm.middleDev.InjectOutbound(buf)
}
pkt.DecRef()
}
} }
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
pm.middleDev.InjectOutbound(buf)
}
pkt.DecRef()
} }
} }

View File

@@ -32,10 +32,19 @@ type Client struct {
monitorLock sync.Mutex monitorLock sync.Mutex
connLock sync.Mutex // Protects connection operations connLock sync.Mutex // Protects connection operations
shutdownCh chan struct{} shutdownCh chan struct{}
updateCh chan struct{}
packetInterval time.Duration packetInterval time.Duration
timeout time.Duration timeout time.Duration
maxAttempts int maxAttempts int
dialer Dialer dialer Dialer
// Exponential backoff fields
defaultMinInterval time.Duration // Default minimum interval (initial)
defaultMaxInterval time.Duration // Default maximum interval (cap for backoff)
minInterval time.Duration // Minimum interval (initial)
maxInterval time.Duration // Maximum interval (cap for backoff)
backoffMultiplier float64 // Multiplier for each stable check
stableCountToBackoff int // Number of stable checks before backing off
} }
// Dialer is a function that creates a connection // Dialer is a function that creates a connection
@@ -50,28 +59,59 @@ type ConnectionStatus struct {
// NewClient creates a new connection test client // NewClient creates a new connection test client
func NewClient(serverAddr string, dialer Dialer) (*Client, error) { func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
return &Client{ return &Client{
serverAddr: serverAddr, serverAddr: serverAddr,
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),
packetInterval: 2 * time.Second, updateCh: make(chan struct{}, 1),
timeout: 500 * time.Millisecond, // Timeout for individual packets packetInterval: 2 * time.Second,
maxAttempts: 3, // Default max attempts defaultMinInterval: 2 * time.Second,
dialer: dialer, defaultMaxInterval: 30 * time.Second,
minInterval: 2 * time.Second,
maxInterval: 30 * time.Second,
backoffMultiplier: 1.5,
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
dialer: dialer,
}, nil }, nil
} }
// SetPacketInterval changes how frequently packets are sent in monitor mode // SetPacketInterval changes how frequently packets are sent in monitor mode
func (c *Client) SetPacketInterval(interval time.Duration) { func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) {
c.packetInterval = interval c.monitorLock.Lock()
c.packetInterval = minInterval
c.minInterval = minInterval
c.maxInterval = maxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// Signal the goroutine to apply the new interval if running
if monitorRunning && updateCh != nil {
select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
} }
// SetTimeout changes the timeout for waiting for responses func (c *Client) ResetPacketInterval() {
func (c *Client) SetTimeout(timeout time.Duration) { c.monitorLock.Lock()
c.timeout = timeout c.packetInterval = c.defaultMinInterval
} c.minInterval = c.defaultMinInterval
c.maxInterval = c.defaultMaxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// SetMaxAttempts changes the maximum number of attempts for TestConnection // Signal the goroutine to apply the new interval if running
func (c *Client) SetMaxAttempts(attempts int) { if monitorRunning && updateCh != nil {
c.maxAttempts = attempts select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
} }
// UpdateServerAddr updates the server address and resets the connection // UpdateServerAddr updates the server address and resets the connection
@@ -125,9 +165,10 @@ func (c *Client) ensureConnection() error {
return nil return nil
} }
// TestConnection checks if the connection to the server is working // TestPeerConnection checks if the connection to the server is working
// Returns true if connected, false otherwise // Returns true if connected, false otherwise
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
// logger.Debug("wgtester: testing connection to peer %s", c.serverAddr)
if err := c.ensureConnection(); err != nil { if err := c.ensureConnection(); err != nil {
logger.Warn("Failed to ensure connection: %v", err) logger.Warn("Failed to ensure connection: %v", err)
return false, 0 return false, 0
@@ -138,6 +179,9 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
binary.BigEndian.PutUint32(packet[0:4], magicHeader) binary.BigEndian.PutUint32(packet[0:4], magicHeader)
packet[4] = packetTypeRequest packet[4] = packetTypeRequest
// Reusable response buffer
responseBuffer := make([]byte, packetSize)
// Send multiple attempts as specified // Send multiple attempts as specified
for attempt := 0; attempt < c.maxAttempts; attempt++ { for attempt := 0; attempt < c.maxAttempts; attempt++ {
select { select {
@@ -157,20 +201,17 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
return false, 0 return false, 0
} }
// logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
_, err := c.conn.Write(packet) _, err := c.conn.Write(packet)
if err != nil { if err != nil {
c.connLock.Unlock() c.connLock.Unlock()
logger.Info("Error sending packet: %v", err) logger.Info("Error sending packet: %v", err)
continue continue
} }
// logger.Debug("Successfully sent monitor packet")
// Set read deadline // Set read deadline
c.conn.SetReadDeadline(time.Now().Add(c.timeout)) c.conn.SetReadDeadline(time.Now().Add(c.timeout))
// Wait for response // Wait for response
responseBuffer := make([]byte, packetSize)
n, err := c.conn.Read(responseBuffer) n, err := c.conn.Read(responseBuffer)
c.connLock.Unlock() c.connLock.Unlock()
@@ -211,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
return c.TestConnection(ctx) return c.TestPeerConnection(ctx)
} }
// MonitorCallback is the function type for connection status change callbacks // MonitorCallback is the function type for connection status change callbacks
@@ -238,28 +279,61 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
go func() { go func() {
var lastConnected bool var lastConnected bool
firstRun := true firstRun := true
stableCount := 0
currentInterval := c.minInterval
ticker := time.NewTicker(c.packetInterval) timer := time.NewTimer(currentInterval)
defer ticker.Stop() defer timer.Stop()
for { for {
select { select {
case <-c.shutdownCh: case <-c.shutdownCh:
return return
case <-ticker.C: case <-c.updateCh:
// Interval settings changed, reset to minimum
c.monitorLock.Lock()
currentInterval = c.minInterval
c.monitorLock.Unlock()
// Reset backoff state
stableCount = 0
timer.Reset(currentInterval)
logger.Debug("Packet interval updated, reset to %v", currentInterval)
case <-timer.C:
ctx, cancel := context.WithTimeout(context.Background(), c.timeout) ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
connected, rtt := c.TestConnection(ctx) connected, rtt := c.TestPeerConnection(ctx)
cancel() cancel()
statusChanged := connected != lastConnected
// Callback if status changed or it's the first check // Callback if status changed or it's the first check
if connected != lastConnected || firstRun { if statusChanged || firstRun {
callback(ConnectionStatus{ callback(ConnectionStatus{
Connected: connected, Connected: connected,
RTT: rtt, RTT: rtt,
}) })
lastConnected = connected lastConnected = connected
firstRun = false firstRun = false
// Reset backoff on status change
stableCount = 0
currentInterval = c.minInterval
} else {
// Status is stable, increment counter
stableCount++
// Apply exponential backoff after stable threshold
if stableCount >= c.stableCountToBackoff {
newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier)
if newInterval > c.maxInterval {
newInterval = c.maxInterval
}
currentInterval = newInterval
}
} }
// Reset timer with current interval
timer.Reset(currentInterval)
} }
} }
}() }()

View File

@@ -11,7 +11,7 @@ import (
) )
// ConfigurePeer sets up or updates a peer within the WireGuard device // ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error { func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
var endpoint string var endpoint string
if relay && siteConfig.RelayEndpoint != "" { if relay && siteConfig.RelayEndpoint != "" {
endpoint = formatEndpoint(siteConfig.RelayEndpoint) endpoint = formatEndpoint(siteConfig.RelayEndpoint)
@@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
} }
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
configBuilder.WriteString("persistent_keepalive_interval=5\n") configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive))
config := configBuilder.String() config := configBuilder.String()
logger.Debug("Configuring peer with config: %s", config) logger.Debug("Configuring peer with config: %s", config)
@@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [
return nil return nil
} }
// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it
func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error {
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
configBuilder.WriteString("update_only=true\n")
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval))
config := configBuilder.String()
logger.Debug("Updating persistent keepalive for peer with config: %s", config)
err := dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err)
}
return nil
}
func formatEndpoint(endpoint string) string { func formatEndpoint(endpoint string) string {
if strings.Contains(endpoint, ":") { if strings.Contains(endpoint, ":") {
return endpoint return endpoint

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -54,8 +55,9 @@ type ExitNode struct {
} }
type WSMessage struct { type WSMessage struct {
Type string `json:"type"` Type string `json:"type"`
Data interface{} `json:"data"` Data interface{} `json:"data"`
ConfigVersion int `json:"configVersion,omitempty"`
} }
// this is not json anymore // this is not json anymore
@@ -77,6 +79,7 @@ type Client struct {
handlersMux sync.RWMutex handlersMux sync.RWMutex
reconnectInterval time.Duration reconnectInterval time.Duration
isConnected bool isConnected bool
isDisconnected bool // Flag to track if client is intentionally disconnected
reconnectMux sync.RWMutex reconnectMux sync.RWMutex
pingInterval time.Duration pingInterval time.Duration
pingTimeout time.Duration pingTimeout time.Duration
@@ -87,6 +90,19 @@ type Client struct {
clientType string // Type of client (e.g., "newt", "olm") clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig tlsConfig TLSConfig
configNeedsSave bool // Flag to track if config needs to be saved configNeedsSave bool // Flag to track if config needs to be saved
configVersion int // Latest config version received from server
configVersionMux sync.RWMutex
token string // Cached authentication token
exitNodes []ExitNode // Cached exit nodes from token response
tokenMux sync.RWMutex // Protects token and exitNodes
forceNewToken bool // Flag to force fetching a new token on next connection
processingMessage bool // Flag to track if a message is currently being processed
processingMux sync.RWMutex // Protects processingMessage
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
getPingData func() map[string]any // Callback to get additional ping data
pingStarted bool // Flag to track if ping monitor has been started
pingStartedMux sync.Mutex // Protects pingStarted
pingDone chan struct{} // Channel to stop the ping monitor independently
} }
type ClientOption func(*Client) type ClientOption func(*Client)
@@ -122,6 +138,13 @@ func WithTLSConfig(config TLSConfig) ClientOption {
} }
} }
// WithPingDataProvider sets a callback to provide additional data for ping messages
func WithPingDataProvider(fn func() map[string]any) ClientOption {
return func(c *Client) {
c.getPingData = fn
}
}
func (c *Client) OnConnect(callback func() error) { func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback c.onConnect = callback
} }
@@ -154,6 +177,7 @@ func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.
pingInterval: pingInterval, pingInterval: pingInterval,
pingTimeout: pingTimeout, pingTimeout: pingTimeout,
clientType: "olm", clientType: "olm",
pingDone: make(chan struct{}),
} }
// Apply options before loading config // Apply options before loading config
@@ -173,6 +197,9 @@ func (c *Client) GetConfig() *Config {
// Connect establishes the WebSocket connection // Connect establishes the WebSocket connection
func (c *Client) Connect() error { func (c *Client) Connect() error {
if c.isDisconnected {
c.isDisconnected = false
}
go c.connectWithRetry() go c.connectWithRetry()
return nil return nil
} }
@@ -205,9 +232,31 @@ func (c *Client) Close() error {
return nil return nil
} }
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
func (c *Client) Disconnect() error {
c.isDisconnected = true
c.setConnected(false)
// Stop the ping monitor
c.stopPingMonitor()
// Wait for any message currently being processed to complete
c.processingWg.Wait()
if c.conn != nil {
c.writeMux.Lock()
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.writeMux.Unlock()
err := c.conn.Close()
c.conn = nil
return err
}
return nil
}
// SendMessage sends a message through the WebSocket connection // SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessage(messageType string, data interface{}) error { func (c *Client) SendMessage(messageType string, data interface{}) error {
if c.conn == nil { if c.isDisconnected || c.conn == nil {
return fmt.Errorf("not connected") return fmt.Errorf("not connected")
} }
@@ -216,14 +265,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
Data: data, Data: data,
} }
logger.Debug("Sending message: %s, data: %+v", messageType, data) logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
c.writeMux.Lock() c.writeMux.Lock()
defer c.writeMux.Unlock() defer c.writeMux.Unlock()
return c.conn.WriteJSON(msg) return c.conn.WriteJSON(msg)
} }
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) { func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
stopChan := make(chan struct{}) stopChan := make(chan struct{})
updateChan := make(chan interface{}) updateChan := make(chan interface{})
var dataMux sync.Mutex var dataMux sync.Mutex
@@ -231,30 +280,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
go func() { go func() {
count := 0 count := 0
maxAttempts := 10
err := c.SendMessage(messageType, currentData) // Send immediately send := func() {
if err != nil { if c.isDisconnected || c.conn == nil {
logger.Error("Failed to send initial message: %v", err) return
}
err := c.SendMessage(messageType, currentData)
if err != nil {
logger.Error("websocket: Failed to send message: %v", err)
}
count++
} }
count++
send() // Send immediately
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if count >= maxAttempts { if maxAttempts != -1 && count >= maxAttempts {
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return return
} }
dataMux.Lock() dataMux.Lock()
err = c.SendMessage(messageType, currentData) send()
dataMux.Unlock() dataMux.Unlock()
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case newData := <-updateChan: case newData := <-updateChan:
dataMux.Lock() dataMux.Lock()
// Merge newData into currentData if both are maps // Merge newData into currentData if both are maps
@@ -277,6 +328,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
case <-stopChan: case <-stopChan:
return return
} }
// Suspend sending if disconnected
for c.isDisconnected {
select {
case <-stopChan:
return
case <-time.After(500 * time.Millisecond):
}
}
} }
}() }()
return func() { return func() {
@@ -323,7 +382,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
tlsConfig = &tls.Config{} tlsConfig = &tls.Config{}
} }
tlsConfig.InsecureSkipVerify = true tlsConfig.InsecureSkipVerify = true
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
} }
tokenData := map[string]interface{}{ tokenData := map[string]interface{}{
@@ -352,7 +411,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
req.Header.Set("X-CSRF-Token", "x-csrf-protection") req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// print out the request for debugging // print out the request for debugging
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
// Make the request // Make the request
client := &http.Client{} client := &http.Client{}
@@ -369,7 +428,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
// Return AuthError for 401/403 status codes // Return AuthError for 401/403 status codes
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
@@ -385,7 +444,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
var tokenResp TokenResponse var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("Failed to decode token response.") logger.Error("websocket: Failed to decode token response.")
return "", nil, fmt.Errorf("failed to decode token response: %w", err) return "", nil, fmt.Errorf("failed to decode token response: %w", err)
} }
@@ -397,7 +456,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
return "", nil, fmt.Errorf("received empty token from server") return "", nil, fmt.Errorf("received empty token from server")
} }
logger.Debug("Received token: %s", tokenResp.Data.Token) logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
} }
@@ -411,7 +470,8 @@ func (c *Client) connectWithRetry() {
err := c.establishConnection() err := c.establishConnection()
if err != nil { if err != nil {
// Check if this is an auth error (401/403) // Check if this is an auth error (401/403)
if authErr, ok := err.(*AuthError); ok { var authErr *AuthError
if errors.As(err, &authErr) {
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr) logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
// Trigger auth error callback if set (this should terminate the tunnel) // Trigger auth error callback if set (this should terminate the tunnel)
if c.onAuthError != nil { if c.onAuthError != nil {
@@ -422,7 +482,7 @@ func (c *Client) connectWithRetry() {
continue continue
} }
// For other errors (5xx, network issues), continue retrying // For other errors (5xx, network issues), continue retrying
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
time.Sleep(c.reconnectInterval) time.Sleep(c.reconnectInterval)
continue continue
} }
@@ -432,15 +492,25 @@ func (c *Client) connectWithRetry() {
} }
func (c *Client) establishConnection() error { func (c *Client) establishConnection() error {
// Get token for authentication // Get token for authentication - reuse cached token unless forced to get new one
token, exitNodes, err := c.getToken() c.tokenMux.Lock()
if err != nil { needNewToken := c.token == "" || c.forceNewToken
return fmt.Errorf("failed to get token: %w", err) if needNewToken {
} token, exitNodes, err := c.getToken()
if err != nil {
c.tokenMux.Unlock()
return fmt.Errorf("failed to get token: %w", err)
}
c.token = token
c.exitNodes = exitNodes
c.forceNewToken = false
if c.onTokenUpdate != nil { if c.onTokenUpdate != nil {
c.onTokenUpdate(token, exitNodes) c.onTokenUpdate(token, exitNodes)
}
} }
token := c.token
c.tokenMux.Unlock()
// Parse the base URL to determine protocol and hostname // Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL) baseURL, err := url.Parse(c.baseURL)
@@ -475,7 +545,7 @@ func (c *Client) establishConnection() error {
// Use new TLS configuration method // Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
logger.Info("Setting up TLS configuration for WebSocket connection") logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
tlsConfig, err := c.setupTLS() tlsConfig, err := c.setupTLS()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup TLS configuration: %w", err) return fmt.Errorf("failed to setup TLS configuration: %w", err)
@@ -489,25 +559,38 @@ func (c *Client) establishConnection() error {
dialer.TLSClientConfig = &tls.Config{} dialer.TLSClientConfig = &tls.Config{}
} }
dialer.TLSClientConfig.InsecureSkipVerify = true dialer.TLSClientConfig.InsecureSkipVerify = true
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
} }
conn, _, err := dialer.Dial(u.String(), nil) conn, resp, err := dialer.Dial(u.String(), nil)
if err != nil { if err != nil {
// Check if this is an unauthorized error (401)
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized")
// Force getting a new token on next reconnect attempt
c.tokenMux.Lock()
c.forceNewToken = true
c.tokenMux.Unlock()
return &AuthError{
StatusCode: http.StatusUnauthorized,
Message: "WebSocket connection unauthorized",
}
}
return fmt.Errorf("failed to connect to WebSocket: %w", err) return fmt.Errorf("failed to connect to WebSocket: %w", err)
} }
c.conn = conn c.conn = conn
c.setConnected(true) c.setConnected(true)
// Start the ping monitor // Note: ping monitor is NOT started here - it will be started when
go c.pingMonitor() // StartPingMonitor() is called after registration completes
// Start the read pump with disconnect detection // Start the read pump with disconnect detection
go c.readPumpWithDisconnectDetection() go c.readPumpWithDisconnectDetection()
if c.onConnect != nil { if c.onConnect != nil {
if err := c.onConnect(); err != nil { if err := c.onConnect(); err != nil {
logger.Error("OnConnect callback failed: %v", err) logger.Error("websocket: OnConnect callback failed: %v", err)
} }
} }
@@ -520,9 +603,9 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Handle new separate certificate configuration // Handle new separate certificate configuration
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" { if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
logger.Info("Loading separate certificate files for mTLS") logger.Info("websocket: Loading separate certificate files for mTLS")
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile) logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile) logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
// Load client certificate and key // Load client certificate and key
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile) cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
@@ -533,7 +616,7 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Load CA certificates for remote validation if specified // Load CA certificates for remote validation if specified
if len(c.tlsConfig.CAFiles) > 0 { if len(c.tlsConfig.CAFiles) > 0 {
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles) logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
caCertPool := x509.NewCertPool() caCertPool := x509.NewCertPool()
for _, caFile := range c.tlsConfig.CAFiles { for _, caFile := range c.tlsConfig.CAFiles {
caCert, err := os.ReadFile(caFile) caCert, err := os.ReadFile(caFile)
@@ -559,13 +642,13 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Fallback to existing PKCS12 implementation for backward compatibility // Fallback to existing PKCS12 implementation for backward compatibility
if c.tlsConfig.PKCS12File != "" { if c.tlsConfig.PKCS12File != "" {
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)") logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
return c.setupPKCS12TLS() return c.setupPKCS12TLS()
} }
// Legacy fallback using config.TlsClientCert // Legacy fallback using config.TlsClientCert
if c.config.TlsClientCert != "" { if c.config.TlsClientCert != "" {
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)") logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
return loadClientCertificate(c.config.TlsClientCert) return loadClientCertificate(c.config.TlsClientCert)
} }
@@ -577,6 +660,59 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
return loadClientCertificate(c.tlsConfig.PKCS12File) return loadClientCertificate(c.tlsConfig.PKCS12File)
} }
// sendPing sends a single ping message
func (c *Client) sendPing() {
if c.isDisconnected || c.conn == nil {
return
}
// Skip ping if a message is currently being processed
c.processingMux.RLock()
isProcessing := c.processingMessage
c.processingMux.RUnlock()
if isProcessing {
logger.Debug("websocket: Skipping ping, message is being processed")
return
}
// Send application-level ping with config version
c.configVersionMux.RLock()
configVersion := c.configVersion
c.configVersionMux.RUnlock()
pingData := map[string]any{
"timestamp": time.Now().Unix(),
"userToken": c.config.UserToken,
}
if c.getPingData != nil {
for k, v := range c.getPingData() {
pingData[k] = v
}
}
pingMsg := WSMessage{
Type: "olm/ping",
Data: pingData,
ConfigVersion: configVersion,
}
logger.Debug("websocket: Sending ping: %+v", pingMsg)
c.writeMux.Lock()
err := c.conn.WriteJSON(pingMsg)
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("websocket: Ping failed: %v", err)
c.reconnect()
return
}
}
}
// pingMonitor sends pings at a short interval and triggers reconnect on failure // pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) pingMonitor() { func (c *Client) pingMonitor() {
ticker := time.NewTicker(c.pingInterval) ticker := time.NewTicker(c.pingInterval)
@@ -586,29 +722,65 @@ func (c *Client) pingMonitor() {
select { select {
case <-c.done: case <-c.done:
return return
case <-c.pingDone:
return
case <-ticker.C: case <-ticker.C:
if c.conn == nil { c.sendPing()
return
}
c.writeMux.Lock()
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("Ping failed: %v", err)
c.reconnect()
return
}
}
} }
} }
} }
// StartPingMonitor starts the ping monitor goroutine.
// This should be called after the client is registered and connected.
// It is safe to call multiple times - only the first call will start the monitor.
func (c *Client) StartPingMonitor() {
c.pingStartedMux.Lock()
defer c.pingStartedMux.Unlock()
if c.pingStarted {
return
}
c.pingStarted = true
// Create a new pingDone channel for this ping monitor instance
c.pingDone = make(chan struct{})
// Send an initial ping immediately
go func() {
c.sendPing()
c.pingMonitor()
}()
}
// stopPingMonitor stops the ping monitor goroutine if it's running.
func (c *Client) stopPingMonitor() {
c.pingStartedMux.Lock()
defer c.pingStartedMux.Unlock()
if !c.pingStarted {
return
}
// Close the pingDone channel to stop the monitor
close(c.pingDone)
c.pingStarted = false
}
// GetConfigVersion returns the current config version
func (c *Client) GetConfigVersion() int {
c.configVersionMux.RLock()
defer c.configVersionMux.RUnlock()
return c.configVersion
}
// setConfigVersion updates the config version if the new version is higher
func (c *Client) setConfigVersion(version int) {
c.configVersionMux.Lock()
defer c.configVersionMux.Unlock()
logger.Debug("websocket: setting config version to %d", version)
c.configVersion = version
}
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error // readPumpWithDisconnectDetection reads messages and triggers reconnect on error
func (c *Client) readPumpWithDisconnectDetection() { func (c *Client) readPumpWithDisconnectDetection() {
defer func() { defer func() {
@@ -633,26 +805,47 @@ func (c *Client) readPumpWithDisconnectDetection() {
var msg WSMessage var msg WSMessage
err := c.conn.ReadJSON(&msg) err := c.conn.ReadJSON(&msg)
if err != nil { if err != nil {
// Check if we're shutting down before logging error // Check if we're shutting down or explicitly disconnected before logging error
select { select {
case <-c.done: case <-c.done:
// Expected during shutdown, don't log as error // Expected during shutdown, don't log as error
logger.Debug("WebSocket connection closed during shutdown") logger.Debug("websocket: connection closed during shutdown")
return return
default: default:
// Check if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: connection closed: client was explicitly disconnected")
return
}
// Unexpected error during normal operation // Unexpected error during normal operation
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Error("WebSocket read error: %v", err) logger.Error("websocket: read error: %v", err)
} else { } else {
logger.Debug("WebSocket connection closed: %v", err) logger.Debug("websocket: connection closed: %v", err)
} }
return // triggers reconnect via defer return // triggers reconnect via defer
} }
} }
// Update config version from incoming message
c.setConfigVersion(msg.ConfigVersion)
c.handlersMux.RLock() c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok { if handler, ok := c.handlers[msg.Type]; ok {
// Mark that we're processing a message
c.processingMux.Lock()
c.processingMessage = true
c.processingMux.Unlock()
c.processingWg.Add(1)
handler(msg) handler(msg)
// Mark that we're done processing
c.processingWg.Done()
c.processingMux.Lock()
c.processingMessage = false
c.processingMux.Unlock()
} }
c.handlersMux.RUnlock() c.handlersMux.RUnlock()
} }
@@ -666,6 +859,12 @@ func (c *Client) reconnect() {
c.conn = nil c.conn = nil
} }
// Don't reconnect if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
return
}
// Only reconnect if we're not shutting down // Only reconnect if we're not shutting down
select { select {
case <-c.done: case <-c.done:
@@ -683,7 +882,7 @@ func (c *Client) setConnected(status bool) {
// LoadClientCertificate Helper method to load client certificates (PKCS12 format) // LoadClientCertificate Helper method to load client certificates (PKCS12 format)
func loadClientCertificate(p12Path string) (*tls.Config, error) { func loadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("Loading tls-client-cert %s", p12Path) logger.Info("websocket: Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file // Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path) p12Data, err := os.ReadFile(p12Path)
if err != nil { if err != nil {