Compare commits

...

203 Commits

Author SHA1 Message Date
Owen
3d891cfa97 Remove do not create client for now
Because its always created when the user joins the org


Former-commit-id: 8ebc678edb
2025-11-08 16:54:35 -08:00
Owen
78e3bb374a Split out hp
Former-commit-id: 29ed4fefbf
2025-11-07 21:59:07 -08:00
Owen
a61c7ca1ee Custom bind?
Former-commit-id: 6d8e298ebc
2025-11-07 21:39:28 -08:00
Owen
7696ba2e36 Add DoNotCreateNewClient
Former-commit-id: aedebb5579
2025-11-07 15:26:45 -08:00
Owen
235877c379 Add optional user token to validate
Former-commit-id: 5734684a21
2025-11-07 14:52:54 -08:00
Owen
befab0f8d1 Fix passing original arguments
Former-commit-id: 7e5b740514
2025-11-07 14:33:52 -08:00
Owen
914d080a57 Connecting disconnecting working
Former-commit-id: 553010f2ea
2025-11-07 14:31:13 -08:00
Owen
a274b4b38f Starting and stopping working
Former-commit-id: f23f2fb9aa
2025-11-07 14:20:36 -08:00
Owen
ce3c585514 Allow connecting and disconnecting
Former-commit-id: 596c4aa0da
2025-11-07 14:07:44 -08:00
Owen
963d8abad5 Add org id in the status
Former-commit-id: da1e4911bd
2025-11-03 20:54:55 -08:00
Owen
38eb56381f Update switching orgs
Former-commit-id: 690b133c7b
2025-11-03 20:33:06 -08:00
Owen
43b3822090 Allow pasing orgId to select org to connect
Former-commit-id: 46a4847cee
2025-11-03 16:54:38 -08:00
Owen
b0fb370c4d Remove status
Former-commit-id: 352ac8def6
2025-11-03 15:29:18 -08:00
Owen
99328ee76f Add registered to api
Former-commit-id: 9c496f7ca7
2025-11-03 15:16:12 -08:00
Owen
36fc3ea253 Add exit call
Former-commit-id: 4a89915826
2025-11-03 14:15:16 -08:00
Owen
a7979259f3 Make api availble over socket
Former-commit-id: e464af5302
2025-11-02 18:56:09 -08:00
Owen
ea6fa72bc0 Copy in config
Former-commit-id: 3505549331
2025-11-02 12:09:39 -08:00
Owen
f9adde6b1d Rename to run
Former-commit-id: 6f7e866e93
2025-11-01 18:39:53 -07:00
Owen
ba25586646 Import submodule
Former-commit-id: eaf94e6855
2025-11-01 18:37:53 -07:00
Owen
952ab63e8d Package?
Former-commit-id: 218e4f88bc
2025-11-01 18:34:00 -07:00
Owen
5e84f802ed Merge branch 'dev' of github.com:fosrl/olm into dev
Former-commit-id: ea05ac9c71
2025-10-27 21:42:52 -07:00
Owen
f40b0ff820 Use local websocket without config or otel
Former-commit-id: 87e2cf8405
2025-10-27 21:42:41 -07:00
Owen
95a4840374 Use local websocket without config or otel
Former-commit-id: a55803f8f3
2025-10-27 21:39:54 -07:00
Owen
27424170e4 Merge branch 'main' into dev
Former-commit-id: 2489283470
2025-10-27 21:27:21 -07:00
Owen Schwartz
a8ace6f64a Merge pull request #47 from gk1/feature/config-file-persistence
Updated the olm client to process config vars from cli,env,file in or…

Former-commit-id: 1a8ae792d0
2025-10-25 17:31:25 -07:00
Owen
3fa1073f49 Treat mtu as int and dont overwrite from websocket
Former-commit-id: 228bddcf79
2025-10-25 17:15:25 -07:00
Owen Schwartz
76d86c10ff Merge pull request #41 from fosrl/dependabot/go_modules/prod-minor-updates-b47bc3c6cb
Bump the prod-minor-updates group with 2 updates

Former-commit-id: 0f25ba98ac
2025-10-25 16:33:02 -07:00
gk1
2d34c6c8b2 Updated the olm client to process config vars from cli,env,file in order of precedence and persist them to file
Former-commit-id: 555c9dc9f4
2025-10-24 17:04:48 -07:00
dependabot[bot]
a7f3477bdd Bump the prod-minor-updates group with 2 updates
Bumps the prod-minor-updates group with 2 updates: [golang.org/x/crypto](https://github.com/golang/crypto) and [golang.org/x/sys](https://github.com/golang/sys).


Updates `golang.org/x/crypto` from 0.42.0 to 0.43.0
- [Commits](https://github.com/golang/crypto/compare/v0.42.0...v0.43.0)

Updates `golang.org/x/sys` from 0.36.0 to 0.37.0
- [Commits](https://github.com/golang/sys/compare/v0.36.0...v0.37.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.43.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: prod-minor-updates
- dependency-name: golang.org/x/sys
  dependency-version: 0.37.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: prod-minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: ef011f29f7
2025-10-20 20:36:13 +00:00
Owen Schwartz
af0a72d296 feat(actions): Sync Images from Docker to GHCR (#46)
* feat(actions): Sync Images from Docker to GHCR

* fix(actions): Moved mirror action to workflows

Former-commit-id: 2b181db148
2025-10-20 13:16:31 -07:00
Marc Schäfer
d1e836e760 fix(actions): Moved mirror action to workflows
Former-commit-id: c94dc6af69
2025-10-20 22:15:15 +02:00
Marc Schäfer
8dd45c4ca2 feat(actions): Sync Images from Docker to GHCR
Former-commit-id: 6275531187
2025-10-20 22:11:42 +02:00
Owen
9db009058b Merge branch 'main' into dev
Former-commit-id: b12e3bd895
2025-10-19 15:12:47 -07:00
Owen
29c01deb05 Update domains
Former-commit-id: 8629c40e2f
2025-10-19 15:12:37 -07:00
Owen Schwartz
7224d9824d Add update checks, log rotation, and message timeouts (#42)
* Add update checker

* Add timeouts to hp

* Try to fix log rotation and service args

* Dont delete service args file

* GO update

Former-commit-id: 9f3eddbc9c
2025-10-08 17:48:50 -07:00
Owen
8afc28fdff GO update
Former-commit-id: c58f3ac92d
2025-10-08 17:46:40 -07:00
Owen
4ba2fb7b53 Merge branch 'main' into dev
Former-commit-id: d88bfd461e
2025-10-08 17:44:53 -07:00
Owen
2e6076923d Dont delete service args file
Former-commit-id: ff3b5c50fc
2025-10-08 17:35:14 -07:00
Owen
4c001dc751 Try to fix log rotation and service args
Former-commit-id: c5ece2f21f
2025-10-01 10:30:45 -07:00
miloschwartz
2b8e240752 update template
Former-commit-id: 27bffa062d
2025-09-29 16:38:29 -07:00
miloschwartz
bee490713d update templates
Former-commit-id: 3ae134af2d
2025-09-29 16:32:05 -07:00
miloschwartz
1cb7fd94ab add templates
Former-commit-id: c06384b700
2025-09-29 16:29:26 -07:00
Owen
dc9a547950 Add timeouts to hp
Former-commit-id: fa1d2b1f55
2025-09-29 14:55:19 -07:00
Owen
2be0933246 Add update checker
Former-commit-id: 2445ced83b
2025-09-29 14:37:07 -07:00
Owen
c0b1cd6bde Add iss file
Former-commit-id: 44802aae7c
2025-09-28 16:44:36 -07:00
Owen
dd00289f8e Remove the old peer when updating new peer
Former-commit-id: 74b166e82f
2025-09-28 12:25:33 -07:00
Owen
b23a02ee97 Merge branch 'main' into dev
Former-commit-id: 1054d70192
2025-09-28 11:41:26 -07:00
Owen
80f726cfea Update get olm script to work with sudo
Former-commit-id: 323d3cf15e
2025-09-28 11:41:07 -07:00
Owen
aa8828186f Add get olm script
Former-commit-id: 77a38e3dba
2025-09-28 11:33:28 -07:00
Owen
0a990d196d Merge branch 'main' into dev
Former-commit-id: 42bd8d5d4c
2025-09-26 09:38:40 -07:00
Owen Schwartz
00e8050949 Fix pulling config.json (#39)
Former-commit-id: 64e7a20915
2025-09-26 09:37:13 -07:00
Owen
18ee4c93fb Fix pulling config.json
Former-commit-id: 03db7649db
2025-09-25 17:41:00 -07:00
Owen Schwartz
8fa2da00b6 Merge pull request #30 from fosrl/dependabot/github_actions/actions/setup-go-6
Bump actions/setup-go from 5 to 6

Former-commit-id: 35d4f19bd8
2025-09-20 11:46:10 -04:00
Owen Schwartz
b851cd73c9 Merge pull request #34 from fosrl/dependabot/go_modules/golang.org/x/sys-0.36.0
Bump golang.org/x/sys from 0.35.0 to 0.36.0

Former-commit-id: a631b60604
2025-09-20 11:46:04 -04:00
Owen Schwartz
4c19d7ef6d Merge pull request #35 from fosrl/dependabot/go_modules/golang.org/x/crypto-0.42.0
Bump golang.org/x/crypto from 0.41.0 to 0.42.0

Former-commit-id: 1eb87c997d
2025-09-20 11:45:57 -04:00
Owen Schwartz
cbecb9a0ce Merge pull request #28 from kevin-gillet/126-stop-litteral-ipv6-from-being-resolved
fix: holepunch to only active peers and stop litteral ipv6 from being name resolved
Former-commit-id: 5d42fac1d1
2025-09-20 11:45:47 -04:00
dependabot[bot]
4fc8db08ba Bump golang.org/x/crypto from 0.41.0 to 0.42.0
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.41.0 to 0.42.0.
- [Commits](https://github.com/golang/crypto/compare/v0.41.0...v0.42.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.42.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: 6d9d012789
2025-09-19 20:20:13 +00:00
dependabot[bot]
7ca46e0a75 Bump golang.org/x/sys from 0.35.0 to 0.36.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.35.0 to 0.36.0.
- [Commits](https://github.com/golang/sys/compare/v0.35.0...v0.36.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-version: 0.36.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: 10b8ebd3c1
2025-09-19 20:20:10 +00:00
dependabot[bot]
a4ea5143af Bump actions/setup-go from 5 to 6
Bumps [actions/setup-go](https://github.com/actions/setup-go) from 5 to 6.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: f9d51ebb88
2025-09-08 20:44:23 +00:00
FranceNuage
e9257b6423 fix: holepunch to only active peers and stop litteral ipv6 from being treated as hostname and be name resolved
Former-commit-id: 2b41d4c459
2025-09-06 02:39:43 +02:00
Owen Schwartz
3c9d3a1d2c Merge pull request #26 from kevin-gillet/25-fix-olm-ipv6-parsing
fix: add ipv6 endpoint formatter
Former-commit-id: 7448a3127d
2025-09-04 10:37:40 -07:00
FranceNuage
b426f14190 fix: remove comment
Former-commit-id: e669d543c4
2025-09-04 14:16:41 +02:00
FranceNuage
d48acfba39 fix: add ipv6 endpoint formatter
Former-commit-id: 5b443a41a3
2025-09-04 14:09:58 +02:00
Owen
35b48cd8e5 Fix ipv6 issue
Former-commit-id: 8c71647802
2025-09-01 17:16:43 -07:00
Owen
15bca53309 Add docs about compose
Former-commit-id: 5dbfeaa95e
2025-09-01 17:01:17 -07:00
Owen
898b599db5 Merge branch 'main' into dev
Former-commit-id: 892eaff480
2025-09-01 16:58:38 -07:00
Owen Schwartz
c07bba18bb Merge pull request #22 from Lokowitz/add-docker-image
added docker version of olm

Former-commit-id: 9aa4288bfe
2025-08-31 09:55:08 -07:00
Lokowitz
4c24d3b808 added build of docker image to test
Former-commit-id: 82555f409b
2025-08-31 07:33:41 +00:00
Lokowitz
ad4ab3d04f added docker version of olm
Former-commit-id: 0d8cacdb90
2025-08-31 07:22:32 +00:00
Owen
e21153fae1 Fix #9
Former-commit-id: dc3d252660
2025-08-30 21:35:33 -07:00
Owen Schwartz
41c3360e23 Merge pull request #21 from fosrl/dependabot/github_actions/actions/setup-go-5
Bump actions/setup-go from 4 to 5

Former-commit-id: 58b05bbb17
2025-08-30 15:29:16 -07:00
Owen Schwartz
1960d32443 Merge pull request #20 from fosrl/dependabot/github_actions/actions/checkout-5
Bump actions/checkout from 3 to 5

Former-commit-id: 81917195e1
2025-08-30 15:29:08 -07:00
dependabot[bot]
74b83b3303 Bump actions/setup-go from 4 to 5
Bumps [actions/setup-go](https://github.com/actions/setup-go) from 4 to 5.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: c2e72a1c51
2025-08-30 22:28:07 +00:00
dependabot[bot]
c2c3470868 Bump actions/checkout from 3 to 5
Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 5.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v5)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: f19b6f1584
2025-08-30 22:28:03 +00:00
Owen Schwartz
2bda3dc3cc Merge pull request #19 from Lokowitz/update-versions-dependabot
Update versions dependabot

Former-commit-id: bb2226c872
2025-08-30 15:27:26 -07:00
Owen Schwartz
52573c8664 Merge pull request #17 from danohn/main
Fix UDP port conflict during NAT holepunching on macOS

Former-commit-id: 456c66e6f2
2025-08-30 15:15:36 -07:00
danohn
0d3c34e23f Update wait time to 500ms
Former-commit-id: 07d5ebdde1
2025-08-29 13:13:09 +10:00
danohn
891df5c74b Update wait timer to 200ms
Former-commit-id: 0765b4daca
2025-08-29 02:57:17 +00:00
Marvin
6f3f162d2b Update go.mod
Former-commit-id: d61d7b64fc
2025-08-28 17:27:12 +02:00
Marvin
f6fa5fd02c Update .go-version
Former-commit-id: d64a4b5973
2025-08-28 17:26:43 +02:00
Marvin
8f4e0ba29e Update test.yml
Former-commit-id: 27d687e91c
2025-08-28 17:26:22 +02:00
Marvin
32b7dc7c43 Update cicd.yml
Former-commit-id: d3b461c01d
2025-08-28 17:26:00 +02:00
Marvin
78d2ebe1de Update dependabot.yml
Former-commit-id: d696706a2e
2025-08-28 17:25:26 +02:00
Owen
014f8eb4e5 Merge branch 'main' into dev
Former-commit-id: 0d1fbd9605
2025-08-23 12:18:57 -07:00
Owen
cd42803291 Add note about config
Former-commit-id: b6e9aae692
2025-08-22 21:35:00 -07:00
Owen Schwartz
5c5b303994 Merge pull request #12 from fosrl/dependabot/docker/minor-updates-887f07f54c
Bump golang from 1.24-alpine to 1.25-alpine in the minor-updates group

Former-commit-id: ad73fc4aa8
2025-08-13 15:01:04 -07:00
Owen
968873da22 Remove container stuff for now
Former-commit-id: 6c41e3f3f6
2025-08-13 14:45:16 -07:00
Owen
e2772f918b Add some more fields to the http api
Former-commit-id: db766de127
2025-08-13 14:41:40 -07:00
dependabot[bot]
cdf6a31b67 Bump golang from 1.24-alpine to 1.25-alpine in the minor-updates group
Bumps the minor-updates group with 1 update: golang.


Updates `golang` from 1.24-alpine to 1.25-alpine

---
updated-dependencies:
- dependency-name: golang
  dependency-version: 1.25-alpine
  dependency-type: direct:production
  dependency-group: minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: e37282c120
2025-08-13 21:14:01 +00:00
Owen
2ce72065a7 Handle env correctly
Former-commit-id: b462b2c53b
2025-08-13 13:47:11 -07:00
Owen
b3e7aafb58 Send version and fall back to old hp
Former-commit-id: 4986859f2f
2025-08-13 12:35:21 -07:00
Owen
dd610ad850 Merge branch 'main' of github.com:fosrl/olm
Former-commit-id: 6c0d664f65
2025-08-13 12:21:22 -07:00
Owen Schwartz
b4b0a832e7 Merge pull request #10 from fosrl/dependabot/go_modules/prod-minor-updates-8d85d55a7a
Bump the prod-minor-updates group with 2 updates

Former-commit-id: 332e2bbb4c
2025-08-07 14:43:44 -07:00
dependabot[bot]
79963c1f66 Bump the prod-minor-updates group with 2 updates
Bumps the prod-minor-updates group with 2 updates: [golang.org/x/crypto](https://github.com/golang/crypto) and [golang.org/x/sys](https://github.com/golang/sys).


Updates `golang.org/x/crypto` from 0.40.0 to 0.41.0
- [Commits](https://github.com/golang/crypto/compare/v0.40.0...v0.41.0)

Updates `golang.org/x/sys` from 0.34.0 to 0.35.0
- [Commits](https://github.com/golang/sys/compare/v0.34.0...v0.35.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.41.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: prod-minor-updates
- dependency-name: golang.org/x/sys
  dependency-version: 0.35.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: prod-minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: c6faa548b3
2025-08-07 21:10:02 +00:00
Owen Schwartz
5f95282161 Merge pull request #4 from fosrl/dependabot/docker/ubuntu-24.04
Bump ubuntu from 22.04 to 24.04

Former-commit-id: 2a6dca355b
2025-08-06 15:37:43 -07:00
Owen
1cca54f9d5 Add version & send hp no overlap
Former-commit-id: 2b5884b19b
2025-08-04 22:22:40 -07:00
Owen
219df22919 Hp to all exit nodes
Former-commit-id: b6fb17d849
2025-08-04 22:22:40 -07:00
dependabot[bot]
337d9934fd Bump ubuntu from 22.04 to 24.04
Bumps ubuntu from 22.04 to 24.04.

---
updated-dependencies:
- dependency-name: ubuntu
  dependency-version: '24.04'
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: 4a26d0c117
2025-07-30 16:59:13 +00:00
Owen Schwartz
bba4d72a78 Merge pull request #3 from Lokowitz/update-and-clean-go
sync go to latest version

Former-commit-id: bdf231d7d4
2025-07-30 09:58:48 -07:00
Owen Schwartz
fb5c793126 Merge pull request #2 from Lokowitz/add-dependabot
add dependabot

Former-commit-id: 9a5032e351
2025-07-30 09:58:18 -07:00
Owen Schwartz
1821dbb672 Merge pull request #1 from Lokowitz/add-test
add simple test

Former-commit-id: 83869f57af
2025-07-30 09:57:03 -07:00
Marvin
4fda6fe031 modified: .github/workflows/cicd.yml
new file:   .go-version
	modified:   Dockerfile
	modified:   go.mod
	modified:   go.sum


Former-commit-id: e51590509f
2025-07-30 09:02:42 +00:00
Marvin
f286f0faf6 Create dependabot.yml
Former-commit-id: 5f4de2a5f6
2025-07-30 10:56:54 +02:00
Marvin
cba3d607bf Create test.yml
Former-commit-id: 1e7cfa95d6
2025-07-30 10:56:30 +02:00
Owen
5ca12834a1 Fix typo
Former-commit-id: fcb745ca77
2025-07-29 09:54:11 -07:00
Owen
9d41154daa Delete buildx
Former-commit-id: c45ac94518
2025-07-29 09:49:35 -07:00
Owen
63933b57fc Update cicd
Former-commit-id: 25b58e868b
2025-07-29 09:47:33 -07:00
Owen
c25d77597d Add warning
Former-commit-id: 6420434821
2025-07-28 22:49:45 -07:00
Owen
ad080046a1 Fix what happens if there are no sites
Former-commit-id: bc855bc4c5
2025-07-28 22:40:49 -07:00
Owen
c1f7cf93a5 Update readme
Former-commit-id: b99096cde1
2025-07-28 12:34:50 -07:00
Owen
f3f112fc42 Merge branch 'clients-fr' into dev
Former-commit-id: 86c2b66b19
2025-07-28 11:56:23 -07:00
Owen
612a9ddb15 Dont kick off the process again on the ws
Former-commit-id: f1b3abdffc
2025-07-28 11:55:03 -07:00
Owen
ad1fa2e59a Handle remote routing
Former-commit-id: 516eae6d96
2025-07-27 14:49:57 -07:00
Owen
29235f6100 Reconnect to newt
Former-commit-id: ee7948f3d5
2025-07-24 20:45:17 -07:00
Owen
848ac6b0c4 Holepunch but relay by default
Former-commit-id: 5302f9da34
2025-07-24 14:44:12 -07:00
Owen
6ab66e6c36 Fix not sending id in punch causing issue
Former-commit-id: 2a832420df
2025-07-24 12:44:13 -07:00
Owen
d7f29d4709 Polish
Former-commit-id: 5068ca6af8
2025-07-24 12:36:47 -07:00
Owen
6fb2b68e21 Service is better?
Former-commit-id: 442098be0c
2025-07-24 11:55:05 -07:00
Owen
8d72e77d57 Service working?
Former-commit-id: 7802248085
2025-07-24 11:06:53 -07:00
Owen
25a9b83496 Service starting with logs
Former-commit-id: e4c030516b
2025-07-23 22:05:35 -07:00
Owen
4d33016389 Logging better?
Former-commit-id: 5fedc2bef1
2025-07-23 21:59:05 -07:00
Owen
0f717aec01 Service is throwing now missing cli args
Former-commit-id: 0807b72fe0
2025-07-23 21:03:47 -07:00
Owen
4c58cd6eff Working windows service
Former-commit-id: a85f83cc20
2025-07-23 20:35:00 -07:00
Owen
3ad36f95e1 Reuse the connection 2025-07-23 14:49:24 -07:00
Owen
b58e7c9fad Holepunch to the right endpoint 2025-07-21 17:04:29 -07:00
Owen
85a8a737e8 Handle relay endpoint dynamically now 2025-07-18 21:41:36 -07:00
Owen
8e83a83294 Update to use newt websocket 2025-07-18 16:58:52 -07:00
Owen
c1ef56001f Add test mode 2025-05-13 11:39:07 -04:00
Owen
c04e727bd3 Add propper logging - remove windows routing? 2025-05-11 10:28:01 -04:00
Owen
becc214078 Add routes 2025-05-03 17:10:46 -04:00
Owen
13e7f55b30 Interface comes up 2025-05-03 16:41:13 -04:00
Owen
0be3ee7eee Fix order of when to add peer monitoring 2025-04-21 10:54:25 -04:00
Owen
0b1724a3f3 Macos working again 2025-04-20 20:57:44 -04:00
Owen
e606264deb Add http server 2025-04-20 16:56:03 -04:00
Owen
5497eb8a4e Print incoming message 2025-04-20 14:49:25 -04:00
Owen
2159371371 Fix olm initially missing private key 2025-04-18 16:56:21 -04:00
Owen
ad8a94fdc8 newts not being on when olm is started 2025-04-13 21:28:25 -04:00
Owen
61b7feef80 Add ping for connectivity monitoring 2025-04-13 17:27:43 -04:00
Owen
4cb31df3c8 Fix concurrancy problem and add +1 back 2025-04-12 17:51:29 -04:00
Owen
3b0eef6d60 Relaying and basic peer detection working 2025-04-11 20:52:10 -04:00
Owen
5d305f1d03 Add peer monitor 2025-04-08 21:57:53 -04:00
Owen
31e5d4e3bd Remove server 2025-04-04 11:35:33 -04:00
Owen
8fb9468d08 Add debug holepunch messahe 2025-04-03 20:47:56 -04:00
Owen
5bbd5016aa Able to connect to multiple newts 2025-04-01 12:48:53 -04:00
Owen
8c40b8c578 Adjustments on the road to testing 2025-04-01 11:24:18 -04:00
Owen
e35f7c2d36 Add wgtester 2025-03-31 18:11:04 -04:00
Owen
aeb8f203a4 Send relay message 2025-03-27 22:13:16 -04:00
Owen
73bd036e58 Add windows to the makefile 2025-03-23 20:57:38 -04:00
Owen
eb6b310304 Include on windows as well 2025-03-15 21:49:04 -04:00
Owen
3d70ff190f Unix: handle encrypted messages 2025-03-15 21:46:54 -04:00
Owen
7e2d7b93a1 Windows "working" 2025-02-27 21:53:40 -05:00
Owen
f50ff67057 Strip trailing slashes 2025-02-24 10:38:05 -05:00
Owen
76d5e95fbf Remove unsuported 2025-02-24 10:09:02 -05:00
Owen
b6db70e285 Rudamentary check for p2p connectivity 2025-02-23 20:17:39 -05:00
Owen
3819823d95 Basic relay working! 2025-02-23 16:50:11 -05:00
Owen
b2830e8473 HP works! 2025-02-22 12:53:46 -05:00
Owen
43a43b429d Initial hp working maybe? 2025-02-22 11:20:30 -05:00
Owen
1593f22691 Merge branch 'main' into holepunch 2025-02-22 00:15:01 -05:00
Owen
4883402393 Add comment 2025-02-22 00:14:49 -05:00
miloschwartz
02eab1ff52 support bring up net interface on darwin 2025-02-22 00:11:25 -05:00
Owen
e0ca38bb35 Add static port and udp holepunch 2025-02-21 22:27:50 -05:00
Owen
8d46ae3aa2 Configure cross platform interface; mac needs fixing 2025-02-21 20:59:34 -05:00
Owen
7424caca8a Working! 2025-02-21 18:50:58 -05:00
Owen
9d9f10a799 Remove ping 2025-02-21 17:12:53 -05:00
Owen
a42d2b75dd Initiall connection seems to be working 2025-02-21 16:20:15 -05:00
Owen
8b09545cf6 Fix small issues; add back ws 2025-02-21 16:11:57 -05:00
Owen
eb77be09e2 Rename to olm 2025-02-21 12:31:25 -05:00
Owen
ad01296c41 Tidy 2025-02-20 22:07:18 -05:00
Owen
b553209712 Import copied libs 2025-02-20 22:06:00 -05:00
Owen
c5098f0cd0 Rename to client 2025-02-20 22:04:06 -05:00
Owen
5ec1aac0d1 Fix S1025 2025-02-20 21:58:33 -05:00
Owen
313ef42883 Remove dev 2025-02-20 21:56:22 -05:00
Owen
085c98668d Move to userspace wg 2025-02-20 21:54:53 -05:00
Owen
6107d20e26 Replace netstack and remove proxy 2025-02-20 21:24:39 -05:00
Owen
66edae4288 Clean up implementation 2025-02-20 21:01:44 -05:00
Owen
f69a7f647d Move wg into more of a class 2025-02-20 20:37:31 -05:00
Owen
e8bd55bed9 Copy in gerbil wg config 2025-02-20 20:04:01 -05:00
Owen
b23eda9c06 Add arm32 go binary as well 2025-02-15 17:59:59 -05:00
Owen
76503f3f2c Fix typo 2025-02-15 17:52:51 -05:00
Owen
9c3112f9bd Merge branch 'dev' 2025-02-10 21:42:29 -05:00
Owen
462af30d16 Add systemd service; Closes #12 2025-02-10 21:41:59 -05:00
Owen
fa6038eb38 Move message to debug to reduce confusion 2025-02-06 20:21:04 -05:00
Owen Schwartz
f346b6cc5d Bump actions/upload-artifact 2025-01-30 10:18:46 -05:00
Owen Schwartz
f20b9ebb14 Merge pull request #9 from fosrl/dev
CICD, --version & Bug Fixes
2025-01-30 10:14:31 -05:00
Owen Schwartz
39bfe5b230 Insert version CICD 2025-01-29 22:31:14 -05:00
Milo Schwartz
a1a3dd9ba2 Merge branch 'dev' of https://github.com/fosrl/newt into dev 2025-01-29 22:23:42 -05:00
Milo Schwartz
7b1492f327 add cicd 2025-01-29 22:23:03 -05:00
Owen Schwartz
4e50819785 Add --version check 2025-01-29 22:19:18 -05:00
Owen Schwartz
f8dccbec80 Fix save config 2025-01-29 22:15:28 -05:00
Owen Schwartz
0c5c59cf00 Fix removing udp sockets 2025-01-27 21:28:22 -05:00
Owen Schwartz
868bb55f87 Fix windows build in release 2025-01-20 21:40:55 -05:00
Owen Schwartz
5b4245402a Merge pull request #6 from fosrl/dev
Proxy Manager Rewrite
2025-01-20 21:15:31 -05:00
Owen Schwartz
f7a705e6f8 Remove starts 2025-01-20 21:13:09 -05:00
Owen Schwartz
3a63657822 Rewrite proxy manager 2025-01-20 21:11:06 -05:00
Owen Schwartz
759780508a Resolve TCP hanging but port is in use issue 2025-01-19 22:46:00 -05:00
Owen Schwartz
533886f2e4 Standarize makefile release 2025-01-16 07:41:56 -05:00
Owen Schwartz
79f8745909 Merge pull request #5 from fosrl/dev
Add tip and MTU set to 1280
2025-01-15 22:41:30 -05:00
Owen Schwartz
7b663027ac Add tip 2025-01-15 21:57:14 -05:00
Owen Schwartz
e90e55d982 Allow chaning mtu; set default low 2025-01-13 22:51:36 -05:00
Owen Schwartz
a46fb23cdd Add all arches and log level 2025-01-13 21:22:17 -05:00
Milo Schwartz
10982b47a5 fix typos in readme 2025-01-09 16:44:25 -05:00
Milo Schwartz
ab12098c9c Merge pull request #4 from fosrl/dev
add security policy
2025-01-08 21:57:45 -05:00
Milo Schwartz
446eb4d6f1 add security policy 2025-01-08 21:36:03 -05:00
48 changed files with 7529 additions and 1376 deletions

View File

@@ -1,6 +1,6 @@
.gitignore .gitignore
.dockerignore .dockerignore
newt olm
*.json *.json
README.md README.md
Makefile Makefile

View File

@@ -0,0 +1,47 @@
body:
- type: textarea
attributes:
label: Summary
description: A clear and concise summary of the requested feature.
validations:
required: true
- type: textarea
attributes:
label: Motivation
description: |
Why is this feature important?
Explain the problem this feature would solve or what use case it would enable.
validations:
required: true
- type: textarea
attributes:
label: Proposed Solution
description: |
How would you like to see this feature implemented?
Provide as much detail as possible about the desired behavior, configuration, or changes.
validations:
required: true
- type: textarea
attributes:
label: Alternatives Considered
description: Describe any alternative solutions or workarounds you've thought about.
validations:
required: false
- type: textarea
attributes:
label: Additional Context
description: Add any other context, mockups, or screenshots about the feature request here.
validations:
required: false
- type: markdown
attributes:
value: |
Before submitting, please:
- Check if there is an existing issue for this feature.
- Clearly explain the benefit and use case.
- Be as specific as possible to help contributors evaluate and implement.

51
.github/ISSUE_TEMPLATE/1.bug_report.yml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: Bug Report
description: Create a bug report
labels: []
body:
- type: textarea
attributes:
label: Describe the Bug
description: A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
attributes:
label: Environment
description: Please fill out the relevant details below for your environment.
value: |
- OS Type & Version: (e.g., Ubuntu 22.04)
- Pangolin Version:
- Gerbil Version:
- Traefik Version:
- Newt Version:
- Olm Version: (if applicable)
validations:
required: true
- type: textarea
attributes:
label: To Reproduce
description: |
Steps to reproduce the behavior, please provide a clear description of how to reproduce the issue, based on the linked minimal reproduction. Screenshots can be provided in the issue body below.
If using code blocks, make sure syntax highlighting is correct and double-check that the rendered preview is not broken.
validations:
required: true
- type: textarea
attributes:
label: Expected Behavior
description: A clear and concise description of what you expected to happen.
validations:
required: true
- type: markdown
attributes:
value: |
Before posting the issue go through the steps you've written down to make sure the steps provided are detailed and clear.
- type: markdown
attributes:
value: |
Contributors should be able to follow the steps provided in order to reproduce the bug.

8
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,8 @@
blank_issues_enabled: false
contact_links:
- name: Need help or have questions?
url: https://github.com/orgs/fosrl/discussions
about: Ask questions, get help, and discuss with other community members
- name: Request a Feature
url: https://github.com/orgs/fosrl/discussions/new?category=feature-requests
about: Feature requests should be opened as discussions so others can upvote and comment

40
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,40 @@
version: 2
updates:
- package-ecosystem: "gomod"
directory: "/"
schedule:
interval: "daily"
groups:
dev-patch-updates:
dependency-type: "development"
update-types:
- "patch"
dev-minor-updates:
dependency-type: "development"
update-types:
- "minor"
prod-patch-updates:
dependency-type: "production"
update-types:
- "patch"
prod-minor-updates:
dependency-type: "production"
update-types:
- "minor"
- package-ecosystem: "docker"
directory: "/"
schedule:
interval: "daily"
groups:
patch-updates:
update-types:
- "patch"
minor-updates:
update-types:
- "minor"
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

60
.github/workflows/cicd.yml vendored Normal file
View File

@@ -0,0 +1,60 @@
name: CI/CD Pipeline
on:
push:
tags:
- "*"
jobs:
release:
name: Build and Release
runs-on: amd64-runner
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
- name: Extract tag name
id: get-tag
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
- name: Install Go
uses: actions/setup-go@v6
with:
go-version: 1.25
- name: Update version in main.go
run: |
TAG=${{ env.TAG }}
if [ -f main.go ]; then
sed -i 's/version_replaceme/'"$TAG"'/' main.go
echo "Updated main.go with version $TAG"
else
echo "main.go not found"
fi
- name: Build and push Docker images
run: |
TAG=${{ env.TAG }}
make docker-build-release tag=$TAG
- name: Build binaries
run: |
make go-build-release
- name: Upload artifacts from /bin
uses: actions/upload-artifact@v4
with:
name: binaries
path: bin/

132
.github/workflows/mirror.yaml vendored Normal file
View File

@@ -0,0 +1,132 @@
name: Mirror & Sign (Docker Hub to GHCR)
on:
workflow_dispatch: {}
permissions:
contents: read
packages: write
id-token: write # for keyless OIDC
env:
SOURCE_IMAGE: docker.io/fosrl/olm
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
jobs:
mirror-and-dual-sign:
runs-on: amd64-runner
steps:
- name: Install skopeo + jq
run: |
sudo apt-get update -y
sudo apt-get install -y skopeo jq
skopeo --version
- name: Install cosign
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
- name: Input check
run: |
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
echo "Source : ${SOURCE_IMAGE}"
echo "Target : ${DEST_IMAGE}"
# Auth for skopeo (containers-auth)
- name: Skopeo login to GHCR
run: |
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
# Auth for cosign (docker-config)
- name: Docker login to GHCR (for cosign)
run: |
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
- name: List source tags
run: |
set -euo pipefail
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
| jq -r '.Tags[]' | sort -u > src-tags.txt
echo "Found source tags: $(wc -l < src-tags.txt)"
head -n 20 src-tags.txt || true
- name: List destination tags (skip existing)
run: |
set -euo pipefail
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
else
: > dst-tags.txt
fi
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
- name: Mirror, dual-sign, and verify
env:
# keyless
COSIGN_YES: "true"
# key-based
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
# verify
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
run: |
set -euo pipefail
copied=0; skipped=0; v_ok=0; errs=0
issuer="https://token.actions.githubusercontent.com"
id_regex="^https://github.com/${{ github.repository }}/.+"
while read -r tag; do
[ -z "$tag" ] && continue
if grep -Fxq "$tag" dst-tags.txt; then
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
skipped=$((skipped+1))
continue
fi
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
if ! skopeo copy --all --retry-times 3 \
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
errs=$((errs+1)); continue
fi
copied=$((copied+1))
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
ref="${DEST_IMAGE}@${digest}"
echo "==> cosign sign (keyless) --recursive ${ref}"
if ! cosign sign --recursive "${ref}"; then
echo "::warning title=Keyless sign failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign sign (key) --recursive ${ref}"
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
echo "::warning title=Key sign failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign verify (public key) ${ref}"
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
echo "::warning title=Verify(pubkey) failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign verify (keyless policy) ${ref}"
if ! cosign verify \
--certificate-oidc-issuer "${issuer}" \
--certificate-identity-regexp "${id_regex}" \
"${ref}" -o text; then
echo "::warning title=Verify(keyless) failed::${ref}"
errs=$((errs+1))
else
v_ok=$((v_ok+1))
fi
done < src-tags.txt
echo "---- Summary ----"
echo "Copied : $copied"
echo "Skipped : $skipped"
echo "Verified OK : $v_ok"
echo "Errors : $errs"

28
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
name: Run Tests
on:
pull_request:
branches:
- main
- dev
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: 1.25
- name: Build go
run: go build
- name: Build Docker image
run: make build
- name: Build binaries
run: make go-build-release

3
.gitignore vendored
View File

@@ -1 +1,2 @@
newt .DS_Store
bin/

1
.go-version Normal file
View File

@@ -0,0 +1 @@
1.25

View File

@@ -4,11 +4,7 @@ Contributions are welcome!
Please see the contribution and local development guide on the docs page before getting started: Please see the contribution and local development guide on the docs page before getting started:
https://docs.fossorial.io/development https://docs.pangolin.net/development/contributing
For ideas about what features to work on and our future plans, please see the roadmap:
https://docs.fossorial.io/roadmap
### Licensing Considerations ### Licensing Considerations

View File

@@ -1,4 +1,4 @@
FROM golang:1.23.1-alpine AS builder FROM golang:1.25-alpine AS builder
# Set the working directory inside the container # Set the working directory inside the container
WORKDIR /app WORKDIR /app
@@ -13,15 +13,15 @@ RUN go mod download
COPY . . COPY . .
# Build the application # Build the application
RUN CGO_ENABLED=0 GOOS=linux go build -o /newt RUN CGO_ENABLED=0 GOOS=linux go build -o /olm
# Start a new stage from scratch # Start a new stage from scratch
FROM ubuntu:22.04 AS runner FROM alpine:3.22 AS runner
RUN apt-get update && apt-get install ca-certificates -y && rm -rf /var/lib/apt/lists/* RUN apk --no-cache add ca-certificates
# Copy the pre-built binary file from the previous stage and the entrypoint script # Copy the pre-built binary file from the previous stage and the entrypoint script
COPY --from=builder /newt /usr/local/bin/ COPY --from=builder /olm /usr/local/bin/
COPY entrypoint.sh / COPY entrypoint.sh /
RUN chmod +x /entrypoint.sh RUN chmod +x /entrypoint.sh
@@ -30,4 +30,4 @@ RUN chmod +x /entrypoint.sh
ENTRYPOINT ["/entrypoint.sh"] ENTRYPOINT ["/entrypoint.sh"]
# Command to run the executable # Command to run the executable
CMD ["newt"] CMD ["olm"]

View File

@@ -1,17 +1,26 @@
all: build push all: go-build-release
build: docker-build-release:
docker build -t fosrl/newt:latest . @if [ -z "$(tag)" ]; then \
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
push: exit 1; \
docker push fosrl/newt:latest fi
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:latest -f Dockerfile --push .
test: docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:$(tag) -f Dockerfile --push .
docker run fosrl/newt:latest
local: local:
CGO_ENABLED=0 go build -o newt CGO_ENABLED=0 go build -o bin/olm
build:
docker build -t fosrl/olm:latest .
go-build-release:
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/olm_linux_arm64
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/olm_linux_amd64
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/olm_darwin_arm64
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/olm_darwin_amd64
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/olm_windows_amd64.exe
clean: clean:
rm newt rm olm

303
README.md
View File

@@ -1,46 +1,59 @@
# Newt # Olm
Newt is a fully user space [WireGuard](https://www.wireguard.com/) tunnel client and TCP/UDP proxy, designed to securely expose private resources controlled by Pangolin. By using Newt, you don't need to manage complex WireGuard tunnels and NATing. Olm is a [WireGuard](https://www.wireguard.com/) tunnel client designed to securely connect your computer to Newt sites running on remote networks.
### Installation and Documentation ### Installation and Documentation
Newt is used with Pangolin and Gerbil as part of the larger system. See documentation below: Olm is used with Pangolin and Newt as part of the larger system. See documentation below:
- [Installation Instructions](https://docs.fossorial.io) - [Full Documentation](https://docs.pangolin.net)
- [Full Documentation](https://docs.fossorial.io)
## Preview
<img src="public/screenshots/preview.png" alt="Preview"/>
_Sample output of a Newt container connected to Pangolin and hosting various resource target proxies._
## Key Functions ## Key Functions
### Registers with Pangolin ### Registers with Pangolin
Using the Newt ID and a secret the client will make HTTP requests to Pangolin to receive a session token. Using that token it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket. Using the Olm ID and a secret, the olm will make HTTP requests to Pangolin to receive a session token. Using that token, it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket.
### Receives WireGuard Control Messages ### Receives WireGuard Control Messages
When Newt receives WireGuard control messages, it will use the information encoded (endpoint, public key) to bring up a WireGuard tunnel using [netstack](https://github.com/WireGuard/wireguard-go/blob/master/tun/netstack/examples/http_server.go) fully in user space. It will ping over the tunnel to ensure the peer on the Gerbil side is brought up. When Olm receives WireGuard control messages, it will use the information encoded (endpoint, public key) to bring up a WireGuard tunnel on your computer to a remote Newt. It will ping over the tunnel to ensure the peer is brought up.
### Receives Proxy Control Messages
When Newt receives WireGuard control messages, it will use the information encoded to crate local low level TCP and UDP proxies attached to the virtual tunnel in order to relay traffic to programmed targets.
## CLI Args ## CLI Args
- `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket. - `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket.
- `id`: Newt ID generated by Pangolin to identify the client. - `id`: Olm ID generated by Pangolin to identify the olm.
- `secret`: A unique secret (not shared and kept private) used to authenticate the client ID with the websocket in order to receive commands. - `secret`: A unique secret (not shared and kept private) used to authenticate the olm ID with the websocket in order to receive commands.
- `dns`: DNS server to use to resolve the endpoint - `mtu` (optional): MTU for the internal WG interface. Default: 1280
- `log-level` (optional): The log level to use. Default: INFO - `dns` (optional): DNS server to use to resolve the endpoint. Default: 8.8.8.8
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: INFO
- `ping-interval` (optional): Interval for pinging the server. Default: 3s
- `ping-timeout` (optional): Timeout for each ping. Default: 5s
- `interface` (optional): Name of the WireGuard interface. Default: olm
- `enable-http` (optional): Enable HTTP server for receiving connection requests. Default: false
- `http-addr` (optional): HTTP server address (e.g., ':9452'). Default: :9452
- `holepunch` (optional): Enable hole punching. Default: false
Example: ## Environment Variables
All CLI arguments can also be set via environment variables:
- `PANGOLIN_ENDPOINT`: Equivalent to `--endpoint`
- `OLM_ID`: Equivalent to `--id`
- `OLM_SECRET`: Equivalent to `--secret`
- `MTU`: Equivalent to `--mtu`
- `DNS`: Equivalent to `--dns`
- `LOG_LEVEL`: Equivalent to `--log-level`
- `INTERFACE`: Equivalent to `--interface`
- `HTTP_ADDR`: Equivalent to `--http-addr`
- `PING_INTERVAL`: Equivalent to `--ping-interval`
- `PING_TIMEOUT`: Equivalent to `--ping-timeout`
- `HOLEPUNCH`: Set to "true" to enable hole punching (equivalent to `--holepunch`)
- `CONFIG_FILE`: Set to the location of a JSON file to load secret values
Examples:
```bash ```bash
./newt \ olm \
--id 31frd0uzbjvp721 \ --id 31frd0uzbjvp721 \
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \ --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
--endpoint https://example.com --endpoint https://example.com
@@ -50,40 +63,230 @@ You can also run it with Docker compose. For example, a service in your `docker-
```yaml ```yaml
services: services:
newt: olm:
image: fosrl/newt image: fosrl/olm
container_name: newt container_name: olm
restart: unless-stopped restart: unless-stopped
environment: network_mode: host
- PANGOLIN_ENDPOINT=https://example.com devices:
- NEWT_ID=2ix2t8xk22ubpfy - /dev/net/tun:/dev/net/tun
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2 environment:
- PANGOLIN_ENDPOINT=https://example.com
- OLM_ID=31frd0uzbjvp721
- OLM_SECRET=h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6
``` ```
You can also pass the CLI args to the container: You can also pass the CLI args to the container:
```yaml ```yaml
services: services:
newt: olm:
image: fosrl/newt image: fosrl/olm
container_name: newt container_name: olm
restart: unless-stopped restart: unless-stopped
command: network_mode: host
- --id 31frd0uzbjvp721 devices:
- --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 - /dev/net/tun:/dev/net/tun
- --endpoint https://example.com command:
- --id 31frd0uzbjvp721
- --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6
- --endpoint https://example.com
```
**Docker Configuration Notes:**
- `network_mode: host` brings the olm network interface to the host system, allowing the WireGuard tunnel to function properly
- `devices: - /dev/net/tun:/dev/net/tun` is required to give the container access to the TUN device for creating WireGuard interfaces
## Loading secrets from files
You can use `CONFIG_FILE` to define a location of a config file to store the credentials between runs.
```
$ cat ~/.config/olm-client/config.json
{
"id": "spmzu8rbpzj1qq6",
"secret": "f6v61mjutwme2kkydbw3fjo227zl60a2tsf5psw9r25hgae3",
"endpoint": "https://app.pangolin.net",
"tlsClientCert": ""
}
```
This file is also written to when newt first starts up. So you do not need to run every time with --id and secret if you have run it once!
Default locations:
- **macOS**: `~/Library/Application Support/olm-client/config.json`
- **Windows**: `%PROGRAMDATA%\olm\olm-client\config.json`
- **Linux/Others**: `~/.config/olm-client/config.json`
## Hole Punching
In the default mode, olm "relays" traffic through Gerbil in the cloud to get down to newt. This is a little more reliable. Support for NAT hole punching is also EXPERIMENTAL right now using the `--holepunch` flag. This will attempt to orchestrate a NAT hole punch between the two sites so that traffic flows directly. This will save data costs and speed. If it fails it should fall back to relaying.
Right now, basic NAT hole punching is supported. We plan to add:
- [ ] Birthday paradox
- [ ] UPnP
- [ ] LAN detection
## Windows Service
On Windows, olm has to be installed and run as a Windows service. When running it with the cli args live above it will attempt to install and run the service to function like a cli tool. You can also run the following:
### Service Management Commands
```
# Install the service
olm.exe install
# Start the service
olm.exe start
# Stop the service
olm.exe stop
# Check service status
olm.exe status
# Remove the service
olm.exe remove
# Run in debug mode (console output) with our without id & secret
olm.exe debug
# Show help
olm.exe help
```
Note running the service requires credentials in `%PROGRAMDATA%\olm\olm-client\config.json`.
### Service Configuration
When running as a service, Olm will read configuration from environment variables or you can modify the service to include command-line arguments:
1. Install the service: `olm.exe install`
2. Set the credentials in `%PROGRAMDATA%\olm\olm-client\config.json`. Hint: if you run olm once with --id and --secret this file will be populated!
3. Start the service: `olm.exe start`
### Service Logs
When running as a service, logs are written to:
- Windows Event Log (Application log, source: "OlmWireguardService")
- Log files in: `%PROGRAMDATA%\olm\logs\olm.log`
You can view the Windows Event Log using Event Viewer or PowerShell:
```powershell
Get-EventLog -LogName Application -Source "OlmWireguardService" -Newest 10
```
## HTTP Endpoints
Olm can be controlled with an embedded http server when using `--enable-http`. This allows you to start it as a daemon and trigger it with the following endpoints:
### POST /connect
Initiates a new connection request.
**Request Body:**
```json
{
"id": "string",
"secret": "string",
"endpoint": "string"
}
```
**Required Fields:**
- `id`: Connection identifier
- `secret`: Authentication secret
- `endpoint`: Target endpoint URL
**Response:**
- **Status Code:** `202 Accepted`
- **Content-Type:** `application/json`
```json
{
"status": "connection request accepted"
}
```
**Error Responses:**
- `405 Method Not Allowed` - Non-POST requests
- `400 Bad Request` - Invalid JSON or missing required fields
### GET /status
Returns the current connection status and peer information.
**Response:**
- **Status Code:** `200 OK`
- **Content-Type:** `application/json`
```json
{
"status": "connected",
"connected": true,
"tunnelIP": "100.89.128.3/20",
"version": "version_replaceme",
"peers": {
"10": {
"siteId": 10,
"connected": true,
"rtt": 145338339,
"lastSeen": "2025-08-13T14:39:17.208334428-07:00",
"endpoint": "p.fosrl.io:21820",
"isRelay": true
},
"8": {
"siteId": 8,
"connected": false,
"rtt": 0,
"lastSeen": "2025-08-13T14:39:19.663823645-07:00",
"endpoint": "p.fosrl.io:21820",
"isRelay": true
}
}
}
```
**Fields:**
- `status`: Overall connection status ("connected" or "disconnected")
- `connected`: Boolean connection state
- `tunnelIP`: IP address and subnet of the tunnel (when connected)
- `version`: Olm version string
- `peers`: Map of peer statuses by site ID
- `siteId`: Peer site identifier
- `connected`: Boolean peer connection state
- `rtt`: Peer round-trip time (integer, nanoseconds)
- `lastSeen`: Last time peer was seen (RFC3339 timestamp)
- `endpoint`: Peer endpoint address
- `isRelay`: Whether the peer is relayed (true) or direct (false)
**Error Responses:**
- `405 Method Not Allowed` - Non-GET requests
## Usage Examples
### Connect to a peer
```bash
curl -X POST http://localhost:8080/connect \
-H "Content-Type: application/json" \
-d '{
"id": "31frd0uzbjvp721",
"secret": "h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6",
"endpoint": "https://example.com"
}'
```
### Check connection status
```bash
curl http://localhost:8080/status
``` ```
## Build ## Build
### Container
Ensure Docker is installed.
```bash
make
```
### Binary ### Binary
Make sure to have Go 1.23.1 installed. Make sure to have Go 1.23.1 installed.
@@ -94,8 +297,8 @@ make local
## Licensing ## Licensing
Newt is dual licensed under the AGPLv3 and the Fossorial Commercial license. For inquiries about commercial licensing, please contact us. Olm is dual licensed under the AGPLv3 and the Fossorial Commercial license. For inquiries about commercial licensing, please contact us.
## Contributions ## Contributions
Please see [CONTRIBUTIONS](./CONTRIBUTING.md) in the repository for guidelines and best practices. Please see [CONTRIBUTIONS](./CONTRIBUTING.md) in the repository for guidelines and best practices.

14
SECURITY.md Normal file
View File

@@ -0,0 +1,14 @@
# Security Policy
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
2. Send a detailed report to [security@pangolin.net](mailto:security@pangolin.net) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
- Description and location of the vulnerability.
- Potential impact of the vulnerability.
- Steps to reproduce the vulnerability.
- Potential solutions to fix the vulnerability.
- Your name/handle and a link for recognition (optional).
We aim to address the issue as soon as possible.

411
api/api.go Normal file
View File

@@ -0,0 +1,411 @@
package api
import (
"encoding/json"
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
// ConnectionRequest defines the structure for an incoming connection request
type ConnectionRequest struct {
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"`
}
// SwitchOrgRequest defines the structure for switching organizations
type SwitchOrgRequest struct {
OrgID string `json:"orgId"`
}
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
Connected bool `json:"connected"`
RTT time.Duration `json:"rtt"`
LastSeen time.Time `json:"lastSeen"`
Endpoint string `json:"endpoint,omitempty"`
IsRelay bool `json:"isRelay"`
PeerIP string `json:"peerAddress,omitempty"`
}
// StatusResponse is returned by the status endpoint
type StatusResponse struct {
Connected bool `json:"connected"`
Registered bool `json:"registered"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"`
OrgID string `json:"orgId,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
// API represents the HTTP server and its state
type API struct {
addr string
socketPath string
listener net.Listener
server *http.Server
connectionChan chan ConnectionRequest
switchOrgChan chan SwitchOrgRequest
shutdownChan chan struct{}
disconnectChan chan struct{}
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
isRegistered bool
tunnelIP string
version string
orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
func NewAPI(addr string) *API {
s := &API{
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
func NewAPISocket(socketPath string) *API {
s := &API{
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// Start starts the HTTP server
func (s *API) Start() error {
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit)
s.server = &http.Server{
Handler: mux,
}
var err error
if s.socketPath != "" {
// Use platform-specific socket listener
s.listener, err = createSocketListener(s.socketPath)
if err != nil {
return fmt.Errorf("failed to create socket listener: %w", err)
}
logger.Info("Starting HTTP server on socket %s", s.socketPath)
} else {
// Use TCP listener
s.listener, err = net.Listen("tcp", s.addr)
if err != nil {
return fmt.Errorf("failed to create TCP listener: %w", err)
}
logger.Info("Starting HTTP server on %s", s.addr)
}
go func() {
if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server error: %v", err)
}
}()
return nil
}
// Stop stops the HTTP server
func (s *API) Stop() error {
logger.Info("Stopping api server")
// Close the server first, which will also close the listener gracefully
if s.server != nil {
s.server.Close()
}
// Clean up socket file if using Unix socket
if s.socketPath != "" {
cleanupSocket(s.socketPath)
}
return nil
}
// GetConnectionChannel returns the channel for receiving connection requests
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
return s.connectionChan
}
// GetSwitchOrgChannel returns the channel for receiving org switch requests
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
return s.switchOrgChan
}
// GetShutdownChannel returns the channel for receiving shutdown requests
func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
}
// GetDisconnectChannel returns the channel for receiving disconnect requests
func (s *API) GetDisconnectChannel() <-chan struct{} {
return s.disconnectChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Connected = connected
status.RTT = rtt
status.LastSeen = time.Now()
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// SetConnectionStatus sets the overall connection status
func (s *API) SetConnectionStatus(isConnected bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isConnected = isConnected
if isConnected {
s.connectedAt = time.Now()
} else {
// Clear peer statuses when disconnected
s.peerStatuses = make(map[int]*PeerStatus)
}
}
func (s *API) SetRegistered(registered bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isRegistered = registered
}
// SetTunnelIP sets the tunnel IP address
func (s *API) SetTunnelIP(tunnelIP string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.tunnelIP = tunnelIP
}
// SetVersion sets the olm version
func (s *API) SetVersion(version string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.version = version
}
// SetOrgID sets the organization ID
func (s *API) SetOrgID(orgID string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.orgID = orgID
}
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// handleConnect handles the /connect endpoint
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// if we are already connected, reject new connection requests
s.statusMu.RLock()
alreadyConnected := s.isConnected
s.statusMu.RUnlock()
if alreadyConnected {
http.Error(w, "Already connected to a server. Disconnect first before connecting again.", http.StatusConflict)
return
}
var req ConnectionRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
return
}
// Send the request to the main goroutine
s.connectionChan <- req
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{
"status": "connection request accepted",
})
}
// handleStatus handles the /status endpoint
func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
s.statusMu.RLock()
defer s.statusMu.RUnlock()
resp := StatusResponse{
Connected: s.isConnected,
Registered: s.isRegistered,
TunnelIP: s.tunnelIP,
Version: s.version,
OrgID: s.orgID,
PeerStatuses: s.peerStatuses,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// handleExit handles the /exit endpoint
func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("Received exit request via API")
// Send shutdown signal
select {
case s.shutdownChan <- struct{}{}:
// Signal sent successfully
default:
// Channel already has a signal, don't block
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "shutdown initiated",
})
}
// handleSwitchOrg handles the /switch-org endpoint
func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req SwitchOrgRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.OrgID == "" {
http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
return
}
logger.Info("Received org switch request to orgId: %s", req.OrgID)
// Send the request to the main goroutine
select {
case s.switchOrgChan <- req:
// Signal sent successfully
default:
// Channel already has a pending request
http.Error(w, "Org switch already in progress", http.StatusConflict)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "org switch request accepted",
})
}
// handleDisconnect handles the /disconnect endpoint
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// if we are already disconnected, reject new disconnect requests
s.statusMu.RLock()
alreadyDisconnected := !s.isConnected
s.statusMu.RUnlock()
if alreadyDisconnected {
http.Error(w, "Not currently connected to a server.", http.StatusConflict)
return
}
logger.Info("Received disconnect request via API")
// Send disconnect signal
select {
case s.disconnectChan <- struct{}{}:
// Signal sent successfully
default:
// Channel already has a signal, don't block
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "disconnect initiated",
})
}

50
api/api_unix.go Normal file
View File

@@ -0,0 +1,50 @@
//go:build !windows
// +build !windows
package api
import (
"fmt"
"net"
"os"
"path/filepath"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Unix domain socket listener
func createSocketListener(socketPath string) (net.Listener, error) {
// Ensure the directory exists
dir := filepath.Dir(socketPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create socket directory: %w", err)
}
// Remove existing socket file if it exists
if err := os.RemoveAll(socketPath); err != nil {
return nil, fmt.Errorf("failed to remove existing socket: %w", err)
}
listener, err := net.Listen("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("failed to listen on Unix socket: %w", err)
}
// Set socket permissions to allow access
if err := os.Chmod(socketPath, 0666); err != nil {
listener.Close()
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
}
logger.Debug("Created Unix socket at %s", socketPath)
return listener, nil
}
// cleanupSocket removes the Unix socket file
func cleanupSocket(socketPath string) {
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
logger.Error("Failed to remove socket file %s: %v", socketPath, err)
} else {
logger.Debug("Removed Unix socket at %s", socketPath)
}
}

41
api/api_windows.go Normal file
View File

@@ -0,0 +1,41 @@
//go:build windows
// +build windows
package api
import (
"fmt"
"net"
"github.com/Microsoft/go-winio"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Windows named pipe listener
func createSocketListener(pipePath string) (net.Listener, error) {
// Ensure the pipe path has the correct format
if pipePath[0] != '\\' {
pipePath = `\\.\pipe\` + pipePath
}
// Create a pipe configuration that allows everyone to write
config := &winio.PipeConfig{
// Set security descriptor to allow everyone full access
// This SDDL string grants full access to Everyone (WD) and to the current owner (OW)
SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)",
}
// Create a named pipe listener using go-winio with the configuration
listener, err := winio.ListenPipe(pipePath, config)
if err != nil {
return nil, fmt.Errorf("failed to listen on named pipe: %w", err)
}
logger.Debug("Created named pipe at %s with write access for everyone", pipePath)
return listener, nil
}
// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up
func cleanupSocket(pipePath string) {
logger.Debug("Named pipe %s will be automatically cleaned up", pipePath)
}

378
bind/shared_bind.go Normal file
View File

@@ -0,0 +1,378 @@
//go:build !js
package bind
import (
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"sync/atomic"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// Endpoint represents a network endpoint for the SharedBind
type Endpoint struct {
AddrPort netip.AddrPort
}
// ClearSrc implements the wgConn.Endpoint interface
func (e *Endpoint) ClearSrc() {}
// DstIP implements the wgConn.Endpoint interface
func (e *Endpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
// SrcIP implements the wgConn.Endpoint interface
func (e *Endpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
// DstToBytes implements the wgConn.Endpoint interface
func (e *Endpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
// DstToString implements the wgConn.Endpoint interface
func (e *Endpoint) DstToString() string {
return e.AddrPort.String()
}
// SrcToString implements the wgConn.Endpoint interface
func (e *Endpoint) SrcToString() string {
return ""
}
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
// and hole punch senders. It wraps a single UDP connection and implements
// reference counting to prevent premature closure.
type SharedBind struct {
mu sync.RWMutex
// The underlying UDP connection
udpConn *net.UDPConn
// IPv4 and IPv6 packet connections for advanced features
ipv4PC *ipv4.PacketConn
ipv6PC *ipv6.PacketConn
// Reference counting to prevent closing while in use
refCount atomic.Int32
closed atomic.Bool
// Channels for receiving data
recvFuncs []wgConn.ReceiveFunc
// Port binding information
port uint16
}
// New creates a new SharedBind from an existing UDP connection.
// The SharedBind takes ownership of the connection and will close it
// when all references are released.
func New(udpConn *net.UDPConn) (*SharedBind, error) {
if udpConn == nil {
return nil, fmt.Errorf("udpConn cannot be nil")
}
bind := &SharedBind{
udpConn: udpConn,
}
// Initialize reference count to 1 (the creator holds the first reference)
bind.refCount.Store(1)
// Get the local port
if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok {
bind.port = uint16(addr.Port)
}
return bind, nil
}
// AddRef increments the reference count. Call this when sharing
// the bind with another component.
func (b *SharedBind) AddRef() {
newCount := b.refCount.Add(1)
// Optional: Add logging for debugging
_ = newCount // Placeholder for potential logging
}
// Release decrements the reference count. When it reaches zero,
// the underlying UDP connection is closed.
func (b *SharedBind) Release() error {
newCount := b.refCount.Add(-1)
// Optional: Add logging for debugging
_ = newCount // Placeholder for potential logging
if newCount < 0 {
// This should never happen with proper usage
b.refCount.Store(0)
return fmt.Errorf("SharedBind reference count went negative")
}
if newCount == 0 {
return b.closeConnection()
}
return nil
}
// closeConnection actually closes the UDP connection
func (b *SharedBind) closeConnection() error {
if !b.closed.CompareAndSwap(false, true) {
// Already closed
return nil
}
b.mu.Lock()
defer b.mu.Unlock()
var err error
if b.udpConn != nil {
err = b.udpConn.Close()
b.udpConn = nil
}
b.ipv4PC = nil
b.ipv6PC = nil
return err
}
// GetUDPConn returns the underlying UDP connection.
// The caller must not close this connection directly.
func (b *SharedBind) GetUDPConn() *net.UDPConn {
b.mu.RLock()
defer b.mu.RUnlock()
return b.udpConn
}
// GetRefCount returns the current reference count (for debugging)
func (b *SharedBind) GetRefCount() int32 {
return b.refCount.Load()
}
// IsClosed returns whether the bind is closed
func (b *SharedBind) IsClosed() bool {
return b.closed.Load()
}
// WriteToUDP writes data to a specific UDP address.
// This is thread-safe and can be used by hole punch senders.
func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) {
if b.closed.Load() {
return 0, net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
return conn.WriteToUDP(data, addr)
}
// Close implements the WireGuard Bind interface.
// It decrements the reference count and closes the connection if no references remain.
func (b *SharedBind) Close() error {
return b.Release()
}
// Open implements the WireGuard Bind interface.
// Since the connection is already open, this just sets up the receive functions.
func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
if b.closed.Load() {
return nil, 0, net.ErrClosed
}
b.mu.Lock()
defer b.mu.Unlock()
if b.udpConn == nil {
return nil, 0, net.ErrClosed
}
// Set up IPv4 and IPv6 packet connections for advanced features
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
}
// Create receive functions
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
// Add IPv4 receive function
if b.ipv4PC != nil || runtime.GOOS != "linux" {
recvFuncs = append(recvFuncs, b.makeReceiveIPv4())
}
// Add IPv6 receive function if needed
// For now, we focus on IPv4 for hole punching use case
b.recvFuncs = recvFuncs
return recvFuncs, b.port, nil
}
// makeReceiveIPv4 creates a receive function for IPv4 packets
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
if b.closed.Load() {
return 0, net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
pc := b.ipv4PC
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
// Use batch reading on Linux for performance
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
}
// Fallback to simple read for other platforms
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
}
}
// receiveIPv4Batch uses batch reading for better performance on Linux
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
// Create messages for batch reading
msgs := make([]ipv4.Message, len(bufs))
for i := range bufs {
msgs[i].Buffers = [][]byte{bufs[i]}
msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use
}
numMsgs, err := pc.ReadBatch(msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
sizes[i] = msgs[i].N
if sizes[i] == 0 {
continue
}
if msgs[i].Addr != nil {
if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok {
addrPort := udpAddr.AddrPort()
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
}
}
}
return numMsgs, nil
}
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
n, addr, err := conn.ReadFromUDP(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
if addr != nil {
addrPort := addr.AddrPort()
eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
}
return 1, nil
}
// Send implements the WireGuard Bind interface.
// It sends packets to the specified endpoint.
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
if b.closed.Load() {
return net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return net.ErrClosed
}
// Extract the destination address from the endpoint
var destAddr *net.UDPAddr
// Try to cast to StdNetEndpoint first
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
} else {
// Fallback: construct from DstIP and DstToBytes
dstBytes := ep.DstToBytes()
if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes)
var addr netip.Addr
var port uint16
if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes)
addr, _ = netip.AddrFromSlice(dstBytes[:16])
port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8
} else { // IPv4
addr, _ = netip.AddrFromSlice(dstBytes[:4])
port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8
}
if addr.IsValid() {
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
}
}
}
if destAddr == nil {
return fmt.Errorf("could not extract destination address from endpoint")
}
// Send all buffers to the destination
for _, buf := range bufs {
_, err := conn.WriteToUDP(buf, destAddr)
if err != nil {
return err
}
}
return nil
}
// SetMark implements the WireGuard Bind interface.
// It's a no-op for this implementation.
func (b *SharedBind) SetMark(mark uint32) error {
// Not implemented for this use case
return nil
}
// BatchSize returns the preferred batch size for sending packets.
func (b *SharedBind) BatchSize() int {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return wgConn.IdealBatchSize
}
return 1
}
// ParseEndpoint creates a new endpoint from a string address.
func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
addrPort, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil
}

424
bind/shared_bind_test.go Normal file
View File

@@ -0,0 +1,424 @@
//go:build !js
package bind
import (
"net"
"net/netip"
"sync"
"testing"
"time"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// TestSharedBindCreation tests basic creation and initialization
func TestSharedBindCreation(t *testing.T) {
// Create a UDP connection
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
defer udpConn.Close()
// Create SharedBind
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
if bind == nil {
t.Fatal("SharedBind is nil")
}
// Verify initial reference count
if bind.refCount.Load() != 1 {
t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load())
}
// Clean up
if err := bind.Close(); err != nil {
t.Errorf("Failed to close SharedBind: %v", err)
}
}
// TestSharedBindReferenceCount tests reference counting
func TestSharedBindReferenceCount(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
// Add references
bind.AddRef()
if bind.refCount.Load() != 2 {
t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load())
}
bind.AddRef()
if bind.refCount.Load() != 3 {
t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load())
}
// Release references
bind.Release()
if bind.refCount.Load() != 2 {
t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load())
}
bind.Release()
bind.Release() // This should close the connection
if !bind.closed.Load() {
t.Error("Expected bind to be closed after all references released")
}
}
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
func TestSharedBindWriteToUDP(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Send data
testData := []byte("Hello, SharedBind!")
n, err := senderBind.WriteToUDP(testData, receiverAddr)
if err != nil {
t.Fatalf("WriteToUDP failed: %v", err)
}
if n != len(testData) {
t.Errorf("Expected to send %d bytes, sent %d", len(testData), n)
}
// Receive data
buf := make([]byte, 1024)
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, _, err = receiverConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive data: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
}
}
// TestSharedBindConcurrentWrites tests thread-safety
func TestSharedBindConcurrentWrites(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Launch concurrent writes
numGoroutines := 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
data := []byte{byte(id)}
_, err := senderBind.WriteToUDP(data, receiverAddr)
if err != nil {
t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err)
}
}(i)
}
wg.Wait()
}
// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
func TestSharedBindWireGuardInterface(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer bind.Close()
// Test Open
recvFuncs, port, err := bind.Open(0)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
if len(recvFuncs) == 0 {
t.Error("Expected at least one receive function")
}
if port == 0 {
t.Error("Expected non-zero port")
}
// Test SetMark (should be a no-op)
if err := bind.SetMark(0); err != nil {
t.Errorf("SetMark failed: %v", err)
}
// Test BatchSize
batchSize := bind.BatchSize()
if batchSize <= 0 {
t.Error("Expected positive batch size")
}
}
// TestSharedBindSend tests the Send method with WireGuard endpoints
func TestSharedBindSend(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Create an endpoint
addrPort := receiverAddr.AddrPort()
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
// Send data
testData := []byte("WireGuard packet")
bufs := [][]byte{testData}
err = senderBind.Send(bufs, endpoint)
if err != nil {
t.Fatalf("Send failed: %v", err)
}
// Receive data
buf := make([]byte, 1024)
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, _, err := receiverConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive data: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
}
}
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
func TestSharedBindMultipleUsers(t *testing.T) {
// Create shared bind
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
sharedBind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
// Add reference for hole punch sender
sharedBind.AddRef()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
var wg sync.WaitGroup
// Simulate WireGuard using the bind
wg.Add(1)
go func() {
defer wg.Done()
addrPort := receiverAddr.AddrPort()
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
for i := 0; i < 10; i++ {
data := []byte("WireGuard packet")
bufs := [][]byte{data}
if err := sharedBind.Send(bufs, endpoint); err != nil {
t.Errorf("WireGuard Send failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
}()
// Simulate hole punch sender using the bind
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
data := []byte("Hole punch packet")
if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil {
t.Errorf("Hole punch WriteToUDP failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
}()
wg.Wait()
// Release the hole punch reference
sharedBind.Release()
// Close WireGuard's reference (should close the connection)
sharedBind.Close()
if !sharedBind.closed.Load() {
t.Error("Expected bind to be closed after all users released it")
}
}
// TestEndpoint tests the Endpoint implementation
func TestEndpoint(t *testing.T) {
addr := netip.MustParseAddr("192.168.1.1")
addrPort := netip.AddrPortFrom(addr, 51820)
ep := &Endpoint{AddrPort: addrPort}
// Test DstIP
if ep.DstIP() != addr {
t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP())
}
// Test DstToString
expected := "192.168.1.1:51820"
if ep.DstToString() != expected {
t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString())
}
// Test DstToBytes
bytes := ep.DstToBytes()
if len(bytes) == 0 {
t.Error("Expected DstToBytes to return non-empty slice")
}
// Test SrcIP (should be zero)
if ep.SrcIP().IsValid() {
t.Error("Expected SrcIP to be invalid")
}
// Test ClearSrc (should not panic)
ep.ClearSrc()
}
// TestParseEndpoint tests the ParseEndpoint method
func TestParseEndpoint(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer bind.Close()
tests := []struct {
name string
input string
wantErr bool
checkAddr func(*testing.T, wgConn.Endpoint)
}{
{
name: "valid IPv4",
input: "192.168.1.1:51820",
wantErr: false,
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
if ep.DstToString() != "192.168.1.1:51820" {
t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString())
}
},
},
{
name: "valid IPv6",
input: "[::1]:51820",
wantErr: false,
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
if ep.DstToString() != "[::1]:51820" {
t.Errorf("Expected [::1]:51820, got %s", ep.DstToString())
}
},
},
{
name: "invalid - missing port",
input: "192.168.1.1",
wantErr: true,
},
{
name: "invalid - bad format",
input: "not-an-address",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ep, err := bind.ParseEndpoint(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && tt.checkAddr != nil {
tt.checkAddr(t, ep)
}
})
}
}

562
config.go Normal file
View File

@@ -0,0 +1,562 @@
package main
import (
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"runtime"
"strconv"
"time"
)
// OlmConfig holds all configuration options for the Olm client
type OlmConfig struct {
// Connection settings
Endpoint string `json:"endpoint"`
ID string `json:"id"`
Secret string `json:"secret"`
OrgID string `json:"org"`
UserToken string `json:"userToken"`
// Network settings
MTU int `json:"mtu"`
DNS string `json:"dns"`
InterfaceName string `json:"interface"`
// Logging
LogLevel string `json:"logLevel"`
// HTTP server
EnableAPI bool `json:"enableApi"`
HTTPAddr string `json:"httpAddr"`
SocketPath string `json:"socketPath"`
// Ping settings
PingInterval string `json:"pingInterval"`
PingTimeout string `json:"pingTimeout"`
// Advanced
Holepunch bool `json:"holepunch"`
TlsClientCert string `json:"tlsClientCert"`
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
// Parsed values (not in JSON)
PingIntervalDuration time.Duration `json:"-"`
PingTimeoutDuration time.Duration `json:"-"`
// Source tracking (not in JSON)
sources map[string]string `json:"-"`
Version string
}
// ConfigSource tracks where each config value came from
type ConfigSource string
const (
SourceDefault ConfigSource = "default"
SourceFile ConfigSource = "file"
SourceEnv ConfigSource = "environment"
SourceCLI ConfigSource = "cli"
)
// DefaultConfig returns a config with default values
func DefaultConfig() *OlmConfig {
// Set OS-specific socket path
var socketPath string
switch runtime.GOOS {
case "windows":
socketPath = "olm"
default: // darwin, linux, and others
socketPath = "/var/run/olm.sock"
}
config := &OlmConfig{
MTU: 1280,
DNS: "8.8.8.8",
LogLevel: "INFO",
InterfaceName: "olm",
EnableAPI: false,
SocketPath: socketPath,
PingInterval: "3s",
PingTimeout: "5s",
Holepunch: false,
// DoNotCreateNewClient: false,
sources: make(map[string]string),
}
// Track default sources
config.sources["mtu"] = string(SourceDefault)
config.sources["dns"] = string(SourceDefault)
config.sources["logLevel"] = string(SourceDefault)
config.sources["interface"] = string(SourceDefault)
config.sources["enableApi"] = string(SourceDefault)
config.sources["httpAddr"] = string(SourceDefault)
config.sources["socketPath"] = string(SourceDefault)
config.sources["pingInterval"] = string(SourceDefault)
config.sources["pingTimeout"] = string(SourceDefault)
config.sources["holepunch"] = string(SourceDefault)
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
return config
}
// getOlmConfigPath returns the path to the olm config file
func getOlmConfigPath() string {
configFile := os.Getenv("CONFIG_FILE")
if configFile != "" {
return configFile
}
var configDir string
switch runtime.GOOS {
case "darwin":
configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client")
case "windows":
configDir = filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "olm-client")
default: // linux and others
configDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client")
}
if err := os.MkdirAll(configDir, 0755); err != nil {
fmt.Printf("Warning: Failed to create config directory: %v\n", err)
}
return filepath.Join(configDir, "config.json")
}
// LoadConfig loads configuration from file, env vars, and CLI args
// Priority: CLI args > Env vars > Config file > Defaults
// Returns: (config, showVersion, showConfig, error)
func LoadConfig(args []string) (*OlmConfig, bool, bool, error) {
// Start with defaults
config := DefaultConfig()
// Load from config file (if exists)
fileConfig, err := loadConfigFromFile()
if err != nil {
return nil, false, false, fmt.Errorf("failed to load config file: %w", err)
}
if fileConfig != nil {
mergeConfigs(config, fileConfig)
}
// Override with environment variables
loadConfigFromEnv(config)
// Override with CLI arguments
showVersion, showConfig, err := loadConfigFromCLI(config, args)
if err != nil {
return nil, false, false, err
}
// Parse duration strings
if err := config.parseDurations(); err != nil {
return nil, false, false, err
}
return config, showVersion, showConfig, nil
}
// loadConfigFromFile loads configuration from the JSON config file
func loadConfigFromFile() (*OlmConfig, error) {
configPath := getOlmConfigPath()
data, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil // File doesn't exist, not an error
}
return nil, err
}
var config OlmConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
return &config, nil
}
// loadConfigFromEnv loads configuration from environment variables
func loadConfigFromEnv(config *OlmConfig) {
if val := os.Getenv("PANGOLIN_ENDPOINT"); val != "" {
config.Endpoint = val
config.sources["endpoint"] = string(SourceEnv)
}
if val := os.Getenv("OLM_ID"); val != "" {
config.ID = val
config.sources["id"] = string(SourceEnv)
}
if val := os.Getenv("OLM_SECRET"); val != "" {
config.Secret = val
config.sources["secret"] = string(SourceEnv)
}
if val := os.Getenv("ORG"); val != "" {
config.OrgID = val
config.sources["org"] = string(SourceEnv)
}
if val := os.Getenv("USER_TOKEN"); val != "" {
config.UserToken = val
config.sources["userToken"] = string(SourceEnv)
}
if val := os.Getenv("MTU"); val != "" {
if mtu, err := strconv.Atoi(val); err == nil {
config.MTU = mtu
config.sources["mtu"] = string(SourceEnv)
} else {
fmt.Printf("Invalid MTU value: %s, keeping current value\n", val)
}
}
if val := os.Getenv("DNS"); val != "" {
config.DNS = val
config.sources["dns"] = string(SourceEnv)
}
if val := os.Getenv("LOG_LEVEL"); val != "" {
config.LogLevel = val
config.sources["logLevel"] = string(SourceEnv)
}
if val := os.Getenv("INTERFACE"); val != "" {
config.InterfaceName = val
config.sources["interface"] = string(SourceEnv)
}
if val := os.Getenv("HTTP_ADDR"); val != "" {
config.HTTPAddr = val
config.sources["httpAddr"] = string(SourceEnv)
}
if val := os.Getenv("PING_INTERVAL"); val != "" {
config.PingInterval = val
config.sources["pingInterval"] = string(SourceEnv)
}
if val := os.Getenv("PING_TIMEOUT"); val != "" {
config.PingTimeout = val
config.sources["pingTimeout"] = string(SourceEnv)
}
if val := os.Getenv("ENABLE_API"); val == "true" {
config.EnableAPI = true
config.sources["enableApi"] = string(SourceEnv)
}
if val := os.Getenv("SOCKET_PATH"); val != "" {
config.SocketPath = val
config.sources["socketPath"] = string(SourceEnv)
}
if val := os.Getenv("HOLEPUNCH"); val == "true" {
config.Holepunch = true
config.sources["holepunch"] = string(SourceEnv)
}
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
// config.DoNotCreateNewClient = true
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
// }
}
// loadConfigFromCLI loads configuration from command-line arguments
func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
// Store original values to detect changes
origValues := map[string]interface{}{
"endpoint": config.Endpoint,
"id": config.ID,
"secret": config.Secret,
"org": config.OrgID,
"userToken": config.UserToken,
"mtu": config.MTU,
"dns": config.DNS,
"logLevel": config.LogLevel,
"interface": config.InterfaceName,
"httpAddr": config.HTTPAddr,
"socketPath": config.SocketPath,
"pingInterval": config.PingInterval,
"pingTimeout": config.PingTimeout,
"enableApi": config.EnableAPI,
"holepunch": config.Holepunch,
// "doNotCreateNewClient": config.DoNotCreateNewClient,
}
// Define flags
serviceFlags.StringVar(&config.Endpoint, "endpoint", config.Endpoint, "Endpoint of your Pangolin server")
serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID")
serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret")
serviceFlags.StringVar(&config.OrgID, "org", config.OrgID, "Organization ID")
serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)")
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)")
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
version := serviceFlags.Bool("version", false, "Print the version")
showConfig := serviceFlags.Bool("show-config", false, "Show configuration sources and exit")
// Parse the arguments
if err := serviceFlags.Parse(args); err != nil {
return false, false, err
}
// Track which values were changed by CLI args
if config.Endpoint != origValues["endpoint"].(string) {
config.sources["endpoint"] = string(SourceCLI)
}
if config.ID != origValues["id"].(string) {
config.sources["id"] = string(SourceCLI)
}
if config.Secret != origValues["secret"].(string) {
config.sources["secret"] = string(SourceCLI)
}
if config.OrgID != origValues["org"].(string) {
config.sources["org"] = string(SourceCLI)
}
if config.UserToken != origValues["userToken"].(string) {
config.sources["userToken"] = string(SourceCLI)
}
if config.MTU != origValues["mtu"].(int) {
config.sources["mtu"] = string(SourceCLI)
}
if config.DNS != origValues["dns"].(string) {
config.sources["dns"] = string(SourceCLI)
}
if config.LogLevel != origValues["logLevel"].(string) {
config.sources["logLevel"] = string(SourceCLI)
}
if config.InterfaceName != origValues["interface"].(string) {
config.sources["interface"] = string(SourceCLI)
}
if config.HTTPAddr != origValues["httpAddr"].(string) {
config.sources["httpAddr"] = string(SourceCLI)
}
if config.SocketPath != origValues["socketPath"].(string) {
config.sources["socketPath"] = string(SourceCLI)
}
if config.PingInterval != origValues["pingInterval"].(string) {
config.sources["pingInterval"] = string(SourceCLI)
}
if config.PingTimeout != origValues["pingTimeout"].(string) {
config.sources["pingTimeout"] = string(SourceCLI)
}
if config.EnableAPI != origValues["enableApi"].(bool) {
config.sources["enableApi"] = string(SourceCLI)
}
if config.Holepunch != origValues["holepunch"].(bool) {
config.sources["holepunch"] = string(SourceCLI)
}
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
// }
return *version, *showConfig, nil
}
// parseDurations parses the duration strings into time.Duration
func (c *OlmConfig) parseDurations() error {
var err error
// Parse ping interval
if c.PingInterval != "" {
c.PingIntervalDuration, err = time.ParseDuration(c.PingInterval)
if err != nil {
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", c.PingInterval)
c.PingIntervalDuration = 3 * time.Second
c.PingInterval = "3s"
}
} else {
c.PingIntervalDuration = 3 * time.Second
c.PingInterval = "3s"
}
// Parse ping timeout
if c.PingTimeout != "" {
c.PingTimeoutDuration, err = time.ParseDuration(c.PingTimeout)
if err != nil {
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", c.PingTimeout)
c.PingTimeoutDuration = 5 * time.Second
c.PingTimeout = "5s"
}
} else {
c.PingTimeoutDuration = 5 * time.Second
c.PingTimeout = "5s"
}
return nil
}
// mergeConfigs merges source config into destination (only non-empty values)
// Also tracks that these values came from a file
func mergeConfigs(dest, src *OlmConfig) {
if src.Endpoint != "" {
dest.Endpoint = src.Endpoint
dest.sources["endpoint"] = string(SourceFile)
}
if src.ID != "" {
dest.ID = src.ID
dest.sources["id"] = string(SourceFile)
}
if src.Secret != "" {
dest.Secret = src.Secret
dest.sources["secret"] = string(SourceFile)
}
if src.OrgID != "" {
dest.OrgID = src.OrgID
dest.sources["org"] = string(SourceFile)
}
if src.UserToken != "" {
dest.UserToken = src.UserToken
dest.sources["userToken"] = string(SourceFile)
}
if src.MTU != 0 && src.MTU != 1280 {
dest.MTU = src.MTU
dest.sources["mtu"] = string(SourceFile)
}
if src.DNS != "" && src.DNS != "8.8.8.8" {
dest.DNS = src.DNS
dest.sources["dns"] = string(SourceFile)
}
if src.LogLevel != "" && src.LogLevel != "INFO" {
dest.LogLevel = src.LogLevel
dest.sources["logLevel"] = string(SourceFile)
}
if src.InterfaceName != "" && src.InterfaceName != "olm" {
dest.InterfaceName = src.InterfaceName
dest.sources["interface"] = string(SourceFile)
}
if src.HTTPAddr != "" && src.HTTPAddr != ":9452" {
dest.HTTPAddr = src.HTTPAddr
dest.sources["httpAddr"] = string(SourceFile)
}
if src.SocketPath != "" {
// Check if it's not the default for any OS
isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm"
if !isDefault {
dest.SocketPath = src.SocketPath
dest.sources["socketPath"] = string(SourceFile)
}
}
if src.PingInterval != "" && src.PingInterval != "3s" {
dest.PingInterval = src.PingInterval
dest.sources["pingInterval"] = string(SourceFile)
}
if src.PingTimeout != "" && src.PingTimeout != "5s" {
dest.PingTimeout = src.PingTimeout
dest.sources["pingTimeout"] = string(SourceFile)
}
if src.TlsClientCert != "" {
dest.TlsClientCert = src.TlsClientCert
dest.sources["tlsClientCert"] = string(SourceFile)
}
// For booleans, we always take the source value if explicitly set
if src.EnableAPI {
dest.EnableAPI = src.EnableAPI
dest.sources["enableApi"] = string(SourceFile)
}
if src.Holepunch {
dest.Holepunch = src.Holepunch
dest.sources["holepunch"] = string(SourceFile)
}
// if src.DoNotCreateNewClient {
// dest.DoNotCreateNewClient = src.DoNotCreateNewClient
// dest.sources["doNotCreateNewClient"] = string(SourceFile)
// }
}
// SaveConfig saves the current configuration to the config file
func SaveConfig(config *OlmConfig) error {
configPath := getOlmConfigPath()
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
return os.WriteFile(configPath, data, 0644)
}
// ShowConfig prints the configuration and the source of each value
func (c *OlmConfig) ShowConfig() {
configPath := getOlmConfigPath()
fmt.Println("\n=== Olm Configuration ===\n")
fmt.Printf("Config File: %s\n", configPath)
// Check if config file exists
if _, err := os.Stat(configPath); err == nil {
fmt.Printf("Config File Status: ✓ exists\n")
} else {
fmt.Printf("Config File Status: ✗ not found\n")
}
fmt.Println("\n--- Configuration Values ---")
fmt.Println("(Format: Setting = Value [source])\n")
// Helper to get source or default
getSource := func(key string) string {
if source, ok := c.sources[key]; ok {
return source
}
return string(SourceDefault)
}
// Helper to format value (mask secrets)
formatValue := func(key, value string) string {
if key == "secret" && value != "" {
if len(value) > 8 {
return value[:4] + "****" + value[len(value)-4:]
}
return "****"
}
if value == "" {
return "(not set)"
}
return value
}
// Connection settings
fmt.Println("Connection:")
fmt.Printf(" endpoint = %s [%s]\n", formatValue("endpoint", c.Endpoint), getSource("endpoint"))
fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id"))
fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret"))
fmt.Printf(" org = %s [%s]\n", formatValue("org", c.OrgID), getSource("org"))
fmt.Printf(" user-token = %s [%s]\n", formatValue("userToken", c.UserToken), getSource("userToken"))
// Network settings
fmt.Println("\nNetwork:")
fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu"))
fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns"))
fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface"))
// Logging
fmt.Println("\nLogging:")
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
// API server
fmt.Println("\nAPI Server:")
fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi"))
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath"))
// Timing
fmt.Println("\nTiming:")
fmt.Printf(" ping-interval = %s [%s]\n", c.PingInterval, getSource("pingInterval"))
fmt.Printf(" ping-timeout = %s [%s]\n", c.PingTimeout, getSource("pingTimeout"))
// Advanced
fmt.Println("\nAdvanced:")
fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch"))
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
if c.TlsClientCert != "" {
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
}
// Source legend
fmt.Println("\n--- Source Legend ---")
fmt.Println(" default = Built-in default value")
fmt.Println(" file = Loaded from config file")
fmt.Println(" environment = Set via environment variable")
fmt.Println(" cli = Provided as command-line argument")
fmt.Println("\nPriority: cli > environment > file > default")
fmt.Println()
}

523
diff Normal file
View File

@@ -0,0 +1,523 @@
diff --git a/api/api.go b/api/api.go
index dd07751..0d2e4ef 100644
--- a/api/api.go
+++ b/api/api.go
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
Endpoint string `json:"endpoint"`
}
+// SwitchOrgRequest defines the structure for switching organizations
+type SwitchOrgRequest struct {
+ OrgID string `json:"orgId"`
+}
+
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
@@ -35,6 +40,7 @@ type StatusResponse struct {
Registered bool `json:"registered"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"`
+ OrgID string `json:"orgId,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
@@ -46,6 +52,7 @@ type API struct {
server *http.Server
connectionChan chan ConnectionRequest
shutdownChan chan struct{}
+ switchOrgChan chan SwitchOrgRequest
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
@@ -53,6 +60,7 @@ type API struct {
isRegistered bool
tunnelIP string
version string
+ orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
@@ -61,6 +69,7 @@ func NewAPI(addr string) *API {
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
shutdownChan: make(chan struct{}, 1),
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API {
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
shutdownChan: make(chan struct{}, 1),
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -85,6 +95,7 @@ func (s *API) Start() error {
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/exit", s.handleExit)
+ mux.HandleFunc("/switch-org", s.handleSwitchOrg)
s.server = &http.Server{
Handler: mux,
@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
}
+// GetSwitchOrgChannel returns the channel for receiving org switch requests
+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
+ return s.switchOrgChan
+}
+
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) {
s.version = version
}
+// SetOrgID sets the org ID
+func (s *API) SetOrgID(orgID string) {
+ s.statusMu.Lock()
+ defer s.statusMu.Unlock()
+ s.orgID = orgID
+}
+
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
Registered: s.isRegistered,
TunnelIP: s.tunnelIP,
Version: s.version,
+ OrgID: s.orgID,
PeerStatuses: s.peerStatuses,
}
@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
"status": "shutdown initiated",
})
}
+
+// handleSwitchOrg handles the /switch-org endpoint
+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ var req SwitchOrgRequest
+ decoder := json.NewDecoder(r.Body)
+ if err := decoder.Decode(&req); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // Validate required fields
+ if req.OrgID == "" {
+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
+ return
+ }
+
+ logger.Info("Received org switch request to orgId: %s", req.OrgID)
+
+ // Send the request to the main goroutine
+ select {
+ case s.switchOrgChan <- req:
+ // Signal sent successfully
+ default:
+ // Channel already has a signal, don't block
+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests)
+ return
+ }
+
+ // Return a success response
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusAccepted)
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "org switch initiated",
+ "orgId": req.OrgID,
+ })
+}
diff --git a/olm/olm.go b/olm/olm.go
index 78080c4..5e292d6 100644
--- a/olm/olm.go
+++ b/olm/olm.go
@@ -58,6 +58,58 @@ type Config struct {
OrgID string
}
+// tunnelState holds all the active tunnel resources that need cleanup
+type tunnelState struct {
+ dev *device.Device
+ tdev tun.Device
+ uapiListener net.Listener
+ peerMonitor *peermonitor.PeerMonitor
+ stopRegister func()
+ connected bool
+}
+
+// teardownTunnel cleans up all tunnel resources
+func teardownTunnel(state *tunnelState) {
+ if state == nil {
+ return
+ }
+
+ logger.Info("Tearing down tunnel...")
+
+ // Stop registration messages
+ if state.stopRegister != nil {
+ state.stopRegister()
+ state.stopRegister = nil
+ }
+
+ // Stop peer monitor
+ if state.peerMonitor != nil {
+ state.peerMonitor.Stop()
+ state.peerMonitor = nil
+ }
+
+ // Close UAPI listener
+ if state.uapiListener != nil {
+ state.uapiListener.Close()
+ state.uapiListener = nil
+ }
+
+ // Close WireGuard device
+ if state.dev != nil {
+ state.dev.Close()
+ state.dev = nil
+ }
+
+ // Close TUN device
+ if state.tdev != nil {
+ state.tdev.Close()
+ state.tdev = nil
+ }
+
+ state.connected = false
+ logger.Info("Tunnel teardown complete")
+}
+
func Run(ctx context.Context, config Config) {
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) {
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
privateKey wgtypes.Key
- connected bool
- dev *device.Device
wgData WgData
holePunchData HolePunchData
- uapiListener net.Listener
- tdev tun.Device
+ orgID = config.OrgID
)
+ // Tunnel state that can be torn down and recreated
+ tunnel := &tunnelState{}
+
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) {
}
apiServer.SetVersion(config.Version)
+ apiServer.SetOrgID(orgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) {
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
- if connected {
+ if tunnel.connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
- if stopRegister != nil {
- stopRegister()
- stopRegister = nil
+ if tunnel.stopRegister != nil {
+ tunnel.stopRegister()
+ tunnel.stopRegister = nil
}
close(stopHolepunch)
@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) {
time.Sleep(500 * time.Millisecond)
// if there is an existing tunnel then close it
- if dev != nil {
+ if tunnel.dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
- dev.Close()
+ tunnel.dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) {
return
}
- tdev, err = func() (tun.Device, error) {
+ tunnel.tdev, err = func() (tun.Device, error) {
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) {
return
}
- if realInterfaceName, err2 := tdev.Name(); err2 == nil {
+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}
@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) {
return
}
- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
- uapiListener, err = uapiListen(interfaceName, fileUAPI)
+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) {
go func() {
for {
- conn, err := uapiListener.Accept()
+ conn, err := tunnel.uapiListener.Accept()
if err != nil {
return
}
- go dev.IpcHandle(conn)
+ go tunnel.dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
- if err = dev.Up(); err != nil {
+ if err = tunnel.dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetTunnelIP(wgData.TunnelIP)
}
- peerMonitor = peermonitor.NewPeerMonitor(
+ tunnel.peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
if apiServer != nil {
// Find the site config to get endpoint information
@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) {
},
fixKey(privateKey.String()),
olm,
- dev,
+ tunnel.dev,
doHolepunch,
)
@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) {
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) {
logger.Info("Configured peer %s", site.PublicKey)
}
- peerMonitor.Start()
+ tunnel.peerMonitor.Start()
if apiServer != nil {
apiServer.SetRegistered(true)
}
- connected = true
+ tunnel.connected = true
logger.Info("WireGuard device created.")
})
@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) {
}
// Update the peer in WireGuard
- if dev != nil {
+ if tunnel.dev != nil {
// Find the existing peer to get old data
var oldRemoteSubnets string
var oldPublicKey string
@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) {
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) {
// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) {
}
// Add the peer to WireGuard
- if dev != nil {
+ if tunnel.dev != nil {
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) {
}
// Remove the peer from WireGuard
- if dev != nil {
- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
+ if tunnel.dev != nil {
+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
logger.Error("Failed to remove peer: %v", err)
// Send error response if needed
return
@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) {
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetConnectionStatus(true)
}
- if connected {
+ if tunnel.connected {
logger.Debug("Already connected, skipping registration")
return nil
}
@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) {
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"olmVersion": config.Version,
- "orgId": config.OrgID,
+ "orgId": orgID,
}, 1*time.Second)
go keepSendingPing(olm)
@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) {
}
defer olm.Close()
+ // Listen for org switch requests from the API (after olm is created)
+ if apiServer != nil {
+ go func() {
+ for req := range apiServer.GetSwitchOrgChannel() {
+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID)
+
+ // Update the orgId
+ orgID = req.OrgID
+
+ // Teardown existing tunnel
+ teardownTunnel(tunnel)
+
+ // Reset tunnel state
+ tunnel = &tunnelState{}
+
+ // Stop holepunch
+ select {
+ case <-stopHolepunch:
+ // Channel already closed
+ default:
+ close(stopHolepunch)
+ }
+ stopHolepunch = make(chan struct{})
+
+ // Clear API server state
+ apiServer.SetRegistered(false)
+ apiServer.SetTunnelIP("")
+ apiServer.SetOrgID(orgID)
+
+ // Send new registration message with updated orgId
+ publicKey := privateKey.PublicKey()
+ logger.Info("Sending registration message with new orgId: %s", orgID)
+
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
+ "publicKey": publicKey.String(),
+ "relay": !doHolepunch,
+ "olmVersion": config.Version,
+ "orgId": orgID,
+ }, 1*time.Second)
+ }
+ }()
+ }
+
select {
case <-ctx.Done():
logger.Info("Context cancelled")
@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) {
close(stopHolepunch)
}
- if stopRegister != nil {
- stopRegister()
- stopRegister = nil
+ if tunnel.stopRegister != nil {
+ tunnel.stopRegister()
+ tunnel.stopRegister = nil
}
select {
@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) {
close(stopPing)
}
- if peerMonitor != nil {
- peerMonitor.Stop()
- }
-
- if uapiListener != nil {
- uapiListener.Close()
- }
- if dev != nil {
- dev.Close()
- }
+ // Use teardownTunnel to clean up all tunnel resources
+ teardownTunnel(tunnel)
if apiServer != nil {
apiServer.Stop()

View File

@@ -1,9 +1,15 @@
services: services:
newt: olm:
image: fosrl/newt:latest image: fosrl/olm:latest
container_name: newt container_name: olm
restart: unless-stopped restart: unless-stopped
environment: environment:
- PANGOLIN_ENDPOINT=https://example.com - PANGOLIN_ENDPOINT=https://example.com
- NEWT_ID=2ix2t8xk22ubpfy - OLM_ID=vdqnz8rwgb95cnp
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2 - OLM_SECRET=1sw05qv1tkfdb1k81zpw05nahnnjvmhxjvf746umwagddmdg
cap_add:
- NET_ADMIN
- SYS_MODULE
devices:
- /dev/net/tun:/dev/net/tun
network_mode: host

View File

@@ -4,7 +4,7 @@ set -e
# first arg is `-f` or `--some-option` # first arg is `-f` or `--some-option`
if [ "${1#-}" != "$1" ]; then if [ "${1#-}" != "$1" ]; then
set -- newt "$@" set -- olm "$@"
fi fi
exec "$@" exec "$@"

279
get-olm.sh Normal file
View File

@@ -0,0 +1,279 @@
#!/bin/bash
# Get Olm - Cross-platform installation script
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/olm/refs/heads/main/get-olm.sh | bash
set -e
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# GitHub repository info
REPO="fosrl/olm"
GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
# Function to print colored output
print_status() {
echo -e "${GREEN}[INFO]${NC} $1"
}
print_warning() {
echo -e "${YELLOW}[WARN]${NC} $1"
}
print_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
# Function to get latest version from GitHub API
get_latest_version() {
local latest_info
if command -v curl >/dev/null 2>&1; then
latest_info=$(curl -fsSL "$GITHUB_API_URL" 2>/dev/null)
elif command -v wget >/dev/null 2>&1; then
latest_info=$(wget -qO- "$GITHUB_API_URL" 2>/dev/null)
else
print_error "Neither curl nor wget is available. Please install one of them." >&2
exit 1
fi
if [ -z "$latest_info" ]; then
print_error "Failed to fetch latest version information" >&2
exit 1
fi
# Extract version from JSON response (works without jq)
local version=$(echo "$latest_info" | grep '"tag_name"' | head -1 | sed 's/.*"tag_name": *"\([^"]*\)".*/\1/')
if [ -z "$version" ]; then
print_error "Could not parse version from GitHub API response" >&2
exit 1
fi
# Remove 'v' prefix if present
version=$(echo "$version" | sed 's/^v//')
echo "$version"
}
# Detect OS and architecture
detect_platform() {
local os arch
# Detect OS
case "$(uname -s)" in
Linux*) os="linux" ;;
Darwin*) os="darwin" ;;
MINGW*|MSYS*|CYGWIN*) os="windows" ;;
FreeBSD*) os="freebsd" ;;
*)
print_error "Unsupported operating system: $(uname -s)"
exit 1
;;
esac
# Detect architecture
case "$(uname -m)" in
x86_64|amd64) arch="amd64" ;;
arm64|aarch64) arch="arm64" ;;
armv7l|armv6l)
if [ "$os" = "linux" ]; then
if [ "$(uname -m)" = "armv6l" ]; then
arch="arm32v6"
else
arch="arm32"
fi
else
arch="arm64" # Default for non-Linux ARM
fi
;;
riscv64)
if [ "$os" = "linux" ]; then
arch="riscv64"
else
print_error "RISC-V architecture only supported on Linux"
exit 1
fi
;;
*)
print_error "Unsupported architecture: $(uname -m)"
exit 1
;;
esac
echo "${os}_${arch}"
}
# Get installation directory
get_install_dir() {
local platform="$1"
if [[ "$platform" == *"windows"* ]]; then
echo "$HOME/bin"
else
# For Unix-like systems, prioritize system-wide directories for sudo access
# Check in order of preference: /usr/local/bin, /usr/bin, ~/.local/bin
if [ -d "/usr/local/bin" ]; then
echo "/usr/local/bin"
elif [ -d "/usr/bin" ]; then
echo "/usr/bin"
else
# Fallback to user directory if system directories don't exist
echo "$HOME/.local/bin"
fi
fi
}
# Check if we need sudo for installation
need_sudo() {
local install_dir="$1"
# If installing to system directory and we don't have write permission, need sudo
if [[ "$install_dir" == "/usr/local/bin" || "$install_dir" == "/usr/bin" ]]; then
if [ ! -w "$install_dir" ] 2>/dev/null; then
return 0 # Need sudo
fi
fi
return 1 # Don't need sudo
}
# Download and install olm
install_olm() {
local platform="$1"
local install_dir="$2"
local binary_name="olm_${platform}"
local exe_suffix=""
# Add .exe suffix for Windows
if [[ "$platform" == *"windows"* ]]; then
binary_name="${binary_name}.exe"
exe_suffix=".exe"
fi
local download_url="${BASE_URL}/${binary_name}"
local temp_file="/tmp/olm${exe_suffix}"
local final_path="${install_dir}/olm${exe_suffix}"
print_status "Downloading olm from ${download_url}"
# Download the binary
if command -v curl >/dev/null 2>&1; then
curl -fsSL "$download_url" -o "$temp_file"
elif command -v wget >/dev/null 2>&1; then
wget -q "$download_url" -O "$temp_file"
else
print_error "Neither curl nor wget is available. Please install one of them."
exit 1
fi
# Check if we need sudo for installation
local use_sudo=""
if need_sudo "$install_dir"; then
print_status "Administrator privileges required for system-wide installation"
if command -v sudo >/dev/null 2>&1; then
use_sudo="sudo"
else
print_error "sudo is required for system-wide installation but not available"
exit 1
fi
fi
# Create install directory if it doesn't exist
if [ -n "$use_sudo" ]; then
$use_sudo mkdir -p "$install_dir"
else
mkdir -p "$install_dir"
fi
# Move binary to install directory
if [ -n "$use_sudo" ]; then
$use_sudo mv "$temp_file" "$final_path"
$use_sudo chmod +x "$final_path"
else
mv "$temp_file" "$final_path"
chmod +x "$final_path"
fi
print_status "olm installed to ${final_path}"
# Check if install directory is in PATH (only warn for non-system directories)
if [[ "$install_dir" != "/usr/local/bin" && "$install_dir" != "/usr/bin" ]]; then
if ! echo "$PATH" | grep -q "$install_dir"; then
print_warning "Install directory ${install_dir} is not in your PATH."
print_warning "Add it to your PATH by adding this line to your shell profile:"
print_warning " export PATH=\"${install_dir}:\$PATH\""
fi
fi
}
# Verify installation
verify_installation() {
local install_dir="$1"
local exe_suffix=""
if [[ "$PLATFORM" == *"windows"* ]]; then
exe_suffix=".exe"
fi
local olm_path="${install_dir}/olm${exe_suffix}"
if [ -f "$olm_path" ] && [ -x "$olm_path" ]; then
print_status "Installation successful!"
print_status "olm version: $("$olm_path" --version 2>/dev/null || echo "unknown")"
return 0
else
print_error "Installation failed. Binary not found or not executable."
return 1
fi
}
# Main installation process
main() {
print_status "Installing latest version of olm..."
# Get latest version
print_status "Fetching latest version from GitHub..."
VERSION=$(get_latest_version)
print_status "Latest version: v${VERSION}"
# Set base URL with the fetched version
BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}"
# Detect platform
PLATFORM=$(detect_platform)
print_status "Detected platform: ${PLATFORM}"
# Get install directory
INSTALL_DIR=$(get_install_dir "$PLATFORM")
print_status "Install directory: ${INSTALL_DIR}"
# Inform user about system-wide installation
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
print_status "Installing system-wide for sudo access"
fi
# Install olm
install_olm "$PLATFORM" "$INSTALL_DIR"
# Verify installation
if verify_installation "$INSTALL_DIR"; then
print_status "olm is ready to use!"
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
print_status "olm is installed system-wide and accessible via sudo"
fi
if [[ "$PLATFORM" == *"windows"* ]]; then
print_status "Run 'olm --help' to get started"
else
print_status "Run 'olm --help' or 'sudo olm --help' to get started"
fi
else
exit 1
fi
}
# Run main function
main "$@"

33
go.mod
View File

@@ -1,19 +1,22 @@
module github.com/fosrl/newt module github.com/fosrl/olm
go 1.23.1 go 1.25
toolchain go1.23.2
require golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
require ( require (
github.com/google/btree v1.1.2 // indirect github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7
github.com/gorilla/websocket v1.5.3 // indirect github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.28.0 // indirect golang.org/x/crypto v0.43.0
golang.org/x/net v0.30.0 // indirect golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
golang.org/x/sys v0.26.0 // indirect golang.org/x/sys v0.37.0
golang.org/x/time v0.7.0 // indirect golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect )
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
require (
github.com/gorilla/websocket v1.5.3 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/net v0.45.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
software.sslmate.com/src/go-pkcs12 v0.6.0 // indirect
) )

46
go.sum
View File

@@ -1,20 +1,34 @@
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 h1:6bSU8Efyhx1SR53iSw1Wjk5V8vDfizGAudq/GlE9b+o=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7/go.mod h1:Ac0k2FmAMC+hu21rAK+p7EnnEGrqKO/QZuGTVHA/XDM=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs=
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=
software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=

351
holepunch/holepunch.go Normal file
View File

@@ -0,0 +1,351 @@
package holepunch
import (
"encoding/json"
"fmt"
"net"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/bind"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// DomainResolver is a function type for resolving domains to IP addresses
type DomainResolver func(string) (string, error)
// ExitNode represents a WireGuard exit node for hole punching
type ExitNode struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
// Manager handles UDP hole punching operations
type Manager struct {
mu sync.Mutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
olmID string
token string
domainResolver DomainResolver
}
// NewManager creates a new hole punch manager
func NewManager(sharedBind *bind.SharedBind, olmID string, domainResolver DomainResolver) *Manager {
return &Manager{
sharedBind: sharedBind,
olmID: olmID,
domainResolver: domainResolver,
}
}
// SetToken updates the authentication token used for hole punching
func (m *Manager) SetToken(token string) {
m.mu.Lock()
defer m.mu.Unlock()
m.token = token
}
// IsRunning returns whether hole punching is currently active
func (m *Manager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.running
}
// Stop stops any ongoing hole punch operations
func (m *Manager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return
}
if m.stopChan != nil {
close(m.stopChan)
m.stopChan = nil
}
m.running = false
logger.Info("Hole punch manager stopped")
}
// StartMultipleExitNodes starts hole punching to multiple exit nodes
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
if len(exitNodes) == 0 {
m.mu.Unlock()
logger.Warn("No exit nodes provided for hole punching")
return fmt.Errorf("no exit nodes provided")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
go m.runMultipleExitNodes(exitNodes)
return nil
}
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
go m.runSingleEndpoint(endpoint, serverPubKey)
return nil
}
// runMultipleExitNodes performs hole punching to multiple exit nodes
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for all exit nodes")
}()
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
var resolvedNodes []resolvedExitNode
for _, exitNode := range exitNodes {
host, err := m.domainResolver(exitNode.Endpoint)
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
if len(resolvedNodes) == 0 {
logger.Error("No exit nodes could be resolved")
return
}
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
}
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-timeout.C:
logger.Debug("Hole punch timeout reached")
return
case <-ticker.C:
// Send hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
}
}
}
}
}
// runSingleEndpoint performs hole punching to a single endpoint
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
}()
host, err := m.domainResolver(endpoint)
if err != nil {
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
return
}
// Execute once immediately before starting the loop
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Warn("Failed to send initial hole punch: %v", err)
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-timeout.C:
logger.Debug("Hole punch timeout reached")
return
case <-ticker.C:
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Debug("Failed to send hole punch: %v", err)
}
}
}
}
// sendHolePunch sends an encrypted hole punch packet using the shared bind
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
m.mu.Lock()
token := m.token
olmID := m.olmID
m.mu.Unlock()
if serverPubKey == "" || token == "" {
return fmt.Errorf("server public key or OLM token is empty")
}
payload := struct {
OlmID string `json:"olmId"`
Token string `json:"token"`
}{
OlmID: olmID,
Token: token,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %w", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
}
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to write to UDP: %w", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}

View File

@@ -1,27 +0,0 @@
package logger
type LogLevel int
const (
DEBUG LogLevel = iota
INFO
WARN
ERROR
FATAL
)
var levelStrings = map[LogLevel]string{
DEBUG: "DEBUG",
INFO: "INFO",
WARN: "WARN",
ERROR: "ERROR",
FATAL: "FATAL",
}
// String returns the string representation of the log level
func (l LogLevel) String() string {
if s, ok := levelStrings[l]; ok {
return s
}
return "UNKNOWN"
}

View File

@@ -1,106 +0,0 @@
package logger
import (
"fmt"
"log"
"os"
"sync"
"time"
)
// Logger struct holds the logger instance
type Logger struct {
logger *log.Logger
level LogLevel
}
var (
defaultLogger *Logger
once sync.Once
)
// NewLogger creates a new logger instance
func NewLogger() *Logger {
return &Logger{
logger: log.New(os.Stdout, "", 0),
level: DEBUG,
}
}
// Init initializes the default logger
func Init() *Logger {
once.Do(func() {
defaultLogger = NewLogger()
})
return defaultLogger
}
// GetLogger returns the default logger instance
func GetLogger() *Logger {
if defaultLogger == nil {
Init()
}
return defaultLogger
}
// SetLevel sets the minimum logging level
func (l *Logger) SetLevel(level LogLevel) {
l.level = level
}
// log handles the actual logging
func (l *Logger) log(level LogLevel, format string, args ...interface{}) {
if level < l.level {
return
}
timestamp := time.Now().Format("2006/01/02 15:04:05")
message := fmt.Sprintf(format, args...)
l.logger.Printf("%s: %s %s", level.String(), timestamp, message)
}
// Debug logs debug level messages
func (l *Logger) Debug(format string, args ...interface{}) {
l.log(DEBUG, format, args...)
}
// Info logs info level messages
func (l *Logger) Info(format string, args ...interface{}) {
l.log(INFO, format, args...)
}
// Warn logs warning level messages
func (l *Logger) Warn(format string, args ...interface{}) {
l.log(WARN, format, args...)
}
// Error logs error level messages
func (l *Logger) Error(format string, args ...interface{}) {
l.log(ERROR, format, args...)
}
// Fatal logs fatal level messages and exits
func (l *Logger) Fatal(format string, args ...interface{}) {
l.log(FATAL, format, args...)
os.Exit(1)
}
// Global helper functions
func Debug(format string, args ...interface{}) {
GetLogger().Debug(format, args...)
}
func Info(format string, args ...interface{}) {
GetLogger().Info(format, args...)
}
func Warn(format string, args ...interface{}) {
GetLogger().Warn(format, args...)
}
func Error(format string, args ...interface{}) {
GetLogger().Error(format, args...)
}
func Fatal(format string, args ...interface{}) {
GetLogger().Fatal(format, args...)
}

765
main.go
View File

@@ -1,603 +1,220 @@
package main package main
import ( import (
"bytes" "context"
"encoding/base64"
"encoding/hex"
"encoding/json"
"flag"
"fmt" "fmt"
"math/rand"
"net"
"net/netip"
"os" "os"
"os/signal" "os/signal"
"strings" "runtime"
"syscall" "syscall"
"time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy" "github.com/fosrl/olm/olm"
"github.com/fosrl/newt/websocket"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type WgData struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
ServerIP string `json:"serverIP"`
TunnelIP string `json:"tunnelIP"`
Targets TargetsByType `json:"targets"`
}
type TargetsByType struct {
UDP []string `json:"udp"`
TCP []string `json:"tcp"`
}
type TargetData struct {
Targets []string `json:"targets"`
}
func fixKey(key string) string {
// Remove any whitespace
key = strings.TrimSpace(key)
// Decode from base64
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
logger.Fatal("Error decoding base64:", err)
}
// Convert to hex
return hex.EncodeToString(decoded)
}
func ping(tnet *netstack.Net, dst string) error {
logger.Info("Pinging %s", dst)
socket, err := tnet.Dial("ping4", dst)
if err != nil {
return fmt.Errorf("failed to create ICMP socket: %w", err)
}
defer socket.Close()
requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16),
Data: []byte("gopher burrow"),
}
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
if err != nil {
return fmt.Errorf("failed to marshal ICMP message: %w", err)
}
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
return fmt.Errorf("failed to set read deadline: %w", err)
}
start := time.Now()
_, err = socket.Write(icmpBytes)
if err != nil {
return fmt.Errorf("failed to write ICMP packet: %w", err)
}
n, err := socket.Read(icmpBytes[:])
if err != nil {
return fmt.Errorf("failed to read ICMP packet: %w", err)
}
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
if err != nil {
return fmt.Errorf("failed to parse ICMP packet: %w", err)
}
replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body)
}
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q",
replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data)
}
logger.Info("Ping latency: %v", time.Since(start))
return nil
}
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
go func() {
for {
select {
case <-ticker.C:
err := ping(tnet, serverIP)
if err != nil {
logger.Warn("Periodic ping failed: %v", err)
}
case <-stopChan:
logger.Info("Stopping ping check")
return
}
}
}()
}
func pingWithRetry(tnet *netstack.Net, dst string) error {
const (
maxAttempts = 5
retryDelay = 2 * time.Second
)
var lastErr error
for attempt := 1; attempt <= maxAttempts; attempt++ {
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
if err := ping(tnet, dst); err != nil {
lastErr = err
logger.Warn("Ping attempt %d failed: %v", attempt, err)
if attempt < maxAttempts {
time.Sleep(retryDelay)
continue
}
return fmt.Errorf("all ping attempts failed after %d tries, last error: %w",
maxAttempts, lastErr)
}
// Successful ping
return nil
}
// This shouldn't be reached due to the return in the loop, but added for completeness
return fmt.Errorf("unexpected error: all ping attempts failed")
}
func parseLogLevel(level string) logger.LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":
return logger.DEBUG
case "INFO":
return logger.INFO
case "WARN":
return logger.WARN
case "ERROR":
return logger.ERROR
case "FATAL":
return logger.FATAL
default:
return logger.INFO // default to INFO if invalid level provided
}
}
func mapToWireGuardLogLevel(level logger.LogLevel) int {
switch level {
case logger.DEBUG:
return device.LogLevelVerbose
// case logger.INFO:
// return device.LogLevel
case logger.WARN:
return device.LogLevelError
case logger.ERROR, logger.FATAL:
return device.LogLevelSilent
default:
return device.LogLevelSilent
}
}
func resolveDomain(domain string) (string, error) {
// Check if there's a port in the domain
host, port, err := net.SplitHostPort(domain)
if err != nil {
// No port found, use the domain as is
host = domain
port = ""
}
// Remove any protocol prefix if present
if strings.HasPrefix(host, "http://") {
host = strings.TrimPrefix(host, "http://")
} else if strings.HasPrefix(host, "https://") {
host = strings.TrimPrefix(host, "https://")
}
// Lookup IP addresses
ips, err := net.LookupIP(host)
if err != nil {
return "", fmt.Errorf("DNS lookup failed: %v", err)
}
if len(ips) == 0 {
return "", fmt.Errorf("no IP addresses found for domain %s", host)
}
// Get the first IPv4 address if available
var ipAddr string
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ipAddr = ipv4.String()
break
}
}
// If no IPv4 found, use the first IP (might be IPv6)
if ipAddr == "" {
ipAddr = ips[0].String()
}
// Add port back if it existed
if port != "" {
ipAddr = net.JoinHostPort(ipAddr, port)
}
return ipAddr, nil
}
func main() { func main() {
var ( // Check if we're running as a Windows service
endpoint string if isWindowsService() {
id string runService("OlmWireguardService", false, os.Args[1:])
secret string fmt.Println("Running as Windows service")
dns string return
privateKey wgtypes.Key
err error
logLevel string
)
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
id = os.Getenv("NEWT_ID")
secret = os.Getenv("NEWT_SECRET")
dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL")
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
}
if id == "" {
flag.StringVar(&id, "id", "", "Newt ID")
}
if secret == "" {
flag.StringVar(&secret, "secret", "", "Newt secret")
}
if dns == "" {
flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
}
if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
flag.Parse()
logger.Init()
loggerLevel := parseLogLevel(logLevel)
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
// Validate required fields
if endpoint == "" || id == "" || secret == "" {
logger.Fatal("endpoint, id, and secret are required either via CLI flags or environment variables")
} }
privateKey, err = wgtypes.GeneratePrivateKey() // Handle service management commands on Windows
if err != nil { if runtime.GOOS == "windows" {
logger.Fatal("Failed to generate private key: %v", err) var command string
} if len(os.Args) > 1 {
command = os.Args[1]
// Create a new client } else {
client, err := websocket.NewClient( command = "default"
id, // CLI arg takes precedence
secret, // CLI arg takes precedence
endpoint,
)
if err != nil {
logger.Fatal("Failed to create client: %v", err)
}
// Create TUN device and network stack
var tun tun.Device
var tnet *netstack.Net
var dev *device.Device
var pm *proxy.ProxyManager
var connected bool
var wgData WgData
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
if pm != nil {
pm.Stop()
} }
if dev != nil {
dev.Close()
}
client.Close()
})
pingStopChan := make(chan struct{}) switch command {
defer close(pingStopChan) case "install":
err := installService()
// Register handlers for different message types
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
logger.Info("Received registration message")
if connected {
logger.Info("Already connected! But I will send a ping anyway...")
// ping(tnet, wgData.ServerIP)
err = pingWithRetry(tnet, wgData.ServerIP)
if err != nil { if err != nil {
// Handle complete failure after all retries fmt.Printf("Failed to install service: %v\n", err)
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err) os.Exit(1)
} }
fmt.Println("Service installed successfully")
return return
} case "remove", "uninstall":
err := removeService()
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
logger.Info("Received: %+v", msg)
tun, tnet, err = netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
[]netip.Addr{netip.MustParseAddr(dns)},
1420)
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
}
// Create WireGuard device
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
mapToWireGuardLogLevel(loggerLevel),
"wireguard: ",
))
endpoint, err := resolveDomain(wgData.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint: %v", err)
return
}
// Configure WireGuard
config := fmt.Sprintf(`private_key=%s
public_key=%s
allowed_ip=%s/32
endpoint=%s
persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
err = dev.IpcSet(config)
if err != nil {
logger.Error("Failed to configure WireGuard device: %v", err)
}
// Bring up the device
err = dev.Up()
if err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
logger.Info("WireGuard device created. Lets ping the server now...")
// Ping to bring the tunnel up on the server side quickly
// ping(tnet, wgData.ServerIP)
err = pingWithRetry(tnet, wgData.ServerIP)
if err != nil {
// Handle complete failure after all retries
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
}
if !connected {
logger.Info("Starting ping check")
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
}
// Create proxy manager
pm = proxy.NewProxyManager(tnet)
connected = true
// add the targets if there are any
if len(wgData.Targets.TCP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
}
if len(wgData.Targets.UDP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
}
err = pm.Start()
if err != nil {
logger.Error("Failed to start proxy manager: %v", err)
}
})
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
}
err = pm.Start()
if err != nil {
logger.Error("Failed to start proxy manager: %v", err)
}
})
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
}
err = pm.Start()
if err != nil {
logger.Error("Failed to start proxy manager: %v", err)
}
})
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
}
})
client.OnConnect(func() error {
publicKey := privateKey.PublicKey()
logger.Debug("Public key: %s", publicKey)
err := client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", publicKey),
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent registration message")
return nil
})
// Connect to the WebSocket server
if err := client.Connect(); err != nil {
logger.Fatal("Failed to connect to server: %v", err)
}
defer client.Close()
// Wait for interrupt signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
// Cleanup
dev.Close()
}
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return targetData, err
}
if err := json.Unmarshal(jsonData, &targetData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return targetData, err
}
return targetData, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
logger.Info("Invalid port: %s", parts[0])
continue
}
if action == "add" {
target := parts[1] + ":" + parts[2]
// Only remove the specific target if it exists
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil { if err != nil {
// Ignore "target not found" errors as this is expected for new targets fmt.Printf("Failed to remove service: %v\n", err)
if !strings.Contains(err.Error(), "target not found") { os.Exit(1)
logger.Error("Failed to remove existing target: %v", err) }
fmt.Println("Service removed successfully")
return
case "start":
// Pass the remaining arguments (after "start") to the service
serviceArgs := os.Args[2:]
err := startService(serviceArgs)
if err != nil {
fmt.Printf("Failed to start service: %v\n", err)
os.Exit(1)
}
fmt.Println("Service started successfully")
return
case "stop":
err := stopService()
if err != nil {
fmt.Printf("Failed to stop service: %v\n", err)
os.Exit(1)
}
fmt.Println("Service stopped successfully")
return
case "status":
status, err := getServiceStatus()
if err != nil {
fmt.Printf("Failed to get service status: %v\n", err)
os.Exit(1)
}
fmt.Printf("Service status: %s\n", status)
return
case "debug":
// get the status and if it is Not Installed then install it first
status, err := getServiceStatus()
if err != nil {
fmt.Printf("Failed to get service status: %v\n", err)
os.Exit(1)
}
if status == "Not Installed" {
err := installService()
if err != nil {
fmt.Printf("Failed to install service: %v\n", err)
os.Exit(1)
} }
fmt.Println("Service installed successfully, now running in debug mode")
} }
// Add the new target // Pass the remaining arguments (after "debug") to the service
pm.AddTarget(proto, tunnelIP, port, target) serviceArgs := os.Args[2:]
err = debugService(serviceArgs)
} else if action == "remove" {
logger.Info("Removing target with port %d", port)
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil { if err != nil {
logger.Error("Failed to remove target: %v", err) fmt.Printf("Failed to debug service: %v\n", err)
return err os.Exit(1)
} }
return
case "logs":
err := watchLogFile(false)
if err != nil {
fmt.Printf("Failed to watch log file: %v\n", err)
os.Exit(1)
}
return
case "config":
if runtime.GOOS == "windows" {
showServiceConfig()
} else {
fmt.Println("Service configuration is only available on Windows")
}
return
case "help", "--help", "-h":
fmt.Println("Olm WireGuard VPN Client")
fmt.Println("\nWindows Service Management:")
fmt.Println(" install Install the service")
fmt.Println(" remove Remove the service")
fmt.Println(" start [args] Start the service with optional arguments")
fmt.Println(" stop Stop the service")
fmt.Println(" status Show service status")
fmt.Println(" debug [args] Run service in debug mode with optional arguments")
fmt.Println(" logs Tail the service log file")
fmt.Println(" config Show current service configuration")
fmt.Println("\nExamples:")
fmt.Println(" olm start --enable-http --http-addr :9452")
fmt.Println(" olm debug --endpoint https://example.com --id myid --secret mysecret")
fmt.Println("\nFor console mode, run without arguments or with standard flags.")
return
default:
// get the status and if it is Not Installed then install it first
status, err := getServiceStatus()
if err != nil {
fmt.Printf("Failed to get service status: %v\n", err)
os.Exit(1)
}
if status == "Not Installed" {
err := installService()
if err != nil {
fmt.Printf("Failed to install service: %v\n", err)
os.Exit(1)
}
fmt.Println("Service installed successfully, now running")
}
// Pass the remaining arguments (after "debug") to the service
serviceArgs := os.Args[1:]
err = debugService(serviceArgs)
if err != nil {
fmt.Printf("Failed to debug service: %v\n", err)
os.Exit(1)
}
return
} }
} }
return nil // Setup Windows event logging if on Windows
if runtime.GOOS != "windows" {
setupWindowsEventLog()
} else {
// Initialize logger for non-Windows platforms
logger.Init()
}
// Load configuration from file, env vars, and CLI args
// Priority: CLI args > Env vars > Config file > Defaults
config, showVersion, showConfig, err := LoadConfig(os.Args[1:])
if err != nil {
fmt.Printf("Failed to load configuration: %v\n", err)
return
}
// Handle --show-config flag
if showConfig {
config.ShowConfig()
os.Exit(0)
}
olmVersion := "version_replaceme"
if showVersion {
fmt.Println("Olm version " + olmVersion)
os.Exit(0)
}
logger.Info("Olm version " + olmVersion)
config.Version = olmVersion
if err := SaveConfig(config); err != nil {
logger.Error("Failed to save full olm config: %v", err)
} else {
logger.Debug("Saved full olm config with all options")
}
// Create a new olm.Config struct and copy values from the main config
olmConfig := olm.Config{
Endpoint: config.Endpoint,
ID: config.ID,
Secret: config.Secret,
UserToken: config.UserToken,
MTU: config.MTU,
DNS: config.DNS,
InterfaceName: config.InterfaceName,
LogLevel: config.LogLevel,
EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr,
SocketPath: config.SocketPath,
Holepunch: config.Holepunch,
TlsClientCert: config.TlsClientCert,
PingIntervalDuration: config.PingIntervalDuration,
PingTimeoutDuration: config.PingTimeoutDuration,
Version: config.Version,
OrgID: config.OrgID,
// DoNotCreateNewClient: config.DoNotCreateNewClient,
}
// Create a context that will be cancelled on interrupt signals
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
olm.Run(ctx, olmConfig)
} }

View File

@@ -0,0 +1 @@
573df1772c00fcb34ec68e575e973c460dc27ba8

1
olm-test.REMOVED.git-id Normal file
View File

@@ -0,0 +1 @@
ba2c118fd96937229ef54dcd0b82fe5d53d94a87

88
olm.iss Normal file
View File

@@ -0,0 +1,88 @@
; Script generated by the Inno Setup Script Wizard.
; SEE THE DOCUMENTATION FOR DETAILS ON CREATING INNO SETUP SCRIPT FILES!
#define MyAppName "olm"
#define MyAppVersion "1.0.0"
#define MyAppPublisher "Fossorial Inc."
#define MyAppURL "https://pangolin.net"
#define MyAppExeName "olm.exe"
[Setup]
; NOTE: The value of AppId uniquely identifies this application. Do not use the same AppId value in installers for other applications.
; (To generate a new GUID, click Tools | Generate GUID inside the IDE.)
AppId={{44A24E4C-B616-476F-ADE7-8D56B930959E}
AppName={#MyAppName}
AppVersion={#MyAppVersion}
;AppVerName={#MyAppName} {#MyAppVersion}
AppPublisher={#MyAppPublisher}
AppPublisherURL={#MyAppURL}
AppSupportURL={#MyAppURL}
AppUpdatesURL={#MyAppURL}
DefaultDirName={autopf}\{#MyAppName}
UninstallDisplayIcon={app}\{#MyAppExeName}
; "ArchitecturesAllowed=x64compatible" specifies that Setup cannot run
; on anything but x64 and Windows 11 on Arm.
ArchitecturesAllowed=x64compatible
; "ArchitecturesInstallIn64BitMode=x64compatible" requests that the
; install be done in "64-bit mode" on x64 or Windows 11 on Arm,
; meaning it should use the native 64-bit Program Files directory and
; the 64-bit view of the registry.
ArchitecturesInstallIn64BitMode=x64compatible
DefaultGroupName={#MyAppName}
DisableProgramGroupPage=yes
; Uncomment the following line to run in non administrative install mode (install for current user only).
;PrivilegesRequired=lowest
OutputBaseFilename=mysetup
SolidCompression=yes
WizardStyle=modern
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed
RestartIfNeededByRun=no
ChangesEnvironment=true
[Languages]
Name: "english"; MessagesFile: "compiler:Default.isl"
[Files]
; The 'DestName' flag ensures that 'olm_windows_amd64.exe' is installed as 'olm.exe'
Source: "C:\Users\Administrator\Downloads\olm_windows_amd64.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}"; Flags: ignoreversion
Source: "C:\Users\Administrator\Downloads\wintun.dll"; DestDir: "{app}"; Flags: ignoreversion
; NOTE: Don't use "Flags: ignoreversion" on any shared system files
[Icons]
Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"
[Registry]
; Add the application's installation directory to the system PATH environment variable.
; HKLM (HKEY_LOCAL_MACHINE) is used for system-wide changes.
; The 'Path' variable is located under 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'.
; ValueType: expandsz allows for environment variables (like %ProgramFiles%) in the path.
; ValueData: "{olddata};{app}" appends the current application directory to the existing PATH.
; Flags: uninsdeletevalue ensures the entry is removed upon uninstallation.
; Check: IsWin64 ensures this is applied on 64-bit systems, which matches ArchitecturesAllowed.
[Registry]
; Add the application's installation directory to the system PATH.
Root: HKLM; Subkey: "SYSTEM\CurrentControlSet\Control\Session Manager\Environment"; \
ValueType: expandsz; ValueName: "Path"; ValueData: "{olddata};{app}"; \
Flags: uninsdeletevalue; Check: NeedsAddPath(ExpandConstant('{app}'))
[Code]
function NeedsAddPath(Path: string): boolean;
var
OrigPath: string;
begin
if not RegQueryStringValue(HKEY_LOCAL_MACHINE,
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
'Path', OrigPath)
then begin
// Path variable doesn't exist at all, so we definitely need to add it.
Result := True;
exit;
end;
// Perform a case-insensitive check to see if the path is already present.
// We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2).
if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then
Result := False
else
Result := True;
end;

885
olm/common.go Normal file
View File

@@ -0,0 +1,885 @@
package olm
import (
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"os/exec"
"regexp"
"runtime"
"strconv"
"strings"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"github.com/vishvananda/netlink"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type WgData struct {
Sites []SiteConfig `json:"sites"`
TunnelIP string `json:"tunnelIP"`
}
type SiteConfig struct {
SiteId int `json:"siteId"`
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
ServerIP string `json:"serverIP"`
ServerPort uint16 `json:"serverPort"`
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
}
type TargetsByType struct {
UDP []string `json:"udp"`
TCP []string `json:"tcp"`
}
type TargetData struct {
Targets []string `json:"targets"`
}
type HolePunchMessage struct {
NewtID string `json:"newtId"`
}
type ExitNode struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
type HolePunchData struct {
ExitNodes []ExitNode `json:"exitNodes"`
}
type EncryptedHolePunchMessage struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}
var (
peerMonitor *peermonitor.PeerMonitor
stopHolepunch chan struct{}
stopRegister func()
stopPing chan struct{}
olmToken string
holePunchRunning bool
)
const (
ENV_WG_TUN_FD = "WG_TUN_FD"
ENV_WG_UAPI_FD = "WG_UAPI_FD"
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
)
// PeerAction represents a request to add, update, or remove a peer
type PeerAction struct {
Action string `json:"action"` // "add", "update", or "remove"
SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information
}
// UpdatePeerData represents the data needed to update a peer
type UpdatePeerData struct {
SiteId int `json:"siteId"`
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
ServerIP string `json:"serverIP"`
ServerPort uint16 `json:"serverPort"`
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
}
// AddPeerData represents the data needed to add a peer
type AddPeerData struct {
SiteId int `json:"siteId"`
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
ServerIP string `json:"serverIP"`
ServerPort uint16 `json:"serverPort"`
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
}
// RemovePeerData represents the data needed to remove a peer
type RemovePeerData struct {
SiteId int `json:"siteId"`
}
type RelayPeerData struct {
SiteId int `json:"siteId"`
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
// Helper function to format endpoints correctly
func formatEndpoint(endpoint string) string {
if endpoint == "" {
return ""
}
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint // Already valid, no change needed
}
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
lastColon := strings.LastIndex(endpoint, ":")
if lastColon > 0 { // Ensure there is a colon and it's not the first character
hostPart := endpoint[:lastColon]
// Check if the host part is a literal IPv6 address
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
// It is! Reformat it with brackets.
portPart := endpoint[lastColon+1:]
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
}
}
// If it's not the specific malformed case, return it as is.
return endpoint
}
func fixKey(key string) string {
// Remove any whitespace
key = strings.TrimSpace(key)
// Decode from base64
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
logger.Fatal("Error decoding base64")
}
// Convert to hex
return hex.EncodeToString(decoded)
}
func parseLogLevel(level string) logger.LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":
return logger.DEBUG
case "INFO":
return logger.INFO
case "WARN":
return logger.WARN
case "ERROR":
return logger.ERROR
case "FATAL":
return logger.FATAL
default:
return logger.INFO // default to INFO if invalid level provided
}
}
func mapToWireGuardLogLevel(level logger.LogLevel) int {
switch level {
case logger.DEBUG:
return device.LogLevelVerbose
// case logger.INFO:
// return device.LogLevel
case logger.WARN:
return device.LogLevelError
case logger.ERROR, logger.FATAL:
return device.LogLevelSilent
default:
return device.LogLevelSilent
}
}
func ResolveDomain(domain string) (string, error) {
// First handle any protocol prefix
domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://")
// if there are any trailing slashes, remove them
domain = strings.TrimSuffix(domain, "/")
// Now split host and port
host, port, err := net.SplitHostPort(domain)
if err != nil {
// No port found, use the domain as is
host = domain
port = ""
}
// Lookup IP addresses
ips, err := net.LookupIP(host)
if err != nil {
return "", fmt.Errorf("DNS lookup failed: %v", err)
}
if len(ips) == 0 {
return "", fmt.Errorf("no IP addresses found for domain %s", host)
}
// Get the first IPv4 address if available
var ipAddr string
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ipAddr = ipv4.String()
break
}
}
// If no IPv4 found, use the first IP (might be IPv6)
if ipAddr == "" {
ipAddr = ips[0].String()
}
// Add port back if it existed
if port != "" {
ipAddr = net.JoinHostPort(ipAddr, port)
}
return ipAddr, nil
}
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
if maxPort < minPort {
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
}
// Create a slice of all ports in the range
portRange := make([]uint16, maxPort-minPort+1)
for i := range portRange {
portRange[i] = minPort + uint16(i)
}
// Fisher-Yates shuffle to randomize the port order
rand.Seed(uint64(time.Now().UnixNano()))
for i := len(portRange) - 1; i > 0; i-- {
j := rand.Intn(i + 1)
portRange[i], portRange[j] = portRange[j], portRange[i]
}
// Try each port in the randomized order
for _, port := range portRange {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: int(port),
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
continue // Port is in use or there was an error, try next port
}
_ = conn.SetDeadline(time.Now())
conn.Close()
return port, nil
}
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
}
func sendPing(olm *websocket.Client) error {
err := olm.SendMessage("olm/ping", map[string]interface{}{
"timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken,
})
if err != nil {
logger.Error("Failed to send ping message: %v", err)
return err
}
logger.Debug("Sent ping message")
return nil
}
func keepSendingPing(olm *websocket.Client) {
// Send ping immediately on startup
if err := sendPing(olm); err != nil {
logger.Error("Failed to send initial ping: %v", err)
} else {
logger.Info("Sent initial ping message")
}
// Set up ticker for one minute intervals
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopPing:
logger.Info("Stopping ping messages")
return
case <-ticker.C:
if err := sendPing(olm); err != nil {
logger.Error("Failed to send periodic ping: %v", err)
}
}
}
}
// ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
siteHost, err := ResolveDomain(siteConfig.Endpoint)
if err != nil {
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
}
// Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP
allowedIp := strings.Split(siteConfig.ServerIP, "/")
if len(allowedIp) > 1 {
allowedIp[1] = "32"
} else {
allowedIp = append(allowedIp, "32")
}
allowedIpStr := strings.Join(allowedIp, "/")
// Collect all allowed IPs in a slice
var allowedIPs []string
allowedIPs = append(allowedIPs, allowedIpStr)
// If we have anything in remoteSubnets, add those as well
if siteConfig.RemoteSubnets != "" {
// Split remote subnets by comma and add each one
remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",")
for _, subnet := range remoteSubnets {
subnet = strings.TrimSpace(subnet)
if subnet != "" {
allowedIPs = append(allowedIPs, subnet)
}
}
}
// Construct WireGuard config for this peer
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String())))
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(siteConfig.PublicKey)))
// Add each allowed IP separately
for _, allowedIP := range allowedIPs {
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
}
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
configBuilder.WriteString("persistent_keepalive_interval=1\n")
config := configBuilder.String()
logger.Debug("Configuring peer with config: %s", config)
err = dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
}
// Set up peer monitoring
if peerMonitor != nil {
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
primaryRelay, err := ResolveDomain(endpoint) // Using global endpoint variable
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
wgConfig := &peermonitor.WireGuardConfig{
SiteID: siteConfig.SiteId,
PublicKey: fixKey(siteConfig.PublicKey),
ServerIP: strings.Split(siteConfig.ServerIP, "/")[0],
Endpoint: siteConfig.Endpoint,
PrimaryRelay: primaryRelay,
}
err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig)
if err != nil {
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
} else {
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
}
}
return nil
}
// RemovePeer removes a peer from the WireGuard device
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
// Construct WireGuard config to remove the peer
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(publicKey)))
configBuilder.WriteString("remove=true\n")
config := configBuilder.String()
logger.Debug("Removing peer with config: %s", config)
err := dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
}
// Stop monitoring this peer
if peerMonitor != nil {
peerMonitor.RemovePeer(siteId)
logger.Info("Stopped monitoring for site %d", siteId)
}
return nil
}
// ConfigureInterface configures a network interface with an IP address and brings it up
func ConfigureInterface(interfaceName string, wgData WgData) error {
var ipAddr string = wgData.TunnelIP
// Parse the IP address and network
ip, ipNet, err := net.ParseCIDR(ipAddr)
if err != nil {
return fmt.Errorf("invalid IP address: %v", err)
}
switch runtime.GOOS {
case "linux":
return configureLinux(interfaceName, ip, ipNet)
case "darwin":
return configureDarwin(interfaceName, ip, ipNet)
case "windows":
return configureWindows(interfaceName, ip, ipNet)
default:
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
}
}
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
logger.Info("Configuring Windows interface: %s", interfaceName)
// Calculate mask string (e.g., 255.255.255.0)
maskBits, _ := ipNet.Mask.Size()
mask := net.CIDRMask(maskBits, 32)
maskIP := net.IP(mask)
// Set the IP address using netsh
cmd := exec.Command("netsh", "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", interfaceName),
"source=static",
fmt.Sprintf("addr=%s", ip.String()),
fmt.Sprintf("mask=%s", maskIP.String()))
logger.Info("Running command: %v", cmd)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("netsh command failed: %v, output: %s", err, out)
}
// Bring up the interface if needed (in Windows, setting the IP usually brings it up)
// But we'll explicitly enable it to be sure
cmd = exec.Command("netsh", "interface", "set", "interface",
interfaceName,
"admin=enable")
logger.Info("Running command: %v", cmd)
out, err = cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out)
}
// delay 2 seconds
time.Sleep(8 * time.Second)
// Wait for the interface to be up and have the correct IP
err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
if err != nil {
return fmt.Errorf("interface did not come up within timeout: %v", err)
}
return nil
}
// waitForInterfaceUp polls the network interface until it's up or times out
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
deadline := time.Now().Add(timeout)
pollInterval := 500 * time.Millisecond
for time.Now().Before(deadline) {
// Check if interface exists and is up
iface, err := net.InterfaceByName(interfaceName)
if err == nil {
// Check if interface is up
if iface.Flags&net.FlagUp != 0 {
// Check if it has the expected IP
addrs, err := iface.Addrs()
if err == nil {
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if ok && ipNet.IP.Equal(expectedIP) {
logger.Info("Interface %s is up with correct IP", interfaceName)
return nil // Interface is up with correct IP
}
}
logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName)
}
} else {
logger.Info("Interface %s exists but is not up yet", interfaceName)
}
} else {
logger.Info("Interface %s not found yet: %v", interfaceName, err)
}
// Wait before next check
time.Sleep(pollInterval)
}
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
}
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
if runtime.GOOS != "windows" {
return nil
}
var cmd *exec.Cmd
// Parse destination to get the IP and subnet
ip, ipNet, err := net.ParseCIDR(destination)
if err != nil {
return fmt.Errorf("invalid destination address: %v", err)
}
// Calculate the subnet mask
maskBits, _ := ipNet.Mask.Size()
mask := net.CIDRMask(maskBits, 32)
maskIP := net.IP(mask)
if gateway != "" {
// Route with specific gateway
cmd = exec.Command("route", "add",
ip.String(),
"mask", maskIP.String(),
gateway,
"metric", "1")
} else if interfaceName != "" {
// First, get the interface index
indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces")
output, err := indexCmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to get interface index: %v, output: %s", err, output)
}
// Parse the output to find the interface index
lines := strings.Split(string(output), "\n")
var ifIndex string
for _, line := range lines {
if strings.Contains(line, interfaceName) {
fields := strings.Fields(line)
if len(fields) > 0 {
ifIndex = fields[0]
break
}
}
}
if ifIndex == "" {
return fmt.Errorf("could not find index for interface %s", interfaceName)
}
// Convert to integer to validate
idx, err := strconv.Atoi(ifIndex)
if err != nil {
return fmt.Errorf("invalid interface index: %v", err)
}
// Route via interface using the index
cmd = exec.Command("route", "add",
ip.String(),
"mask", maskIP.String(),
"0.0.0.0",
"if", strconv.Itoa(idx))
} else {
return fmt.Errorf("either gateway or interface must be specified")
}
logger.Info("Running command: %v", cmd)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("route command failed: %v, output: %s", err, out)
}
return nil
}
func WindowsRemoveRoute(destination string) error {
// Parse destination to get the IP
ip, ipNet, err := net.ParseCIDR(destination)
if err != nil {
return fmt.Errorf("invalid destination address: %v", err)
}
// Calculate the subnet mask
maskBits, _ := ipNet.Mask.Size()
mask := net.CIDRMask(maskBits, 32)
maskIP := net.IP(mask)
cmd := exec.Command("route", "delete",
ip.String(),
"mask", maskIP.String())
logger.Info("Running command: %v", cmd)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
}
return nil
}
func findUnusedUTUN() (string, error) {
ifaces, err := net.Interfaces()
if err != nil {
return "", fmt.Errorf("failed to list interfaces: %v", err)
}
used := make(map[int]bool)
re := regexp.MustCompile(`^utun(\d+)$`)
for _, iface := range ifaces {
if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 {
if num, err := strconv.Atoi(matches[1]); err == nil {
used[num] = true
}
}
}
// Try utun0 up to utun255.
for i := 0; i < 256; i++ {
if !used[i] {
return fmt.Sprintf("utun%d", i), nil
}
}
return "", fmt.Errorf("no unused utun interface found")
}
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
logger.Info("Configuring darwin interface: %s", interfaceName)
prefix, _ := ipNet.Mask.Size()
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
logger.Info("Running command: %v", cmd)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out)
}
// Bring up the interface
cmd = exec.Command("ifconfig", interfaceName, "up")
logger.Info("Running command: %v", cmd)
out, err = cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out)
}
return nil
}
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
// Get the interface
link, err := netlink.LinkByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
}
// Create the IP address attributes
addr := &netlink.Addr{
IPNet: &net.IPNet{
IP: ip,
Mask: ipNet.Mask,
},
}
// Add the IP address to the interface
if err := netlink.AddrAdd(link, addr); err != nil {
return fmt.Errorf("failed to add IP address: %v", err)
}
// Bring up the interface
if err := netlink.LinkSetUp(link); err != nil {
return fmt.Errorf("failed to bring up interface: %v", err)
}
return nil
}
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
if runtime.GOOS != "darwin" {
return nil
}
var cmd *exec.Cmd
if gateway != "" {
// Route with specific gateway
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway)
} else if interfaceName != "" {
// Route via interface
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName)
} else {
return fmt.Errorf("either gateway or interface must be specified")
}
logger.Info("Running command: %v", cmd)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("route command failed: %v, output: %s", err, out)
}
return nil
}
func DarwinRemoveRoute(destination string) error {
if runtime.GOOS != "darwin" {
return nil
}
cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination)
logger.Info("Running command: %v", cmd)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
}
return nil
}
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
if runtime.GOOS != "linux" {
return nil
}
var cmd *exec.Cmd
if gateway != "" {
// Route with specific gateway
cmd = exec.Command("ip", "route", "add", destination, "via", gateway)
} else if interfaceName != "" {
// Route via interface
cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName)
} else {
return fmt.Errorf("either gateway or interface must be specified")
}
logger.Info("Running command: %v", cmd)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("ip route command failed: %v, output: %s", err, out)
}
return nil
}
func LinuxRemoveRoute(destination string) error {
if runtime.GOOS != "linux" {
return nil
}
cmd := exec.Command("ip", "route", "del", destination)
logger.Info("Running command: %v", cmd)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out)
}
return nil
}
// addRouteForServerIP adds an OS-specific route for the server IP
func addRouteForServerIP(serverIP, interfaceName string) error {
if runtime.GOOS == "darwin" {
return DarwinAddRoute(serverIP, "", interfaceName)
}
// else if runtime.GOOS == "windows" {
// return WindowsAddRoute(serverIP, "", interfaceName)
// } else if runtime.GOOS == "linux" {
// return LinuxAddRoute(serverIP, "", interfaceName)
// }
return nil
}
// removeRouteForServerIP removes an OS-specific route for the server IP
func removeRouteForServerIP(serverIP string) error {
if runtime.GOOS == "darwin" {
return DarwinRemoveRoute(serverIP)
}
// else if runtime.GOOS == "windows" {
// return WindowsRemoveRoute(serverIP)
// } else if runtime.GOOS == "linux" {
// return LinuxRemoveRoute(serverIP)
// }
return nil
}
// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets
func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error {
if remoteSubnets == "" {
return nil
}
// Split remote subnets by comma and add routes for each one
subnets := strings.Split(remoteSubnets, ",")
for _, subnet := range subnets {
subnet = strings.TrimSpace(subnet)
if subnet == "" {
continue
}
// Add route based on operating system
if runtime.GOOS == "darwin" {
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
return err
}
} else if runtime.GOOS == "windows" {
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
return err
}
} else if runtime.GOOS == "linux" {
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
return err
}
}
logger.Info("Added route for remote subnet: %s", subnet)
}
return nil
}
// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets
func removeRoutesForRemoteSubnets(remoteSubnets string) error {
if remoteSubnets == "" {
return nil
}
// Split remote subnets by comma and remove routes for each one
subnets := strings.Split(remoteSubnets, ",")
for _, subnet := range subnets {
subnet = strings.TrimSpace(subnet)
if subnet == "" {
continue
}
// Remove route based on operating system
if runtime.GOOS == "darwin" {
if err := DarwinRemoveRoute(subnet); err != nil {
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
return err
}
} else if runtime.GOOS == "windows" {
if err := WindowsRemoveRoute(subnet); err != nil {
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
return err
}
} else if runtime.GOOS == "linux" {
if err := LinuxRemoveRoute(subnet); err != nil {
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
return err
}
}
logger.Info("Removed route for remote subnet: %s", subnet)
}
return nil
}

894
olm/olm.go Normal file
View File

@@ -0,0 +1,894 @@
package olm
import (
"context"
"encoding/json"
"net"
"os"
"runtime"
"strconv"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates"
"github.com/fosrl/olm/api"
"github.com/fosrl/olm/bind"
"github.com/fosrl/olm/holepunch"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Config struct {
// Connection settings
Endpoint string
ID string
Secret string
UserToken string
// Network settings
MTU int
DNS string
InterfaceName string
// Logging
LogLevel string
// HTTP server
EnableAPI bool
HTTPAddr string
SocketPath string
// Advanced
Holepunch bool
TlsClientCert string
// Parsed values (not in JSON)
PingIntervalDuration time.Duration
PingTimeoutDuration time.Duration
// Source tracking (not in JSON)
sources map[string]string
Version string
OrgID string
DoNotCreateNewClient bool
}
var (
privateKey wgtypes.Key
connected bool
dev *device.Device
wgData WgData
holePunchData HolePunchData
uapiListener net.Listener
tdev tun.Device
apiServer *api.API
olmClient *websocket.Client
tunnelCancel context.CancelFunc
tunnelRunning bool
sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager
)
func Run(ctx context.Context, config Config) {
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger.GetLogger().SetLevel(parseLogLevel(config.LogLevel))
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
logger.Debug("Failed to check for updates: %v", err)
}
if config.Holepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
if config.HTTPAddr != "" {
apiServer = api.NewAPI(config.HTTPAddr)
} else if config.SocketPath != "" {
apiServer = api.NewAPISocket(config.SocketPath)
}
apiServer.SetVersion(config.Version)
apiServer.SetOrgID(config.OrgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
// Listen for shutdown requests from the API
go func() {
<-apiServer.GetShutdownChannel()
logger.Info("Shutdown requested via API")
// Cancel the context to trigger graceful shutdown
cancel()
}()
var (
id = config.ID
secret = config.Secret
endpoint = config.Endpoint
userToken = config.UserToken
)
// Main event loop that handles connect, disconnect, and reconnect
for {
select {
case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials")
goto shutdown
case req := <-apiServer.GetConnectionChannel():
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Stop any existing tunnel before starting a new one
if olmClient != nil {
logger.Info("Stopping existing tunnel before starting new connection")
StopTunnel()
}
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
userToken := req.UserToken
// Start the tunnel process with the new credentials
if id != "" && secret != "" && endpoint != "" {
logger.Info("Starting tunnel with new credentials")
tunnelRunning = true
go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
}
case <-apiServer.GetDisconnectChannel():
logger.Info("Received disconnect request via API")
StopTunnel()
// Clear credentials so we wait for new connect call
id = ""
secret = ""
endpoint = ""
userToken = ""
default:
// If we have credentials and no tunnel is running, start it
if id != "" && secret != "" && endpoint != "" && !tunnelRunning {
logger.Info("Starting tunnel process with initial credentials")
tunnelRunning = true
go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
} else if id == "" || secret == "" || endpoint == "" {
// If we don't have credentials, check if API is enabled
if !config.EnableAPI {
missing := []string{}
if id == "" {
missing = append(missing, "id")
}
if secret == "" {
missing = append(missing, "secret")
}
if endpoint == "" {
missing = append(missing, "endpoint")
}
// exit the application because there is no way to provide the missing parameters
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
goto shutdown
}
}
// Sleep briefly to prevent tight loop
time.Sleep(100 * time.Millisecond)
}
}
shutdown:
Stop()
apiServer.Stop()
logger.Info("Olm service shutting down")
}
func TunnelProcess(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) {
// Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(ctx)
tunnelCancel = cancel
defer func() {
tunnelCancel = nil
}()
// Recreate channels for this tunnel session
stopPing = make(chan struct{})
var (
interfaceName = config.InterfaceName
loggerLevel = parseLogLevel(config.LogLevel)
)
// Create a new olm client using the provided credentials
olm, err := websocket.NewClient(
id, // Use provided ID
secret, // Use provided secret
userToken, // Use provided user token OPTIONAL
endpoint, // Use provided endpoint
config.PingIntervalDuration,
config.PingTimeoutDuration,
)
if err != nil {
logger.Error("Failed to create olm: %v", err)
return
}
// Store the client reference globally
olmClient = olm
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Error("Failed to generate private key: %v", err)
return
}
// Create shared UDP socket for both holepunch and WireGuard
if sharedBind == nil {
sourcePort, err := FindAvailableUDPPort(49152, 65535)
if err != nil {
logger.Error("Error finding available port: %v", err)
return
}
localAddr := &net.UDPAddr{
Port: int(sourcePort),
IP: net.IPv4zero,
}
udpConn, err := net.ListenUDP("udp", localAddr)
if err != nil {
logger.Error("Failed to create shared UDP socket: %v", err)
return
}
sharedBind, err = bind.New(udpConn)
if err != nil {
logger.Error("Failed to create shared bind: %v", err)
udpConn.Close()
return
}
// Add a reference for the hole punch senders (creator already has one reference for WireGuard)
sharedBind.AddRef()
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
}
// Create the holepunch manager
if holePunchManager == nil {
holePunchManager = holepunch.NewManager(sharedBind, id, ResolveDomain)
}
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
for i, node := range holePunchData.ExitNodes {
exitNodes[i] = holepunch.ExitNode{
Endpoint: node.Endpoint,
PublicKey: node.PublicKey,
}
}
// Start hole punching using the manager
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
})
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
logger.Debug("Received message: %v", msg.Data)
type LegacyHolePunchData struct {
ServerPubKey string `json:"serverPubKey"`
Endpoint string `json:"endpoint"`
}
var legacyHolePunchData LegacyHolePunchData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Stop any existing hole punch operations
if holePunchManager != nil {
holePunchManager.Stop()
}
// Start hole punching for the exit node
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
})
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
if connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
// close(stopHolepunch)
// wait 10 milliseconds to ensure the previous connection is closed
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
time.Sleep(500 * time.Millisecond)
// if there is an existing tunnel then close it
if dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
tdev, err = func() (tun.Device, error) {
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
return nil, err
}
return tun.CreateTUN(interfaceName, config.MTU)
}
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
return createTUNFromFD(tunFdStr, config.MTU)
}
return tun.CreateTUN(interfaceName, config.MTU)
}()
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
return
}
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}
fileUAPI, err := func() (*os.File, error) {
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil
}
return uapiOpen(interfaceName)
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
dev = device.NewDevice(tdev, sharedBind, device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() {
for {
conn, err := uapiListener.Accept()
if err != nil {
return
}
go dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
if err = dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
apiServer.SetTunnelIP(wgData.TunnelIP)
peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
// Find the site config to get endpoint information
var endpoint string
var isRelay bool
for _, site := range wgData.Sites {
if site.SiteId == siteID {
endpoint = site.Endpoint
// TODO: We'll need to track relay status separately
// For now, assume not using relay unless we get relay data
isRelay = !config.Holepunch
break
}
}
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
} else {
logger.Warn("Peer %d is disconnected", siteID)
}
},
fixKey(privateKey.String()),
olm,
dev,
config.Holepunch,
)
for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
logger.Info("Configured peer %s", site.PublicKey)
}
peerMonitor.Start()
apiServer.SetRegistered(true)
connected = true
logger.Info("WireGuard device created.")
})
olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) {
logger.Debug("Received update-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateData UpdatePeerData
if err := json.Unmarshal(jsonData, &updateData); err != nil {
logger.Error("Error unmarshaling update data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: updateData.SiteId,
Endpoint: updateData.Endpoint,
PublicKey: updateData.PublicKey,
ServerIP: updateData.ServerIP,
ServerPort: updateData.ServerPort,
RemoteSubnets: updateData.RemoteSubnets,
}
// Update the peer in WireGuard
if dev != nil {
// Find the existing peer to get old data
var oldRemoteSubnets string
var oldPublicKey string
for _, site := range wgData.Sites {
if site.SiteId == updateData.SiteId {
oldRemoteSubnets = site.RemoteSubnets
oldPublicKey = site.PublicKey
break
}
}
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
}
// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
// Remove old remote subnet routes if they changed
if oldRemoteSubnets != siteConfig.RemoteSubnets {
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
logger.Error("Failed to remove old remote subnet routes: %v", err)
// Continue anyway to add new routes
}
// Add new remote subnet routes
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add new remote subnet routes: %v", err)
return
}
}
// Update successful
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
for i := range wgData.Sites {
if wgData.Sites[i].SiteId == updateData.SiteId {
wgData.Sites[i] = siteConfig
break
}
}
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for adding a new peer
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var addData AddPeerData
if err := json.Unmarshal(jsonData, &addData); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: addData.SiteId,
Endpoint: addData.Endpoint,
PublicKey: addData.PublicKey,
ServerIP: addData.ServerIP,
ServerPort: addData.ServerPort,
RemoteSubnets: addData.RemoteSubnets,
}
// Add the peer to WireGuard
if dev != nil {
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for new peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
// Add successful
logger.Info("Successfully added peer for site %d", addData.SiteId)
// Update WgData with the new peer
wgData.Sites = append(wgData.Sites, siteConfig)
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for removing a peer
olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) {
logger.Debug("Received remove-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeData RemovePeerData
if err := json.Unmarshal(jsonData, &removeData); err != nil {
logger.Error("Error unmarshaling remove data: %v", err)
return
}
// Find the peer to remove
var peerToRemove *SiteConfig
var newSites []SiteConfig
for _, site := range wgData.Sites {
if site.SiteId == removeData.SiteId {
peerToRemove = &site
} else {
newSites = append(newSites, site)
}
}
if peerToRemove == nil {
logger.Error("Peer with site ID %d not found", removeData.SiteId)
return
}
// Remove the peer from WireGuard
if dev != nil {
if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
logger.Error("Failed to remove peer: %v", err)
// Send error response if needed
return
}
// Remove route for the peer
err = removeRouteForServerIP(peerToRemove.ServerIP)
if err != nil {
logger.Error("Failed to remove route for peer: %v", err)
return
}
// Remove routes for remote subnets
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
logger.Error("Failed to remove routes for remote subnets: %v", err)
return
}
// Remove successful
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
// Update WgData to remove the peer
wgData.Sites = newSites
} else {
logger.Error("WireGuard device not initialized")
}
})
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
logger.Debug("Received relay-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData RelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := ResolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
logger.Info("Received no-sites message - no sites available for connection")
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
logger.Info("No sites available - stopped registration and holepunch processes")
})
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
olm.Close()
})
olm.OnConnect(func() error {
logger.Info("Websocket Connected")
apiServer.SetConnectionStatus(true)
if connected {
logger.Debug("Already connected, skipping registration")
return nil
}
publicKey := privateKey.PublicKey()
if stopRegister == nil {
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
// "doNotCreateNewClient": config.DoNotCreateNewClient,
}, 1*time.Second)
}
go keepSendingPing(olm)
return nil
})
olm.OnTokenUpdate(func(token string) {
if holePunchManager != nil {
holePunchManager.SetToken(token)
}
})
// Connect to the WebSocket server
if err := olm.Connect(); err != nil {
logger.Error("Failed to connect to server: %v", err)
return
}
defer olm.Close()
// Listen for org switch requests from the API
go func() {
for req := range apiServer.GetSwitchOrgChannel() {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Update the config with the new orgId
config.OrgID = req.OrgID
// Mark as not connected to trigger re-registration
connected = false
Stop()
// Clear peer statuses in API
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
apiServer.SetOrgID(config.OrgID)
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
}, 1*time.Second)
}
}()
// Wait for context cancellation
<-tunnelCtx.Done()
logger.Info("Tunnel process context cancelled, cleaning up")
}
func Stop() {
// Stop hole punch manager
if holePunchManager != nil {
holePunchManager.Stop()
}
if stopPing != nil {
select {
case <-stopPing:
// Channel already closed
default:
close(stopPing)
}
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
if peerMonitor != nil {
peerMonitor.Stop()
peerMonitor = nil
}
if uapiListener != nil {
uapiListener.Close()
uapiListener = nil
}
if dev != nil {
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
dev = nil
}
// Close TUN device
if tdev != nil {
tdev.Close()
tdev = nil
}
// Release the hole punch reference to the shared bind
if sharedBind != nil {
// Release hole punch reference (WireGuard already released its reference via dev.Close())
logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount())
sharedBind.Release()
sharedBind = nil
logger.Info("Released shared UDP bind")
}
logger.Info("Olm service stopped")
}
// StopTunnel stops just the tunnel process and websocket connection
// without shutting down the entire application
func StopTunnel() {
logger.Info("Stopping tunnel process")
// Cancel the tunnel context if it exists
if tunnelCancel != nil {
tunnelCancel()
// Give it a moment to clean up
time.Sleep(200 * time.Millisecond)
}
// Close the websocket connection
if olmClient != nil {
olmClient.Close()
olmClient = nil
}
Stop()
// Reset the connected state
connected = false
tunnelRunning = false
// Update API server status
apiServer.SetConnectionStatus(false)
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
logger.Info("Tunnel process stopped")
}

35
olm/unix.go Normal file
View File

@@ -0,0 +1,35 @@
//go:build !windows
package olm
import (
"net"
"os"
"strconv"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
)
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
if err != nil {
return nil, err
}
err = unix.SetNonblock(int(fd), true)
if err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "")
return tun.CreateTUNFromFile(file, mtuInt)
}
func uapiOpen(interfaceName string) (*os.File, error) {
return ipc.UAPIOpen(interfaceName)
}
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
return ipc.UAPIListen(interfaceName, fileUAPI)
}

25
olm/windows.go Normal file
View File

@@ -0,0 +1,25 @@
//go:build windows
package olm
import (
"errors"
"net"
"os"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
)
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
return nil, errors.New("CreateTUNFromFile not supported on Windows")
}
func uapiOpen(interfaceName string) (*os.File, error) {
return nil, nil
}
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
// On Windows, UAPIListen only takes one parameter
return ipc.UAPIListen(interfaceName)
}

331
peermonitor/peermonitor.go Normal file
View File

@@ -0,0 +1,331 @@
package peermonitor
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/websocket"
"github.com/fosrl/olm/wgtester"
"golang.zx2c4.com/wireguard/device"
)
// PeerMonitorCallback is the function type for connection status change callbacks
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
// WireGuardConfig holds the WireGuard configuration for a peer
type WireGuardConfig struct {
SiteID int
PublicKey string
ServerIP string
Endpoint string
PrimaryRelay string // The primary relay endpoint
}
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
type PeerMonitor struct {
monitors map[int]*wgtester.Client
configs map[int]*WireGuardConfig
callback PeerMonitorCallback
mutex sync.Mutex
running bool
interval time.Duration
timeout time.Duration
maxAttempts int
privateKey string
wsClient *websocket.Client
device *device.Device
handleRelaySwitch bool // Whether to handle relay switching
}
// NewPeerMonitor creates a new peer monitor with the given callback
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
return &PeerMonitor{
monitors: make(map[int]*wgtester.Client),
configs: make(map[int]*WireGuardConfig),
callback: callback,
interval: 1 * time.Second, // Default check interval
timeout: 2500 * time.Millisecond,
maxAttempts: 8,
privateKey: privateKey,
wsClient: wsClient,
device: device,
handleRelaySwitch: handleRelaySwitch,
}
}
// SetInterval changes how frequently peers are checked
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.interval = interval
// Update interval for all existing monitors
for _, client := range pm.monitors {
client.SetPacketInterval(interval)
}
}
// SetTimeout changes the timeout for waiting for responses
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.timeout = timeout
// Update timeout for all existing monitors
for _, client := range pm.monitors {
client.SetTimeout(timeout)
}
}
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.maxAttempts = attempts
// Update max attempts for all existing monitors
for _, client := range pm.monitors {
client.SetMaxAttempts(attempts)
}
}
// AddPeer adds a new peer to monitor
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error {
pm.mutex.Lock()
defer pm.mutex.Unlock()
// Check if we're already monitoring this peer
if _, exists := pm.monitors[siteID]; exists {
// Update the endpoint instead of creating a new monitor
pm.removePeerUnlocked(siteID)
}
client, err := wgtester.NewClient(endpoint)
if err != nil {
return err
}
// Configure the client with our settings
client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts)
// Store the client and config
pm.monitors[siteID] = client
pm.configs[siteID] = wgConfig
// If monitor is already running, start monitoring this peer
if pm.running {
siteIDCopy := siteID // Create a copy for the closure
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
pm.handleConnectionStatusChange(siteIDCopy, status)
})
}
return err
}
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
// This function assumes the mutex is already held by the caller
func (pm *PeerMonitor) removePeerUnlocked(siteID int) {
client, exists := pm.monitors[siteID]
if !exists {
return
}
client.StopMonitor()
client.Close()
delete(pm.monitors, siteID)
delete(pm.configs, siteID)
}
// RemovePeer stops monitoring a peer and removes it from the monitor
func (pm *PeerMonitor) RemovePeer(siteID int) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.removePeerUnlocked(siteID)
}
// Start begins monitoring all peers
func (pm *PeerMonitor) Start() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
if pm.running {
return // Already running
}
pm.running = true
// Start monitoring all peers
for siteID, client := range pm.monitors {
siteIDCopy := siteID // Create a copy for the closure
err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
pm.handleConnectionStatusChange(siteIDCopy, status)
})
if err != nil {
logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err)
continue
}
logger.Info("Started monitoring peer %d\n", siteID)
}
}
// handleConnectionStatusChange is called when a peer's connection status changes
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status wgtester.ConnectionStatus) {
// Call the user-provided callback first
if pm.callback != nil {
pm.callback(siteID, status.Connected, status.RTT)
}
// If disconnected, handle failover
if !status.Connected {
// Send relay message to the server
if pm.wsClient != nil {
pm.sendRelay(siteID)
}
}
}
// handleFailover handles failover to the relay server when a peer is disconnected
func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) {
pm.mutex.Lock()
config, exists := pm.configs[siteID]
pm.mutex.Unlock()
if !exists {
return
}
// Check for IPv6 and format the endpoint correctly
formattedEndpoint := relayEndpoint
if strings.Contains(relayEndpoint, ":") {
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
}
// Configure WireGuard to use the relay
wgConfig := fmt.Sprintf(`private_key=%s
public_key=%s
allowed_ip=%s/32
endpoint=%s:21820
persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, formattedEndpoint)
err := pm.device.IpcSet(wgConfig)
if err != nil {
logger.Error("Failed to configure WireGuard device: %v\n", err)
return
}
logger.Info("Adjusted peer %d to point to relay!\n", siteID)
}
// sendRelay sends a relay message to the server
func (pm *PeerMonitor) sendRelay(siteID int) error {
if !pm.handleRelaySwitch {
return nil
}
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent relay message")
return nil
}
// Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
if !pm.running {
return
}
pm.running = false
// Stop all monitors
for _, client := range pm.monitors {
client.StopMonitor()
}
}
// Close stops monitoring and cleans up resources
func (pm *PeerMonitor) Close() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
// Stop and close all clients
for siteID, client := range pm.monitors {
client.StopMonitor()
client.Close()
delete(pm.monitors, siteID)
}
pm.running = false
}
// TestPeer tests connectivity to a specific peer
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
pm.mutex.Lock()
client, exists := pm.monitors[siteID]
pm.mutex.Unlock()
if !exists {
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
}
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
defer cancel()
connected, rtt := client.TestConnection(ctx)
return connected, rtt, nil
}
// TestAllPeers tests connectivity to all peers
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
Connected bool
RTT time.Duration
} {
pm.mutex.Lock()
peers := make(map[int]*wgtester.Client, len(pm.monitors))
for siteID, client := range pm.monitors {
peers[siteID] = client
}
pm.mutex.Unlock()
results := make(map[int]struct {
Connected bool
RTT time.Duration
})
for siteID, client := range peers {
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
connected, rtt := client.TestConnection(ctx)
cancel()
results[siteID] = struct {
Connected bool
RTT time.Duration
}{
Connected: connected,
RTT: rtt,
}
}
return results
}

View File

@@ -1,334 +0,0 @@
package proxy
import (
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun/netstack"
)
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
return &ProxyManager{
tnet: tnet,
}
}
func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target string) {
pm.Lock()
defer pm.Unlock()
logger.Info("Adding target: %s://%s:%d -> %s", protocol, listen, port, target)
newTarget := ProxyTarget{
Protocol: protocol,
Listen: listen,
Port: port,
Target: target,
cancel: make(chan struct{}),
done: make(chan struct{}),
}
pm.targets = append(pm.targets, newTarget)
}
func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error {
pm.Lock()
defer pm.Unlock()
protocol = strings.ToLower(protocol)
if protocol != "tcp" && protocol != "udp" {
return fmt.Errorf("unsupported protocol: %s", protocol)
}
for i, target := range pm.targets {
if target.Listen == listen &&
target.Port == port &&
strings.ToLower(target.Protocol) == protocol {
// Signal the serving goroutine to stop
select {
case <-target.cancel:
// Channel is already closed, no need to close it again
default:
close(target.cancel)
}
// Close the appropriate listener/connection based on protocol
target.Lock()
switch protocol {
case "tcp":
if target.listener != nil {
select {
case <-target.cancel:
// Listener was already closed by Stop()
default:
target.listener.Close()
}
}
case "udp":
if target.udpConn != nil {
select {
case <-target.cancel:
// Connection was already closed by Stop()
default:
target.udpConn.Close()
}
}
}
target.Unlock()
// Wait for the target to fully stop
<-target.done
// Remove the target from the slice
pm.targets = append(pm.targets[:i], pm.targets[i+1:]...)
return nil
}
}
return fmt.Errorf("target not found for %s %s:%d", protocol, listen, port)
}
func (pm *ProxyManager) Start() error {
pm.RLock()
defer pm.RUnlock()
for i := range pm.targets {
target := &pm.targets[i]
target.Lock()
// If target is already running, skip it
if target.listener != nil || target.udpConn != nil {
target.Unlock()
continue
}
// Mark the target as starting by creating a nil listener/connection
// This prevents other goroutines from trying to start it
if strings.ToLower(target.Protocol) == "tcp" {
target.listener = nil
} else {
target.udpConn = nil
}
target.Unlock()
switch strings.ToLower(target.Protocol) {
case "tcp":
go pm.serveTCP(target)
case "udp":
go pm.serveUDP(target)
default:
return fmt.Errorf("unsupported protocol: %s", target.Protocol)
}
}
return nil
}
func (pm *ProxyManager) Stop() error {
pm.Lock()
defer pm.Unlock()
var wg sync.WaitGroup
for i := range pm.targets {
target := &pm.targets[i]
wg.Add(1)
go func(t *ProxyTarget) {
defer wg.Done()
close(t.cancel)
t.Lock()
if t.listener != nil {
t.listener.Close()
}
if t.udpConn != nil {
t.udpConn.Close()
}
t.Unlock()
// Wait for the target to fully stop
<-t.done
}(target)
}
wg.Wait()
return nil
}
func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
defer close(target.done) // Signal that this target is fully stopped
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{
IP: net.ParseIP(target.Listen),
Port: target.Port,
})
if err != nil {
logger.Info("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err)
return
}
target.Lock()
target.listener = listener
target.Unlock()
defer listener.Close()
logger.Info("TCP proxy listening on %s", listener.Addr())
var activeConns sync.WaitGroup
acceptDone := make(chan struct{})
// Goroutine to handle shutdown signal
go func() {
<-target.cancel
close(acceptDone)
listener.Close()
}()
for {
conn, err := listener.Accept()
if err != nil {
select {
case <-target.cancel:
// Wait for active connections to finish
activeConns.Wait()
return
default:
logger.Info("Failed to accept TCP connection: %v", err)
// Don't return here, try to accept new connections
time.Sleep(time.Second)
continue
}
}
activeConns.Add(1)
go func() {
defer activeConns.Done()
pm.handleTCPConnection(conn, target.Target, acceptDone)
}()
}
}
func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string, done chan struct{}) {
defer clientConn.Close()
serverConn, err := net.Dial("tcp", target)
if err != nil {
logger.Info("Failed to connect to target %s: %v", target, err)
return
}
defer serverConn.Close()
var wg sync.WaitGroup
wg.Add(2)
// Client -> Server
go func() {
defer wg.Done()
select {
case <-done:
return
default:
io.Copy(serverConn, clientConn)
}
}()
// Server -> Client
go func() {
defer wg.Done()
select {
case <-done:
return
default:
io.Copy(clientConn, serverConn)
}
}()
wg.Wait()
}
func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
defer close(target.done) // Signal that this target is fully stopped
addr := &net.UDPAddr{
IP: net.ParseIP(target.Listen),
Port: target.Port,
}
conn, err := pm.tnet.ListenUDP(addr)
if err != nil {
logger.Info("Failed to start UDP listener for %s:%d: %v", target.Listen, target.Port, err)
return
}
target.Lock()
target.udpConn = conn
target.Unlock()
defer conn.Close()
logger.Info("UDP proxy listening on %s", conn.LocalAddr())
buffer := make([]byte, 65535)
var activeConns sync.WaitGroup
for {
select {
case <-target.cancel:
activeConns.Wait() // Wait for all active UDP handlers to complete
return
default:
n, remoteAddr, err := conn.ReadFrom(buffer)
if err != nil {
select {
case <-target.cancel:
activeConns.Wait()
return
default:
logger.Info("Failed to read UDP packet: %v", err)
continue
}
}
targetAddr, err := net.ResolveUDPAddr("udp", target.Target)
if err != nil {
logger.Info("Failed to resolve target address %s: %v", target.Target, err)
continue
}
activeConns.Add(1)
go func(data []byte, remote net.Addr) {
defer activeConns.Done()
targetConn, err := net.DialUDP("udp", nil, targetAddr)
if err != nil {
logger.Info("Failed to connect to target %s: %v", target.Target, err)
return
}
defer targetConn.Close()
select {
case <-target.cancel:
return
default:
_, err = targetConn.Write(data)
if err != nil {
logger.Info("Failed to write to target: %v", err)
return
}
response := make([]byte, 65535)
n, err := targetConn.Read(response)
if err != nil {
logger.Info("Failed to read response from target: %v", err)
return
}
_, err = conn.WriteTo(response[:n], remote)
if err != nil {
logger.Info("Failed to write response to client: %v", err)
}
}
}(buffer[:n], remoteAddr)
}
}
}

View File

@@ -1,28 +0,0 @@
package proxy
import (
"log"
"net"
"sync"
"golang.zx2c4.com/wireguard/tun/netstack"
)
type ProxyTarget struct {
Protocol string
Listen string
Port int
Target string
cancel chan struct{} // Channel to signal shutdown
done chan struct{} // Channel to signal completion
listener net.Listener // For TCP
udpConn net.PacketConn // For UDP
sync.Mutex // Protect access to connection
}
type ProxyManager struct {
targets []ProxyTarget
tnet *netstack.Net
log *log.Logger
sync.RWMutex // Protect access to targets slice
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 774 KiB

54
service_unix.go Normal file
View File

@@ -0,0 +1,54 @@
//go:build !windows
package main
import (
"fmt"
)
// Service management functions are not available on non-Windows platforms
func installService() error {
return fmt.Errorf("service management is only available on Windows")
}
func removeService() error {
return fmt.Errorf("service management is only available on Windows")
}
func startService(args []string) error {
_ = args // unused on Unix platforms
return fmt.Errorf("service management is only available on Windows")
}
func stopService() error {
return fmt.Errorf("service management is only available on Windows")
}
func getServiceStatus() (string, error) {
return "", fmt.Errorf("service management is only available on Windows")
}
func debugService(args []string) error {
_ = args // unused on Unix platforms
return fmt.Errorf("debug service is only available on Windows")
}
func isWindowsService() bool {
return false
}
func runService(name string, isDebug bool, args []string) {
// No-op on non-Windows platforms
}
func setupWindowsEventLog() {
// No-op on non-Windows platforms
}
func watchLogFile(end bool) error {
return fmt.Errorf("watching log file is only available on Windows")
}
func showServiceConfig() {
fmt.Println("Service configuration is only available on Windows")
}

631
service_windows.go Normal file
View File

@@ -0,0 +1,631 @@
//go:build windows
package main
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
"github.com/fosrl/newt/logger"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/debug"
"golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr"
)
const (
serviceName = "OlmWireguardService"
serviceDisplayName = "Olm WireGuard VPN Service"
serviceDescription = "Olm WireGuard VPN client service for secure network connectivity"
)
// Global variable to store service arguments
var serviceArgs []string
// getServiceArgsPath returns the path where service arguments are stored
func getServiceArgsPath() string {
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm")
return filepath.Join(logDir, "service_args.json")
}
// saveServiceArgs saves the service arguments to a file
func saveServiceArgs(args []string) error {
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm")
err := os.MkdirAll(logDir, 0755)
if err != nil {
return fmt.Errorf("failed to create config directory: %v", err)
}
argsPath := getServiceArgsPath()
data, err := json.Marshal(args)
if err != nil {
return fmt.Errorf("failed to marshal service args: %v", err)
}
err = os.WriteFile(argsPath, data, 0644)
if err != nil {
return fmt.Errorf("failed to write service args: %v", err)
}
return nil
}
// loadServiceArgs loads the service arguments from a file
func loadServiceArgs() ([]string, error) {
argsPath := getServiceArgsPath()
data, err := os.ReadFile(argsPath)
if err != nil {
if os.IsNotExist(err) {
return []string{}, nil // Return empty args if file doesn't exist
}
return nil, fmt.Errorf("failed to read service args: %v", err)
}
var args []string
err = json.Unmarshal(data, &args)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal service args: %v", err)
}
return args, nil
}
type olmService struct {
elog debug.Log
ctx context.Context
stop context.CancelFunc
args []string
}
func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) {
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
changes <- svc.Status{State: svc.StartPending}
s.elog.Info(1, fmt.Sprintf("Service Execute called with args: %v", args))
// Load saved service arguments
savedArgs, err := loadServiceArgs()
if err != nil {
s.elog.Error(1, fmt.Sprintf("Failed to load service args: %v", err))
// Continue with empty args if loading fails
savedArgs = []string{}
}
// Combine service start args with saved args, giving priority to service start args
finalArgs := []string{}
if len(args) > 0 {
// Skip the first arg which is typically the service name
if len(args) > 1 {
finalArgs = append(finalArgs, args[1:]...)
}
s.elog.Info(1, fmt.Sprintf("Using service start parameters: %v", finalArgs))
}
// If no service start parameters, use saved args
if len(finalArgs) == 0 && len(savedArgs) > 0 {
finalArgs = savedArgs
s.elog.Info(1, fmt.Sprintf("Using saved service args: %v", finalArgs))
}
s.args = finalArgs
// Start the main olm functionality
olmDone := make(chan struct{})
go func() {
s.runOlm()
close(olmDone)
}()
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
s.elog.Info(1, "Service status set to Running")
for {
select {
case c := <-r:
switch c.Cmd {
case svc.Interrogate:
changes <- c.CurrentStatus
case svc.Stop, svc.Shutdown:
s.elog.Info(1, "Service stopping")
changes <- svc.Status{State: svc.StopPending}
if s.stop != nil {
s.stop()
}
// Wait for main logic to finish or timeout
select {
case <-olmDone:
s.elog.Info(1, "Main logic finished gracefully")
case <-time.After(10 * time.Second):
s.elog.Info(1, "Timeout waiting for main logic to finish")
}
return false, 0
default:
s.elog.Error(1, fmt.Sprintf("Unexpected control request #%d", c))
}
case <-olmDone:
s.elog.Info(1, "Main olm logic completed, stopping service")
changes <- svc.Status{State: svc.StopPending}
return false, 0
}
}
}
func (s *olmService) runOlm() {
// Create a context that can be cancelled when the service stops
s.ctx, s.stop = context.WithCancel(context.Background())
// Setup logging for service mode
s.elog.Info(1, "Starting Olm main logic")
// Run the main olm logic and wait for it to complete
done := make(chan struct{})
go func() {
defer func() {
if r := recover(); r != nil {
s.elog.Error(1, fmt.Sprintf("Olm panic: %v", r))
}
close(done)
}()
// Call the main olm function with stored arguments
runOlmMainWithArgs(s.ctx, s.args)
}()
// Wait for either context cancellation or main logic completion
select {
case <-s.ctx.Done():
s.elog.Info(1, "Olm service context cancelled")
case <-done:
s.elog.Info(1, "Olm main logic completed")
}
}
func runService(name string, isDebug bool, args []string) {
var err error
var elog debug.Log
if isDebug {
elog = debug.New(name)
fmt.Printf("Starting %s service in debug mode\n", name)
} else {
elog, err = eventlog.Open(name)
if err != nil {
fmt.Printf("Failed to open event log: %v\n", err)
return
}
}
defer elog.Close()
elog.Info(1, fmt.Sprintf("Starting %s service", name))
run := svc.Run
if isDebug {
run = debug.Run
}
service := &olmService{elog: elog, args: args}
err = run(name, service)
if err != nil {
elog.Error(1, fmt.Sprintf("%s service failed: %v", name, err))
if isDebug {
fmt.Printf("Service failed: %v\n", err)
}
return
}
elog.Info(1, fmt.Sprintf("%s service stopped", name))
if isDebug {
fmt.Printf("%s service stopped\n", name)
}
}
func installService() error {
exepath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %v", err)
}
m, err := mgr.Connect()
if err != nil {
return fmt.Errorf("failed to connect to service manager: %v", err)
}
defer m.Disconnect()
s, err := m.OpenService(serviceName)
if err == nil {
s.Close()
return fmt.Errorf("service %s already exists", serviceName)
}
config := mgr.Config{
ServiceType: 0x10, // SERVICE_WIN32_OWN_PROCESS
StartType: mgr.StartManual,
ErrorControl: mgr.ErrorNormal,
DisplayName: serviceDisplayName,
Description: serviceDescription,
BinaryPathName: exepath,
}
s, err = m.CreateService(serviceName, exepath, config)
if err != nil {
return fmt.Errorf("failed to create service: %v", err)
}
defer s.Close()
err = eventlog.InstallAsEventCreate(serviceName, eventlog.Error|eventlog.Warning|eventlog.Info)
if err != nil {
s.Delete()
return fmt.Errorf("failed to install event log: %v", err)
}
return nil
}
func removeService() error {
m, err := mgr.Connect()
if err != nil {
return fmt.Errorf("failed to connect to service manager: %v", err)
}
defer m.Disconnect()
s, err := m.OpenService(serviceName)
if err != nil {
return fmt.Errorf("service %s is not installed", serviceName)
}
defer s.Close()
// Stop the service if it's running
status, err := s.Query()
if err != nil {
return fmt.Errorf("failed to query service status: %v", err)
}
if status.State != svc.Stopped {
_, err = s.Control(svc.Stop)
if err != nil {
return fmt.Errorf("failed to stop service: %v", err)
}
// Wait for service to stop
timeout := time.Now().Add(30 * time.Second)
for status.State != svc.Stopped {
if timeout.Before(time.Now()) {
return fmt.Errorf("timeout waiting for service to stop")
}
time.Sleep(300 * time.Millisecond)
status, err = s.Query()
if err != nil {
return fmt.Errorf("failed to query service status: %v", err)
}
}
}
err = s.Delete()
if err != nil {
return fmt.Errorf("failed to delete service: %v", err)
}
err = eventlog.Remove(serviceName)
if err != nil {
return fmt.Errorf("failed to remove event log: %v", err)
}
return nil
}
func startService(args []string) error {
// Save the service arguments as backup
if len(args) > 0 {
err := saveServiceArgs(args)
if err != nil {
return fmt.Errorf("failed to save service args: %v", err)
}
}
m, err := mgr.Connect()
if err != nil {
return fmt.Errorf("failed to connect to service manager: %v", err)
}
defer m.Disconnect()
s, err := m.OpenService(serviceName)
if err != nil {
return fmt.Errorf("service %s is not installed", serviceName)
}
defer s.Close()
// Pass arguments directly to the service start call
err = s.Start(args...)
if err != nil {
return fmt.Errorf("failed to start service: %v", err)
}
return nil
}
func stopService() error {
m, err := mgr.Connect()
if err != nil {
return fmt.Errorf("failed to connect to service manager: %v", err)
}
defer m.Disconnect()
s, err := m.OpenService(serviceName)
if err != nil {
return fmt.Errorf("service %s is not installed", serviceName)
}
defer s.Close()
status, err := s.Control(svc.Stop)
if err != nil {
return fmt.Errorf("failed to stop service: %v", err)
}
timeout := time.Now().Add(30 * time.Second)
for status.State != svc.Stopped {
if timeout.Before(time.Now()) {
return fmt.Errorf("timeout waiting for service to stop")
}
time.Sleep(300 * time.Millisecond)
status, err = s.Query()
if err != nil {
return fmt.Errorf("failed to query service status: %v", err)
}
}
return nil
}
func debugService(args []string) error {
// Save the service arguments before starting
if len(args) > 0 {
err := saveServiceArgs(args)
if err != nil {
return fmt.Errorf("failed to save service args: %v", err)
}
}
// Start the service with the provided arguments
err := startService(args)
if err != nil {
return fmt.Errorf("failed to start service: %v", err)
}
// Watch the log file
return watchLogFile(true)
}
func watchLogFile(end bool) error {
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
logPath := filepath.Join(logDir, "olm.log")
// Ensure the log directory exists
err := os.MkdirAll(logDir, 0755)
if err != nil {
return fmt.Errorf("failed to create log directory: %v", err)
}
// Wait for the log file to be created if it doesn't exist
var file *os.File
for i := 0; i < 30; i++ { // Wait up to 15 seconds
file, err = os.Open(logPath)
if err == nil {
break
}
if i == 0 {
fmt.Printf("Waiting for log file to be created...\n")
}
time.Sleep(500 * time.Millisecond)
}
if err != nil {
return fmt.Errorf("failed to open log file after waiting: %v", err)
}
defer file.Close()
// Seek to the end of the file to only show new logs
_, err = file.Seek(0, 2)
if err != nil {
return fmt.Errorf("failed to seek to end of file: %v", err)
}
// Set up signal handling for graceful exit
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
// Create a ticker to check for new content
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
buffer := make([]byte, 4096)
for {
select {
case <-sigCh:
fmt.Printf("\n\nStopping log watch...\n")
// stop the service if needed
if end {
if err := stopService(); err != nil {
fmt.Printf("Failed to stop service: %v\n", err)
}
}
fmt.Printf("Log watch stopped.\n")
return nil
case <-ticker.C:
// Read new content
n, err := file.Read(buffer)
if err != nil && err != io.EOF {
// Try to reopen the file in case it was recreated
file.Close()
file, err = os.Open(logPath)
if err != nil {
return fmt.Errorf("error reopening log file: %v", err)
}
continue
}
if n > 0 {
// Print the new content
fmt.Print(string(buffer[:n]))
}
}
}
}
func getServiceStatus() (string, error) {
m, err := mgr.Connect()
if err != nil {
return "", fmt.Errorf("failed to connect to service manager: %v", err)
}
defer m.Disconnect()
s, err := m.OpenService(serviceName)
if err != nil {
return "Not Installed", nil
}
defer s.Close()
status, err := s.Query()
if err != nil {
return "", fmt.Errorf("failed to query service status: %v", err)
}
switch status.State {
case svc.Stopped:
return "Stopped", nil
case svc.StartPending:
return "Starting", nil
case svc.StopPending:
return "Stopping", nil
case svc.Running:
return "Running", nil
case svc.ContinuePending:
return "Continue Pending", nil
case svc.PausePending:
return "Pause Pending", nil
case svc.Paused:
return "Paused", nil
default:
return "Unknown", nil
}
}
// showServiceConfig displays current saved service configuration
func showServiceConfig() {
configPath := getServiceArgsPath()
fmt.Printf("Service configuration file: %s\n", configPath)
args, err := loadServiceArgs()
if err != nil {
fmt.Printf("No saved configuration found or error loading: %v\n", err)
return
}
if len(args) == 0 {
fmt.Println("No saved service arguments found")
} else {
fmt.Printf("Saved service arguments: %v\n", args)
}
}
func isWindowsService() bool {
isWindowsService, err := svc.IsWindowsService()
return err == nil && isWindowsService
}
// rotateLogFile handles daily log rotation
func rotateLogFile(logDir string, logFile string) error {
// Get current log file info
info, err := os.Stat(logFile)
if err != nil {
if os.IsNotExist(err) {
return nil // No current log file to rotate
}
return fmt.Errorf("failed to stat log file: %v", err)
}
// Check if log file is from today
now := time.Now()
fileTime := info.ModTime()
// If the log file is from today, no rotation needed
if now.Year() == fileTime.Year() && now.YearDay() == fileTime.YearDay() {
return nil
}
// Create rotated filename with date
rotatedName := fmt.Sprintf("olm-%s.log", fileTime.Format("2006-01-02"))
rotatedPath := filepath.Join(logDir, rotatedName)
// Rename current log file to dated filename
err = os.Rename(logFile, rotatedPath)
if err != nil {
return fmt.Errorf("failed to rotate log file: %v", err)
}
// Clean up old log files (keep last 30 days)
cleanupOldLogFiles(logDir, 30)
return nil
}
// cleanupOldLogFiles removes log files older than specified days
func cleanupOldLogFiles(logDir string, daysToKeep int) {
cutoff := time.Now().AddDate(0, 0, -daysToKeep)
files, err := os.ReadDir(logDir)
if err != nil {
return
}
for _, file := range files {
if !file.IsDir() && strings.HasPrefix(file.Name(), "olm-") && strings.HasSuffix(file.Name(), ".log") {
filePath := filepath.Join(logDir, file.Name())
info, err := file.Info()
if err != nil {
continue
}
if info.ModTime().Before(cutoff) {
os.Remove(filePath)
}
}
}
}
func setupWindowsEventLog() {
// Create log directory if it doesn't exist
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
err := os.MkdirAll(logDir, 0755)
if err != nil {
fmt.Printf("Failed to create log directory: %v\n", err)
return
}
logFile := filepath.Join(logDir, "olm.log")
// Rotate log file if needed
err = rotateLogFile(logDir, logFile)
if err != nil {
fmt.Printf("Failed to rotate log file: %v\n", err)
// Continue anyway to create new log file
}
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
fmt.Printf("Failed to open log file: %v\n", err)
return
}
// Set the custom logger output
logger.GetLogger().SetOutput(file)
log.Printf("Olm service logging initialized - log file: %s", logFile)
}

View File

@@ -2,38 +2,81 @@ package websocket
import ( import (
"bytes" "bytes"
"crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/fosrl/newt/logger" "software.sslmate.com/src/go-pkcs12"
"github.com/fosrl/newt/logger"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
type Client struct { type TokenResponse struct {
conn *websocket.Conn Data struct {
config *Config Token string `json:"token"`
baseURL string } `json:"data"`
handlers map[string]MessageHandler Success bool `json:"success"`
done chan struct{} Message string `json:"message"`
handlersMux sync.RWMutex }
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
}
// this is not json anymore
type Config struct {
ID string
Secret string
Endpoint string
TlsClientCert string // legacy PKCS12 file path
UserToken string // optional user token for websocket authentication
}
type Client struct {
config *Config
conn *websocket.Conn
baseURL string
handlers map[string]MessageHandler
done chan struct{}
handlersMux sync.RWMutex
reconnectInterval time.Duration reconnectInterval time.Duration
isConnected bool isConnected bool
reconnectMux sync.RWMutex reconnectMux sync.RWMutex
pingInterval time.Duration
onConnect func() error pingTimeout time.Duration
onConnect func() error
onTokenUpdate func(token string)
writeMux sync.Mutex
clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig
configNeedsSave bool // Flag to track if config needs to be saved
} }
type ClientOption func(*Client) type ClientOption func(*Client)
type MessageHandler func(message WSMessage) type MessageHandler func(message WSMessage)
// TLSConfig holds TLS configuration options
type TLSConfig struct {
// New separate certificate support
ClientCertFile string
ClientKeyFile string
CAFiles []string
// Existing PKCS12 support (deprecated)
PKCS12File string
}
// WithBaseURL sets the base URL for the client // WithBaseURL sets the base URL for the client
func WithBaseURL(url string) ClientOption { func WithBaseURL(url string) ClientOption {
return func(c *Client) { return func(c *Client) {
@@ -41,16 +84,32 @@ func WithBaseURL(url string) ClientOption {
} }
} }
// WithTLSConfig sets the TLS configuration for the client
func WithTLSConfig(config TLSConfig) ClientOption {
return func(c *Client) {
c.tlsConfig = config
// For backward compatibility, also set the legacy field
if config.PKCS12File != "" {
c.config.TlsClientCert = config.PKCS12File
}
}
}
func (c *Client) OnConnect(callback func() error) { func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback c.onConnect = callback
} }
// NewClient creates a new Newt client func (c *Client) OnTokenUpdate(callback func(token string)) {
func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) { c.onTokenUpdate = callback
}
// NewClient creates a new websocket client
func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{ config := &Config{
NewtID: newtID, ID: ID,
Secret: secret, Secret: secret,
Endpoint: endpoint, Endpoint: endpoint,
UserToken: userToken,
} }
client := &Client{ client := &Client{
@@ -58,39 +117,59 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
baseURL: endpoint, // default value baseURL: endpoint, // default value
handlers: make(map[string]MessageHandler), handlers: make(map[string]MessageHandler),
done: make(chan struct{}), done: make(chan struct{}),
reconnectInterval: 10 * time.Second, reconnectInterval: 3 * time.Second,
isConnected: false, isConnected: false,
pingInterval: pingInterval,
pingTimeout: pingTimeout,
clientType: "olm",
} }
// Apply options before loading config // Apply options before loading config
for _, opt := range opts { for _, opt := range opts {
if opt == nil {
continue
}
opt(client) opt(client)
} }
// Load existing config if available
if err := client.loadConfig(); err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
}
return client, nil return client, nil
} }
func (c *Client) GetConfig() *Config {
return c.config
}
// Connect establishes the WebSocket connection // Connect establishes the WebSocket connection
func (c *Client) Connect() error { func (c *Client) Connect() error {
go c.connectWithRetry() go c.connectWithRetry()
return nil return nil
} }
// Close closes the WebSocket connection // Close closes the WebSocket connection gracefully
func (c *Client) Close() error { func (c *Client) Close() error {
close(c.done) // Signal shutdown to all goroutines first
if c.conn != nil { select {
return c.conn.Close() case <-c.done:
// Already closed
return nil
default:
close(c.done)
} }
// stop the ping monitor // Set connection status to false
c.setConnected(false) c.setConnected(false)
// Close the WebSocket connection gracefully
if c.conn != nil {
// Send close message
c.writeMux.Lock()
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.writeMux.Unlock()
// Close the connection
return c.conn.Close()
}
return nil return nil
} }
@@ -105,9 +184,49 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
Data: data, Data: data,
} }
logger.Debug("Sending message: %s, data: %+v", messageType, data)
c.writeMux.Lock()
defer c.writeMux.Unlock()
return c.conn.WriteJSON(msg) return c.conn.WriteJSON(msg)
} }
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
stopChan := make(chan struct{})
go func() {
count := 0
maxAttempts := 10
err := c.SendMessage(messageType, data) // Send immediately
if err != nil {
logger.Error("Failed to send initial message: %v", err)
}
count++
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if count >= maxAttempts {
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return
}
err = c.SendMessage(messageType, data)
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case <-stopChan:
return
}
}
}()
return func() {
close(stopChan)
}
}
// RegisterHandler registers a handler for a specific message type // RegisterHandler registers a handler for a specific message type
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
c.handlersMux.Lock() c.handlersMux.Lock()
@@ -115,30 +234,6 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
c.handlers[messageType] = handler c.handlers[messageType] = handler
} }
// readPump pumps messages from the WebSocket connection
func (c *Client) readPump() {
defer c.conn.Close()
for {
select {
case <-c.done:
return
default:
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err != nil {
return
}
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
handler(msg)
}
c.handlersMux.RUnlock()
}
}
}
func (c *Client) getToken() (string, error) { func (c *Client) getToken() (string, error) {
// Parse the base URL to ensure we have the correct hostname // Parse the base URL to ensure we have the correct hostname
baseURL, err := url.Parse(c.baseURL) baseURL, err := url.Parse(c.baseURL)
@@ -149,57 +244,33 @@ func (c *Client) getToken() (string, error) {
// Ensure we have the base URL without trailing slashes // Ensure we have the base URL without trailing slashes
baseEndpoint := strings.TrimRight(baseURL.String(), "/") baseEndpoint := strings.TrimRight(baseURL.String(), "/")
// If we already have a token, try to use it var tlsConfig *tls.Config = nil
if c.config.Token != "" {
tokenCheckData := map[string]interface{}{ // Use new TLS configuration method
"newtId": c.config.NewtID, if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
"secret": c.config.Secret, tlsConfig, err = c.setupTLS()
"token": c.config.Token,
}
jsonData, err := json.Marshal(tokenCheckData)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to marshal token check data: %w", err) return "", fmt.Errorf("failed to setup TLS configuration: %w", err)
}
// Create a new request
req, err := http.NewRequest(
"POST",
baseEndpoint+"/api/v1/auth/newt/get-token",
bytes.NewBuffer(jsonData),
)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// Make the request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to check token validity: %w", err)
}
defer resp.Body.Close()
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token check response: %w", err)
}
// If token is still valid, return it
if tokenResp.Success && tokenResp.Message == "Token session already valid" {
return c.config.Token, nil
} }
} }
// Get a new token // Check for environment variable to skip TLS verification
tokenData := map[string]interface{}{ if os.Getenv("SKIP_TLS_VERIFY") == "true" {
"newtId": c.config.NewtID, if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
tlsConfig.InsecureSkipVerify = true
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
var tokenData map[string]interface{}
tokenData = map[string]interface{}{
"olmId": c.config.ID,
"secret": c.config.Secret, "secret": c.config.Secret,
} }
jsonData, err := json.Marshal(tokenData) jsonData, err := json.Marshal(tokenData)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to marshal token request data: %w", err) return "", fmt.Errorf("failed to marshal token request data: %w", err)
} }
@@ -207,7 +278,7 @@ func (c *Client) getToken() (string, error) {
// Create a new request // Create a new request
req, err := http.NewRequest( req, err := http.NewRequest(
"POST", "POST",
baseEndpoint+"/api/v1/auth/newt/get-token", baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token",
bytes.NewBuffer(jsonData), bytes.NewBuffer(jsonData),
) )
if err != nil { if err != nil {
@@ -220,14 +291,26 @@ func (c *Client) getToken() (string, error) {
// Make the request // Make the request
client := &http.Client{} client := &http.Client{}
if tlsConfig != nil {
client.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
}
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to request new token: %w", err) return "", fmt.Errorf("failed to request new token: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
}
var tokenResp TokenResponse var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("Failed to decode token response.")
return "", fmt.Errorf("failed to decode token response: %w", err) return "", fmt.Errorf("failed to decode token response: %w", err)
} }
@@ -239,6 +322,8 @@ func (c *Client) getToken() (string, error) {
return "", fmt.Errorf("received empty token from server") return "", fmt.Errorf("received empty token from server")
} }
logger.Debug("Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, nil return tokenResp.Data.Token, nil
} }
@@ -266,6 +351,10 @@ func (c *Client) establishConnection() error {
return fmt.Errorf("failed to get token: %w", err) return fmt.Errorf("failed to get token: %w", err)
} }
if c.onTokenUpdate != nil {
c.onTokenUpdate(token)
}
// Parse the base URL to determine protocol and hostname // Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL) baseURL, err := url.Parse(c.baseURL)
if err != nil { if err != nil {
@@ -288,10 +377,35 @@ func (c *Client) establishConnection() error {
// Add token to query parameters // Add token to query parameters
q := u.Query() q := u.Query()
q.Set("token", token) q.Set("token", token)
q.Set("clientType", c.clientType)
if c.config.UserToken != "" {
q.Set("userToken", c.config.UserToken)
}
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
// Connect to WebSocket // Connect to WebSocket
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) dialer := websocket.DefaultDialer
// Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
logger.Info("Setting up TLS configuration for WebSocket connection")
tlsConfig, err := c.setupTLS()
if err != nil {
return fmt.Errorf("failed to setup TLS configuration: %w", err)
}
dialer.TLSClientConfig = tlsConfig
}
// Check for environment variable to skip TLS verification for WebSocket connection
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
if dialer.TLSClientConfig == nil {
dialer.TLSClientConfig = &tls.Config{}
}
dialer.TLSClientConfig.InsecureSkipVerify = true
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
conn, _, err := dialer.Dial(u.String(), nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to WebSocket: %w", err) return fmt.Errorf("failed to connect to WebSocket: %w", err)
} }
@@ -301,8 +415,8 @@ func (c *Client) establishConnection() error {
// Start the ping monitor // Start the ping monitor
go c.pingMonitor() go c.pingMonitor()
// Start the read pump // Start the read pump with disconnect detection
go c.readPump() go c.readPumpWithDisconnectDetection()
if c.onConnect != nil { if c.onConnect != nil {
if err := c.onConnect(); err != nil { if err := c.onConnect(); err != nil {
@@ -313,8 +427,72 @@ func (c *Client) establishConnection() error {
return nil return nil
} }
// setupTLS configures TLS based on the TLS configuration
func (c *Client) setupTLS() (*tls.Config, error) {
tlsConfig := &tls.Config{}
// Handle new separate certificate configuration
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
logger.Info("Loading separate certificate files for mTLS")
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
// Load client certificate and key
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate pair: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
// Load CA certificates for remote validation if specified
if len(c.tlsConfig.CAFiles) > 0 {
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
caCertPool := x509.NewCertPool()
for _, caFile := range c.tlsConfig.CAFiles {
caCert, err := os.ReadFile(caFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file %s: %w", caFile, err)
}
// Try to parse as PEM first, then DER
if !caCertPool.AppendCertsFromPEM(caCert) {
// If PEM parsing failed, try DER
cert, err := x509.ParseCertificate(caCert)
if err != nil {
return nil, fmt.Errorf("failed to parse CA certificate from %s: %w", caFile, err)
}
caCertPool.AddCert(cert)
}
}
tlsConfig.RootCAs = caCertPool
}
return tlsConfig, nil
}
// Fallback to existing PKCS12 implementation for backward compatibility
if c.tlsConfig.PKCS12File != "" {
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
return c.setupPKCS12TLS()
}
// Legacy fallback using config.TlsClientCert
if c.config.TlsClientCert != "" {
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
return loadClientCertificate(c.config.TlsClientCert)
}
return nil, nil
}
// setupPKCS12TLS loads TLS configuration from PKCS12 file
func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
return loadClientCertificate(c.tlsConfig.PKCS12File)
}
// pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) pingMonitor() { func (c *Client) pingMonitor() {
ticker := time.NewTicker(30 * time.Second) ticker := time.NewTicker(c.pingInterval)
defer ticker.Stop() defer ticker.Stop()
for { for {
@@ -322,11 +500,74 @@ func (c *Client) pingMonitor() {
case <-c.done: case <-c.done:
return return
case <-ticker.C: case <-ticker.C:
if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { if c.conn == nil {
logger.Error("Ping failed: %v", err)
c.reconnect()
return return
} }
c.writeMux.Lock()
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("Ping failed: %v", err)
c.reconnect()
return
}
}
}
}
}
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
func (c *Client) readPumpWithDisconnectDetection() {
defer func() {
if c.conn != nil {
c.conn.Close()
}
// Only attempt reconnect if we're not shutting down
select {
case <-c.done:
// Shutting down, don't reconnect
return
default:
c.reconnect()
}
}()
for {
select {
case <-c.done:
return
default:
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err != nil {
// Check if we're shutting down before logging error
select {
case <-c.done:
// Expected during shutdown, don't log as error
logger.Debug("WebSocket connection closed during shutdown")
return
default:
// Unexpected error during normal operation
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Error("WebSocket read error: %v", err)
} else {
logger.Debug("WebSocket connection closed: %v", err)
}
return // triggers reconnect via defer
}
}
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
handler(msg)
}
c.handlersMux.RUnlock()
} }
} }
} }
@@ -335,9 +576,16 @@ func (c *Client) reconnect() {
c.setConnected(false) c.setConnected(false)
if c.conn != nil { if c.conn != nil {
c.conn.Close() c.conn.Close()
c.conn = nil
} }
go c.connectWithRetry() // Only reconnect if we're not shutting down
select {
case <-c.done:
return
default:
go c.connectWithRetry()
}
} }
func (c *Client) setConnected(status bool) { func (c *Client) setConnected(status bool) {
@@ -345,3 +593,42 @@ func (c *Client) setConnected(status bool) {
defer c.reconnectMux.Unlock() defer c.reconnectMux.Unlock()
c.isConnected = status c.isConnected = status
} }
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
func loadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path)
if err != nil {
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
}
// Parse PKCS12 with empty password for non-encrypted files
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
if err != nil {
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
}
// Create certificate
cert := tls.Certificate{
Certificate: [][]byte{certificate.Raw},
PrivateKey: privateKey,
}
// Optional: Add CA certificates if present
rootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
}
if len(caCerts) > 0 {
for _, caCert := range caCerts {
rootCAs.AddCert(caCert)
}
}
// Create TLS configuration
return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: rootCAs,
}, nil
}

View File

@@ -1,72 +0,0 @@
package websocket
import (
"encoding/json"
"log"
"os"
"path/filepath"
"runtime"
)
func getConfigPath() string {
var configDir string
switch runtime.GOOS {
case "darwin":
configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "newt-client")
case "windows":
configDir = filepath.Join(os.Getenv("APPDATA"), "newt-client")
default: // linux and others
configDir = filepath.Join(os.Getenv("HOME"), ".config", "newt-client")
}
if err := os.MkdirAll(configDir, 0755); err != nil {
log.Printf("Failed to create config directory: %v", err)
}
return filepath.Join(configDir, "config.json")
}
func (c *Client) loadConfig() error {
if c.config.NewtID != "" && c.config.Secret != "" && c.config.Endpoint != "" {
return nil
}
configPath := getConfigPath()
data, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
var config Config
if err := json.Unmarshal(data, &config); err != nil {
return err
}
if c.config.NewtID == "" {
c.config.NewtID = config.NewtID
}
if c.config.Token == "" {
c.config.Token = config.Token
}
if c.config.Secret == "" {
c.config.Secret = config.Secret
}
if c.config.Endpoint == "" {
c.config.Endpoint = config.Endpoint
c.baseURL = config.Endpoint
}
return nil
}
func (c *Client) saveConfig() error {
configPath := getConfigPath()
data, err := json.MarshalIndent(c.config, "", " ")
if err != nil {
return err
}
return os.WriteFile(configPath, data, 0644)
}

View File

@@ -1,21 +0,0 @@
package websocket
type Config struct {
NewtID string `json:"newtId"`
Secret string `json:"secret"`
Token string `json:"token"`
Endpoint string `json:"endpoint"`
}
type TokenResponse struct {
Data struct {
Token string `json:"token"`
} `json:"data"`
Success bool `json:"success"`
Message string `json:"message"`
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
}

260
wgtester/wgtester.go Normal file
View File

@@ -0,0 +1,260 @@
package wgtester
import (
"context"
"encoding/binary"
"net"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
const (
// Magic bytes to identify our packets
magicHeader uint32 = 0xDEADBEEF
// Request packet type
packetTypeRequest uint8 = 1
// Response packet type
packetTypeResponse uint8 = 2
// Packet format:
// - 4 bytes: magic header (0xDEADBEEF)
// - 1 byte: packet type (1 = request, 2 = response)
// - 8 bytes: timestamp (for round-trip timing)
packetSize = 13
)
// Client handles checking connectivity to a server
type Client struct {
conn *net.UDPConn
serverAddr string
monitorRunning bool
monitorLock sync.Mutex
connLock sync.Mutex // Protects connection operations
shutdownCh chan struct{}
packetInterval time.Duration
timeout time.Duration
maxAttempts int
}
// ConnectionStatus represents the current connection state
type ConnectionStatus struct {
Connected bool
RTT time.Duration
}
// NewClient creates a new connection test client
func NewClient(serverAddr string) (*Client, error) {
return &Client{
serverAddr: serverAddr,
shutdownCh: make(chan struct{}),
packetInterval: 2 * time.Second,
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
}, nil
}
// SetPacketInterval changes how frequently packets are sent in monitor mode
func (c *Client) SetPacketInterval(interval time.Duration) {
c.packetInterval = interval
}
// SetTimeout changes the timeout for waiting for responses
func (c *Client) SetTimeout(timeout time.Duration) {
c.timeout = timeout
}
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (c *Client) SetMaxAttempts(attempts int) {
c.maxAttempts = attempts
}
// Close cleans up client resources
func (c *Client) Close() {
c.StopMonitor()
c.connLock.Lock()
defer c.connLock.Unlock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
}
// ensureConnection makes sure we have an active UDP connection
func (c *Client) ensureConnection() error {
c.connLock.Lock()
defer c.connLock.Unlock()
if c.conn != nil {
return nil
}
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
if err != nil {
return err
}
c.conn, err = net.DialUDP("udp", nil, serverAddr)
if err != nil {
return err
}
return nil
}
// TestConnection checks if the connection to the server is working
// Returns true if connected, false otherwise
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
if err := c.ensureConnection(); err != nil {
logger.Warn("Failed to ensure connection: %v", err)
return false, 0
}
// Prepare packet buffer
packet := make([]byte, packetSize)
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
packet[4] = packetTypeRequest
// Send multiple attempts as specified
for attempt := 0; attempt < c.maxAttempts; attempt++ {
select {
case <-ctx.Done():
return false, 0
default:
// Add current timestamp to packet
timestamp := time.Now().UnixNano()
binary.BigEndian.PutUint64(packet[5:13], uint64(timestamp))
// Lock the connection for the entire send/receive operation
c.connLock.Lock()
// Check if connection is still valid after acquiring lock
if c.conn == nil {
c.connLock.Unlock()
return false, 0
}
logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
_, err := c.conn.Write(packet)
if err != nil {
c.connLock.Unlock()
logger.Info("Error sending packet: %v", err)
continue
}
logger.Debug("Successfully sent monitor packet")
// Set read deadline
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
// Wait for response
responseBuffer := make([]byte, packetSize)
n, err := c.conn.Read(responseBuffer)
c.connLock.Unlock()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout, try next attempt
time.Sleep(100 * time.Millisecond) // Brief pause between attempts
continue
}
logger.Error("Error reading response: %v", err)
continue
}
if n != packetSize {
continue // Malformed packet
}
// Verify response
magic := binary.BigEndian.Uint32(responseBuffer[0:4])
packetType := responseBuffer[4]
if magic != magicHeader || packetType != packetTypeResponse {
continue // Not our response
}
// Extract the original timestamp and calculate RTT
sentTimestamp := int64(binary.BigEndian.Uint64(responseBuffer[5:13]))
rtt := time.Duration(time.Now().UnixNano() - sentTimestamp)
return true, rtt
}
}
return false, 0
}
// TestConnectionWithTimeout tries to test connection with a timeout
// Returns true if connected, false otherwise
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return c.TestConnection(ctx)
}
// MonitorCallback is the function type for connection status change callbacks
type MonitorCallback func(status ConnectionStatus)
// StartMonitor begins monitoring the connection and calls the callback
// when the connection status changes
func (c *Client) StartMonitor(callback MonitorCallback) error {
c.monitorLock.Lock()
defer c.monitorLock.Unlock()
if c.monitorRunning {
logger.Info("Monitor already running")
return nil // Already running
}
if err := c.ensureConnection(); err != nil {
return err
}
c.monitorRunning = true
c.shutdownCh = make(chan struct{})
go func() {
var lastConnected bool
firstRun := true
ticker := time.NewTicker(c.packetInterval)
defer ticker.Stop()
for {
select {
case <-c.shutdownCh:
return
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
connected, rtt := c.TestConnection(ctx)
cancel()
// Callback if status changed or it's the first check
if connected != lastConnected || firstRun {
callback(ConnectionStatus{
Connected: connected,
RTT: rtt,
})
lastConnected = connected
firstRun = false
}
}
}
}()
return nil
}
// StopMonitor stops the connection monitoring
func (c *Client) StopMonitor() {
c.monitorLock.Lock()
defer c.monitorLock.Unlock()
if !c.monitorRunning {
return
}
close(c.shutdownCh)
c.monitorRunning = false
}