Compare commits

..

43 Commits

Author SHA1 Message Date
Owen
7fc09f8ed1 Fix windows build
Former-commit-id: 6af69cdcd6
2025-11-08 20:39:36 -08:00
Owen
079843602c Dont close and comment out dont create
Former-commit-id: 9b74bcfb81
2025-11-08 17:42:19 -08:00
Owen
70bf22c354 Remove binaries
Former-commit-id: 3398d2ab7e
2025-11-08 17:38:46 -08:00
Owen
3d891cfa97 Remove do not create client for now
Because its always created when the user joins the org


Former-commit-id: 8ebc678edb
2025-11-08 16:54:35 -08:00
Owen
78e3bb374a Split out hp
Former-commit-id: 29ed4fefbf
2025-11-07 21:59:07 -08:00
Owen
a61c7ca1ee Custom bind?
Former-commit-id: 6d8e298ebc
2025-11-07 21:39:28 -08:00
Owen
7696ba2e36 Add DoNotCreateNewClient
Former-commit-id: aedebb5579
2025-11-07 15:26:45 -08:00
Owen
235877c379 Add optional user token to validate
Former-commit-id: 5734684a21
2025-11-07 14:52:54 -08:00
Owen
befab0f8d1 Fix passing original arguments
Former-commit-id: 7e5b740514
2025-11-07 14:33:52 -08:00
Owen
914d080a57 Connecting disconnecting working
Former-commit-id: 553010f2ea
2025-11-07 14:31:13 -08:00
Owen
a274b4b38f Starting and stopping working
Former-commit-id: f23f2fb9aa
2025-11-07 14:20:36 -08:00
Owen
ce3c585514 Allow connecting and disconnecting
Former-commit-id: 596c4aa0da
2025-11-07 14:07:44 -08:00
Owen
963d8abad5 Add org id in the status
Former-commit-id: da1e4911bd
2025-11-03 20:54:55 -08:00
Owen
38eb56381f Update switching orgs
Former-commit-id: 690b133c7b
2025-11-03 20:33:06 -08:00
Owen
43b3822090 Allow pasing orgId to select org to connect
Former-commit-id: 46a4847cee
2025-11-03 16:54:38 -08:00
Owen
b0fb370c4d Remove status
Former-commit-id: 352ac8def6
2025-11-03 15:29:18 -08:00
Owen
99328ee76f Add registered to api
Former-commit-id: 9c496f7ca7
2025-11-03 15:16:12 -08:00
Owen
36fc3ea253 Add exit call
Former-commit-id: 4a89915826
2025-11-03 14:15:16 -08:00
Owen
a7979259f3 Make api availble over socket
Former-commit-id: e464af5302
2025-11-02 18:56:09 -08:00
Owen
ea6fa72bc0 Copy in config
Former-commit-id: 3505549331
2025-11-02 12:09:39 -08:00
Owen
f9adde6b1d Rename to run
Former-commit-id: 6f7e866e93
2025-11-01 18:39:53 -07:00
Owen
ba25586646 Import submodule
Former-commit-id: eaf94e6855
2025-11-01 18:37:53 -07:00
Owen
952ab63e8d Package?
Former-commit-id: 218e4f88bc
2025-11-01 18:34:00 -07:00
Owen
5e84f802ed Merge branch 'dev' of github.com:fosrl/olm into dev
Former-commit-id: ea05ac9c71
2025-10-27 21:42:52 -07:00
Owen
f40b0ff820 Use local websocket without config or otel
Former-commit-id: 87e2cf8405
2025-10-27 21:42:41 -07:00
Owen
95a4840374 Use local websocket without config or otel
Former-commit-id: a55803f8f3
2025-10-27 21:39:54 -07:00
Owen
27424170e4 Merge branch 'main' into dev
Former-commit-id: 2489283470
2025-10-27 21:27:21 -07:00
Owen Schwartz
a8ace6f64a Merge pull request #47 from gk1/feature/config-file-persistence
Updated the olm client to process config vars from cli,env,file in or…

Former-commit-id: 1a8ae792d0
2025-10-25 17:31:25 -07:00
Owen
3fa1073f49 Treat mtu as int and dont overwrite from websocket
Former-commit-id: 228bddcf79
2025-10-25 17:15:25 -07:00
Owen Schwartz
76d86c10ff Merge pull request #41 from fosrl/dependabot/go_modules/prod-minor-updates-b47bc3c6cb
Bump the prod-minor-updates group with 2 updates

Former-commit-id: 0f25ba98ac
2025-10-25 16:33:02 -07:00
gk1
2d34c6c8b2 Updated the olm client to process config vars from cli,env,file in order of precedence and persist them to file
Former-commit-id: 555c9dc9f4
2025-10-24 17:04:48 -07:00
dependabot[bot]
a7f3477bdd Bump the prod-minor-updates group with 2 updates
Bumps the prod-minor-updates group with 2 updates: [golang.org/x/crypto](https://github.com/golang/crypto) and [golang.org/x/sys](https://github.com/golang/sys).


Updates `golang.org/x/crypto` from 0.42.0 to 0.43.0
- [Commits](https://github.com/golang/crypto/compare/v0.42.0...v0.43.0)

Updates `golang.org/x/sys` from 0.36.0 to 0.37.0
- [Commits](https://github.com/golang/sys/compare/v0.36.0...v0.37.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.43.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: prod-minor-updates
- dependency-name: golang.org/x/sys
  dependency-version: 0.37.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: prod-minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: ef011f29f7
2025-10-20 20:36:13 +00:00
Owen Schwartz
af0a72d296 feat(actions): Sync Images from Docker to GHCR (#46)
* feat(actions): Sync Images from Docker to GHCR

* fix(actions): Moved mirror action to workflows

Former-commit-id: 2b181db148
2025-10-20 13:16:31 -07:00
Marc Schäfer
d1e836e760 fix(actions): Moved mirror action to workflows
Former-commit-id: c94dc6af69
2025-10-20 22:15:15 +02:00
Marc Schäfer
8dd45c4ca2 feat(actions): Sync Images from Docker to GHCR
Former-commit-id: 6275531187
2025-10-20 22:11:42 +02:00
Owen
9db009058b Merge branch 'main' into dev
Former-commit-id: b12e3bd895
2025-10-19 15:12:47 -07:00
Owen
29c01deb05 Update domains
Former-commit-id: 8629c40e2f
2025-10-19 15:12:37 -07:00
Owen
8afc28fdff GO update
Former-commit-id: c58f3ac92d
2025-10-08 17:46:40 -07:00
Owen
4ba2fb7b53 Merge branch 'main' into dev
Former-commit-id: d88bfd461e
2025-10-08 17:44:53 -07:00
Owen
2e6076923d Dont delete service args file
Former-commit-id: ff3b5c50fc
2025-10-08 17:35:14 -07:00
Owen
4c001dc751 Try to fix log rotation and service args
Former-commit-id: c5ece2f21f
2025-10-01 10:30:45 -07:00
Owen
dc9a547950 Add timeouts to hp
Former-commit-id: fa1d2b1f55
2025-09-29 14:55:19 -07:00
Owen
2be0933246 Add update checker
Former-commit-id: 2445ced83b
2025-09-29 14:37:07 -07:00
26 changed files with 4505 additions and 1507 deletions

View File

@@ -8,7 +8,7 @@ on:
jobs:
release:
name: Build and Release
runs-on: ubuntu-latest
runs-on: amd64-runner
steps:
- name: Checkout code

132
.github/workflows/mirror.yaml vendored Normal file
View File

@@ -0,0 +1,132 @@
name: Mirror & Sign (Docker Hub to GHCR)
on:
workflow_dispatch: {}
permissions:
contents: read
packages: write
id-token: write # for keyless OIDC
env:
SOURCE_IMAGE: docker.io/fosrl/olm
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
jobs:
mirror-and-dual-sign:
runs-on: amd64-runner
steps:
- name: Install skopeo + jq
run: |
sudo apt-get update -y
sudo apt-get install -y skopeo jq
skopeo --version
- name: Install cosign
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
- name: Input check
run: |
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
echo "Source : ${SOURCE_IMAGE}"
echo "Target : ${DEST_IMAGE}"
# Auth for skopeo (containers-auth)
- name: Skopeo login to GHCR
run: |
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
# Auth for cosign (docker-config)
- name: Docker login to GHCR (for cosign)
run: |
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
- name: List source tags
run: |
set -euo pipefail
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
| jq -r '.Tags[]' | sort -u > src-tags.txt
echo "Found source tags: $(wc -l < src-tags.txt)"
head -n 20 src-tags.txt || true
- name: List destination tags (skip existing)
run: |
set -euo pipefail
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
else
: > dst-tags.txt
fi
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
- name: Mirror, dual-sign, and verify
env:
# keyless
COSIGN_YES: "true"
# key-based
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
# verify
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
run: |
set -euo pipefail
copied=0; skipped=0; v_ok=0; errs=0
issuer="https://token.actions.githubusercontent.com"
id_regex="^https://github.com/${{ github.repository }}/.+"
while read -r tag; do
[ -z "$tag" ] && continue
if grep -Fxq "$tag" dst-tags.txt; then
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
skipped=$((skipped+1))
continue
fi
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
if ! skopeo copy --all --retry-times 3 \
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
errs=$((errs+1)); continue
fi
copied=$((copied+1))
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
ref="${DEST_IMAGE}@${digest}"
echo "==> cosign sign (keyless) --recursive ${ref}"
if ! cosign sign --recursive "${ref}"; then
echo "::warning title=Keyless sign failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign sign (key) --recursive ${ref}"
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
echo "::warning title=Key sign failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign verify (public key) ${ref}"
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
echo "::warning title=Verify(pubkey) failed::${ref}"
errs=$((errs+1))
fi
echo "==> cosign verify (keyless policy) ${ref}"
if ! cosign verify \
--certificate-oidc-issuer "${issuer}" \
--certificate-identity-regexp "${id_regex}" \
"${ref}" -o text; then
echo "::warning title=Verify(keyless) failed::${ref}"
errs=$((errs+1))
else
v_ok=$((v_ok+1))
fi
done < src-tags.txt
echo "---- Summary ----"
echo "Copied : $copied"
echo "Skipped : $skipped"
echo "Verified OK : $v_ok"
echo "Errors : $errs"

1
.gitignore vendored
View File

@@ -1,3 +1,2 @@
olm
.DS_Store
bin/

View File

@@ -4,11 +4,7 @@ Contributions are welcome!
Please see the contribution and local development guide on the docs page before getting started:
https://docs.fossorial.io/development
For ideas about what features to work on and our future plans, please see the roadmap:
https://docs.fossorial.io/roadmap
https://docs.pangolin.net/development/contributing
### Licensing Considerations

View File

@@ -10,7 +10,7 @@ docker-build-release:
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:$(tag) -f Dockerfile --push .
local:
CGO_ENABLED=0 go build -o olm
CGO_ENABLED=0 go build -o bin/olm
build:
docker build -t fosrl/olm:latest .

View File

@@ -6,7 +6,7 @@ Olm is a [WireGuard](https://www.wireguard.com/) tunnel client designed to secur
Olm is used with Pangolin and Newt as part of the larger system. See documentation below:
- [Full Documentation](https://docs.fossorial.io)
- [Full Documentation](https://docs.pangolin.net)
## Key Functions
@@ -107,7 +107,7 @@ $ cat ~/.config/olm-client/config.json
{
"id": "spmzu8rbpzj1qq6",
"secret": "f6v61mjutwme2kkydbw3fjo227zl60a2tsf5psw9r25hgae3",
"endpoint": "https://pangolin.fossorial.io",
"endpoint": "https://app.pangolin.net",
"tlsClientCert": ""
}
```

View File

@@ -3,7 +3,7 @@
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
2. Send a detailed report to [security@fossorial.io](mailto:security@fossorial.io) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
2. Send a detailed report to [security@pangolin.net](mailto:security@pangolin.net) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
- Description and location of the vulnerability.
- Potential impact of the vulnerability.

411
api/api.go Normal file
View File

@@ -0,0 +1,411 @@
package api
import (
"encoding/json"
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
// ConnectionRequest defines the structure for an incoming connection request
type ConnectionRequest struct {
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"`
}
// SwitchOrgRequest defines the structure for switching organizations
type SwitchOrgRequest struct {
OrgID string `json:"orgId"`
}
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
Connected bool `json:"connected"`
RTT time.Duration `json:"rtt"`
LastSeen time.Time `json:"lastSeen"`
Endpoint string `json:"endpoint,omitempty"`
IsRelay bool `json:"isRelay"`
PeerIP string `json:"peerAddress,omitempty"`
}
// StatusResponse is returned by the status endpoint
type StatusResponse struct {
Connected bool `json:"connected"`
Registered bool `json:"registered"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"`
OrgID string `json:"orgId,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
// API represents the HTTP server and its state
type API struct {
addr string
socketPath string
listener net.Listener
server *http.Server
connectionChan chan ConnectionRequest
switchOrgChan chan SwitchOrgRequest
shutdownChan chan struct{}
disconnectChan chan struct{}
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
isRegistered bool
tunnelIP string
version string
orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
func NewAPI(addr string) *API {
s := &API{
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
func NewAPISocket(socketPath string) *API {
s := &API{
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// Start starts the HTTP server
func (s *API) Start() error {
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit)
s.server = &http.Server{
Handler: mux,
}
var err error
if s.socketPath != "" {
// Use platform-specific socket listener
s.listener, err = createSocketListener(s.socketPath)
if err != nil {
return fmt.Errorf("failed to create socket listener: %w", err)
}
logger.Info("Starting HTTP server on socket %s", s.socketPath)
} else {
// Use TCP listener
s.listener, err = net.Listen("tcp", s.addr)
if err != nil {
return fmt.Errorf("failed to create TCP listener: %w", err)
}
logger.Info("Starting HTTP server on %s", s.addr)
}
go func() {
if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server error: %v", err)
}
}()
return nil
}
// Stop stops the HTTP server
func (s *API) Stop() error {
logger.Info("Stopping api server")
// Close the server first, which will also close the listener gracefully
if s.server != nil {
s.server.Close()
}
// Clean up socket file if using Unix socket
if s.socketPath != "" {
cleanupSocket(s.socketPath)
}
return nil
}
// GetConnectionChannel returns the channel for receiving connection requests
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
return s.connectionChan
}
// GetSwitchOrgChannel returns the channel for receiving org switch requests
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
return s.switchOrgChan
}
// GetShutdownChannel returns the channel for receiving shutdown requests
func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
}
// GetDisconnectChannel returns the channel for receiving disconnect requests
func (s *API) GetDisconnectChannel() <-chan struct{} {
return s.disconnectChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Connected = connected
status.RTT = rtt
status.LastSeen = time.Now()
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// SetConnectionStatus sets the overall connection status
func (s *API) SetConnectionStatus(isConnected bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isConnected = isConnected
if isConnected {
s.connectedAt = time.Now()
} else {
// Clear peer statuses when disconnected
s.peerStatuses = make(map[int]*PeerStatus)
}
}
func (s *API) SetRegistered(registered bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isRegistered = registered
}
// SetTunnelIP sets the tunnel IP address
func (s *API) SetTunnelIP(tunnelIP string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.tunnelIP = tunnelIP
}
// SetVersion sets the olm version
func (s *API) SetVersion(version string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.version = version
}
// SetOrgID sets the organization ID
func (s *API) SetOrgID(orgID string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.orgID = orgID
}
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// handleConnect handles the /connect endpoint
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// if we are already connected, reject new connection requests
s.statusMu.RLock()
alreadyConnected := s.isConnected
s.statusMu.RUnlock()
if alreadyConnected {
http.Error(w, "Already connected to a server. Disconnect first before connecting again.", http.StatusConflict)
return
}
var req ConnectionRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
return
}
// Send the request to the main goroutine
s.connectionChan <- req
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{
"status": "connection request accepted",
})
}
// handleStatus handles the /status endpoint
func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
s.statusMu.RLock()
defer s.statusMu.RUnlock()
resp := StatusResponse{
Connected: s.isConnected,
Registered: s.isRegistered,
TunnelIP: s.tunnelIP,
Version: s.version,
OrgID: s.orgID,
PeerStatuses: s.peerStatuses,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// handleExit handles the /exit endpoint
func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("Received exit request via API")
// Send shutdown signal
select {
case s.shutdownChan <- struct{}{}:
// Signal sent successfully
default:
// Channel already has a signal, don't block
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "shutdown initiated",
})
}
// handleSwitchOrg handles the /switch-org endpoint
func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req SwitchOrgRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.OrgID == "" {
http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
return
}
logger.Info("Received org switch request to orgId: %s", req.OrgID)
// Send the request to the main goroutine
select {
case s.switchOrgChan <- req:
// Signal sent successfully
default:
// Channel already has a pending request
http.Error(w, "Org switch already in progress", http.StatusConflict)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "org switch request accepted",
})
}
// handleDisconnect handles the /disconnect endpoint
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// if we are already disconnected, reject new disconnect requests
s.statusMu.RLock()
alreadyDisconnected := !s.isConnected
s.statusMu.RUnlock()
if alreadyDisconnected {
http.Error(w, "Not currently connected to a server.", http.StatusConflict)
return
}
logger.Info("Received disconnect request via API")
// Send disconnect signal
select {
case s.disconnectChan <- struct{}{}:
// Signal sent successfully
default:
// Channel already has a signal, don't block
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "disconnect initiated",
})
}

50
api/api_unix.go Normal file
View File

@@ -0,0 +1,50 @@
//go:build !windows
// +build !windows
package api
import (
"fmt"
"net"
"os"
"path/filepath"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Unix domain socket listener
func createSocketListener(socketPath string) (net.Listener, error) {
// Ensure the directory exists
dir := filepath.Dir(socketPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create socket directory: %w", err)
}
// Remove existing socket file if it exists
if err := os.RemoveAll(socketPath); err != nil {
return nil, fmt.Errorf("failed to remove existing socket: %w", err)
}
listener, err := net.Listen("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("failed to listen on Unix socket: %w", err)
}
// Set socket permissions to allow access
if err := os.Chmod(socketPath, 0666); err != nil {
listener.Close()
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
}
logger.Debug("Created Unix socket at %s", socketPath)
return listener, nil
}
// cleanupSocket removes the Unix socket file
func cleanupSocket(socketPath string) {
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
logger.Error("Failed to remove socket file %s: %v", socketPath, err)
} else {
logger.Debug("Removed Unix socket at %s", socketPath)
}
}

41
api/api_windows.go Normal file
View File

@@ -0,0 +1,41 @@
//go:build windows
// +build windows
package api
import (
"fmt"
"net"
"github.com/Microsoft/go-winio"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Windows named pipe listener
func createSocketListener(pipePath string) (net.Listener, error) {
// Ensure the pipe path has the correct format
if pipePath[0] != '\\' {
pipePath = `\\.\pipe\` + pipePath
}
// Create a pipe configuration that allows everyone to write
config := &winio.PipeConfig{
// Set security descriptor to allow everyone full access
// This SDDL string grants full access to Everyone (WD) and to the current owner (OW)
SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)",
}
// Create a named pipe listener using go-winio with the configuration
listener, err := winio.ListenPipe(pipePath, config)
if err != nil {
return nil, fmt.Errorf("failed to listen on named pipe: %w", err)
}
logger.Debug("Created named pipe at %s with write access for everyone", pipePath)
return listener, nil
}
// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up
func cleanupSocket(pipePath string) {
logger.Debug("Named pipe %s will be automatically cleaned up", pipePath)
}

378
bind/shared_bind.go Normal file
View File

@@ -0,0 +1,378 @@
//go:build !js
package bind
import (
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"sync/atomic"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// Endpoint represents a network endpoint for the SharedBind
type Endpoint struct {
AddrPort netip.AddrPort
}
// ClearSrc implements the wgConn.Endpoint interface
func (e *Endpoint) ClearSrc() {}
// DstIP implements the wgConn.Endpoint interface
func (e *Endpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
// SrcIP implements the wgConn.Endpoint interface
func (e *Endpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
// DstToBytes implements the wgConn.Endpoint interface
func (e *Endpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
// DstToString implements the wgConn.Endpoint interface
func (e *Endpoint) DstToString() string {
return e.AddrPort.String()
}
// SrcToString implements the wgConn.Endpoint interface
func (e *Endpoint) SrcToString() string {
return ""
}
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
// and hole punch senders. It wraps a single UDP connection and implements
// reference counting to prevent premature closure.
type SharedBind struct {
mu sync.RWMutex
// The underlying UDP connection
udpConn *net.UDPConn
// IPv4 and IPv6 packet connections for advanced features
ipv4PC *ipv4.PacketConn
ipv6PC *ipv6.PacketConn
// Reference counting to prevent closing while in use
refCount atomic.Int32
closed atomic.Bool
// Channels for receiving data
recvFuncs []wgConn.ReceiveFunc
// Port binding information
port uint16
}
// New creates a new SharedBind from an existing UDP connection.
// The SharedBind takes ownership of the connection and will close it
// when all references are released.
func New(udpConn *net.UDPConn) (*SharedBind, error) {
if udpConn == nil {
return nil, fmt.Errorf("udpConn cannot be nil")
}
bind := &SharedBind{
udpConn: udpConn,
}
// Initialize reference count to 1 (the creator holds the first reference)
bind.refCount.Store(1)
// Get the local port
if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok {
bind.port = uint16(addr.Port)
}
return bind, nil
}
// AddRef increments the reference count. Call this when sharing
// the bind with another component.
func (b *SharedBind) AddRef() {
newCount := b.refCount.Add(1)
// Optional: Add logging for debugging
_ = newCount // Placeholder for potential logging
}
// Release decrements the reference count. When it reaches zero,
// the underlying UDP connection is closed.
func (b *SharedBind) Release() error {
newCount := b.refCount.Add(-1)
// Optional: Add logging for debugging
_ = newCount // Placeholder for potential logging
if newCount < 0 {
// This should never happen with proper usage
b.refCount.Store(0)
return fmt.Errorf("SharedBind reference count went negative")
}
if newCount == 0 {
return b.closeConnection()
}
return nil
}
// closeConnection actually closes the UDP connection
func (b *SharedBind) closeConnection() error {
if !b.closed.CompareAndSwap(false, true) {
// Already closed
return nil
}
b.mu.Lock()
defer b.mu.Unlock()
var err error
if b.udpConn != nil {
err = b.udpConn.Close()
b.udpConn = nil
}
b.ipv4PC = nil
b.ipv6PC = nil
return err
}
// GetUDPConn returns the underlying UDP connection.
// The caller must not close this connection directly.
func (b *SharedBind) GetUDPConn() *net.UDPConn {
b.mu.RLock()
defer b.mu.RUnlock()
return b.udpConn
}
// GetRefCount returns the current reference count (for debugging)
func (b *SharedBind) GetRefCount() int32 {
return b.refCount.Load()
}
// IsClosed returns whether the bind is closed
func (b *SharedBind) IsClosed() bool {
return b.closed.Load()
}
// WriteToUDP writes data to a specific UDP address.
// This is thread-safe and can be used by hole punch senders.
func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) {
if b.closed.Load() {
return 0, net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
return conn.WriteToUDP(data, addr)
}
// Close implements the WireGuard Bind interface.
// It decrements the reference count and closes the connection if no references remain.
func (b *SharedBind) Close() error {
return b.Release()
}
// Open implements the WireGuard Bind interface.
// Since the connection is already open, this just sets up the receive functions.
func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
if b.closed.Load() {
return nil, 0, net.ErrClosed
}
b.mu.Lock()
defer b.mu.Unlock()
if b.udpConn == nil {
return nil, 0, net.ErrClosed
}
// Set up IPv4 and IPv6 packet connections for advanced features
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
}
// Create receive functions
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
// Add IPv4 receive function
if b.ipv4PC != nil || runtime.GOOS != "linux" {
recvFuncs = append(recvFuncs, b.makeReceiveIPv4())
}
// Add IPv6 receive function if needed
// For now, we focus on IPv4 for hole punching use case
b.recvFuncs = recvFuncs
return recvFuncs, b.port, nil
}
// makeReceiveIPv4 creates a receive function for IPv4 packets
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
if b.closed.Load() {
return 0, net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
pc := b.ipv4PC
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
// Use batch reading on Linux for performance
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
}
// Fallback to simple read for other platforms
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
}
}
// receiveIPv4Batch uses batch reading for better performance on Linux
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
// Create messages for batch reading
msgs := make([]ipv4.Message, len(bufs))
for i := range bufs {
msgs[i].Buffers = [][]byte{bufs[i]}
msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use
}
numMsgs, err := pc.ReadBatch(msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
sizes[i] = msgs[i].N
if sizes[i] == 0 {
continue
}
if msgs[i].Addr != nil {
if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok {
addrPort := udpAddr.AddrPort()
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
}
}
}
return numMsgs, nil
}
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
n, addr, err := conn.ReadFromUDP(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
if addr != nil {
addrPort := addr.AddrPort()
eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
}
return 1, nil
}
// Send implements the WireGuard Bind interface.
// It sends packets to the specified endpoint.
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
if b.closed.Load() {
return net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return net.ErrClosed
}
// Extract the destination address from the endpoint
var destAddr *net.UDPAddr
// Try to cast to StdNetEndpoint first
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
} else {
// Fallback: construct from DstIP and DstToBytes
dstBytes := ep.DstToBytes()
if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes)
var addr netip.Addr
var port uint16
if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes)
addr, _ = netip.AddrFromSlice(dstBytes[:16])
port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8
} else { // IPv4
addr, _ = netip.AddrFromSlice(dstBytes[:4])
port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8
}
if addr.IsValid() {
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
}
}
}
if destAddr == nil {
return fmt.Errorf("could not extract destination address from endpoint")
}
// Send all buffers to the destination
for _, buf := range bufs {
_, err := conn.WriteToUDP(buf, destAddr)
if err != nil {
return err
}
}
return nil
}
// SetMark implements the WireGuard Bind interface.
// It's a no-op for this implementation.
func (b *SharedBind) SetMark(mark uint32) error {
// Not implemented for this use case
return nil
}
// BatchSize returns the preferred batch size for sending packets.
func (b *SharedBind) BatchSize() int {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return wgConn.IdealBatchSize
}
return 1
}
// ParseEndpoint creates a new endpoint from a string address.
func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
addrPort, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil
}

424
bind/shared_bind_test.go Normal file
View File

@@ -0,0 +1,424 @@
//go:build !js
package bind
import (
"net"
"net/netip"
"sync"
"testing"
"time"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// TestSharedBindCreation tests basic creation and initialization
func TestSharedBindCreation(t *testing.T) {
// Create a UDP connection
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
defer udpConn.Close()
// Create SharedBind
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
if bind == nil {
t.Fatal("SharedBind is nil")
}
// Verify initial reference count
if bind.refCount.Load() != 1 {
t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load())
}
// Clean up
if err := bind.Close(); err != nil {
t.Errorf("Failed to close SharedBind: %v", err)
}
}
// TestSharedBindReferenceCount tests reference counting
func TestSharedBindReferenceCount(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
// Add references
bind.AddRef()
if bind.refCount.Load() != 2 {
t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load())
}
bind.AddRef()
if bind.refCount.Load() != 3 {
t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load())
}
// Release references
bind.Release()
if bind.refCount.Load() != 2 {
t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load())
}
bind.Release()
bind.Release() // This should close the connection
if !bind.closed.Load() {
t.Error("Expected bind to be closed after all references released")
}
}
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
func TestSharedBindWriteToUDP(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Send data
testData := []byte("Hello, SharedBind!")
n, err := senderBind.WriteToUDP(testData, receiverAddr)
if err != nil {
t.Fatalf("WriteToUDP failed: %v", err)
}
if n != len(testData) {
t.Errorf("Expected to send %d bytes, sent %d", len(testData), n)
}
// Receive data
buf := make([]byte, 1024)
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, _, err = receiverConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive data: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
}
}
// TestSharedBindConcurrentWrites tests thread-safety
func TestSharedBindConcurrentWrites(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Launch concurrent writes
numGoroutines := 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
data := []byte{byte(id)}
_, err := senderBind.WriteToUDP(data, receiverAddr)
if err != nil {
t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err)
}
}(i)
}
wg.Wait()
}
// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
func TestSharedBindWireGuardInterface(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer bind.Close()
// Test Open
recvFuncs, port, err := bind.Open(0)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
if len(recvFuncs) == 0 {
t.Error("Expected at least one receive function")
}
if port == 0 {
t.Error("Expected non-zero port")
}
// Test SetMark (should be a no-op)
if err := bind.SetMark(0); err != nil {
t.Errorf("SetMark failed: %v", err)
}
// Test BatchSize
batchSize := bind.BatchSize()
if batchSize <= 0 {
t.Error("Expected positive batch size")
}
}
// TestSharedBindSend tests the Send method with WireGuard endpoints
func TestSharedBindSend(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Create an endpoint
addrPort := receiverAddr.AddrPort()
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
// Send data
testData := []byte("WireGuard packet")
bufs := [][]byte{testData}
err = senderBind.Send(bufs, endpoint)
if err != nil {
t.Fatalf("Send failed: %v", err)
}
// Receive data
buf := make([]byte, 1024)
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, _, err := receiverConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive data: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
}
}
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
func TestSharedBindMultipleUsers(t *testing.T) {
// Create shared bind
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
sharedBind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
// Add reference for hole punch sender
sharedBind.AddRef()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
var wg sync.WaitGroup
// Simulate WireGuard using the bind
wg.Add(1)
go func() {
defer wg.Done()
addrPort := receiverAddr.AddrPort()
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
for i := 0; i < 10; i++ {
data := []byte("WireGuard packet")
bufs := [][]byte{data}
if err := sharedBind.Send(bufs, endpoint); err != nil {
t.Errorf("WireGuard Send failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
}()
// Simulate hole punch sender using the bind
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
data := []byte("Hole punch packet")
if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil {
t.Errorf("Hole punch WriteToUDP failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
}()
wg.Wait()
// Release the hole punch reference
sharedBind.Release()
// Close WireGuard's reference (should close the connection)
sharedBind.Close()
if !sharedBind.closed.Load() {
t.Error("Expected bind to be closed after all users released it")
}
}
// TestEndpoint tests the Endpoint implementation
func TestEndpoint(t *testing.T) {
addr := netip.MustParseAddr("192.168.1.1")
addrPort := netip.AddrPortFrom(addr, 51820)
ep := &Endpoint{AddrPort: addrPort}
// Test DstIP
if ep.DstIP() != addr {
t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP())
}
// Test DstToString
expected := "192.168.1.1:51820"
if ep.DstToString() != expected {
t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString())
}
// Test DstToBytes
bytes := ep.DstToBytes()
if len(bytes) == 0 {
t.Error("Expected DstToBytes to return non-empty slice")
}
// Test SrcIP (should be zero)
if ep.SrcIP().IsValid() {
t.Error("Expected SrcIP to be invalid")
}
// Test ClearSrc (should not panic)
ep.ClearSrc()
}
// TestParseEndpoint tests the ParseEndpoint method
func TestParseEndpoint(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer bind.Close()
tests := []struct {
name string
input string
wantErr bool
checkAddr func(*testing.T, wgConn.Endpoint)
}{
{
name: "valid IPv4",
input: "192.168.1.1:51820",
wantErr: false,
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
if ep.DstToString() != "192.168.1.1:51820" {
t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString())
}
},
},
{
name: "valid IPv6",
input: "[::1]:51820",
wantErr: false,
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
if ep.DstToString() != "[::1]:51820" {
t.Errorf("Expected [::1]:51820, got %s", ep.DstToString())
}
},
},
{
name: "invalid - missing port",
input: "192.168.1.1",
wantErr: true,
},
{
name: "invalid - bad format",
input: "not-an-address",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ep, err := bind.ParseEndpoint(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && tt.checkAddr != nil {
tt.checkAddr(t, ep)
}
})
}
}

562
config.go Normal file
View File

@@ -0,0 +1,562 @@
package main
import (
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"runtime"
"strconv"
"time"
)
// OlmConfig holds all configuration options for the Olm client
type OlmConfig struct {
// Connection settings
Endpoint string `json:"endpoint"`
ID string `json:"id"`
Secret string `json:"secret"`
OrgID string `json:"org"`
UserToken string `json:"userToken"`
// Network settings
MTU int `json:"mtu"`
DNS string `json:"dns"`
InterfaceName string `json:"interface"`
// Logging
LogLevel string `json:"logLevel"`
// HTTP server
EnableAPI bool `json:"enableApi"`
HTTPAddr string `json:"httpAddr"`
SocketPath string `json:"socketPath"`
// Ping settings
PingInterval string `json:"pingInterval"`
PingTimeout string `json:"pingTimeout"`
// Advanced
Holepunch bool `json:"holepunch"`
TlsClientCert string `json:"tlsClientCert"`
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
// Parsed values (not in JSON)
PingIntervalDuration time.Duration `json:"-"`
PingTimeoutDuration time.Duration `json:"-"`
// Source tracking (not in JSON)
sources map[string]string `json:"-"`
Version string
}
// ConfigSource tracks where each config value came from
type ConfigSource string
const (
SourceDefault ConfigSource = "default"
SourceFile ConfigSource = "file"
SourceEnv ConfigSource = "environment"
SourceCLI ConfigSource = "cli"
)
// DefaultConfig returns a config with default values
func DefaultConfig() *OlmConfig {
// Set OS-specific socket path
var socketPath string
switch runtime.GOOS {
case "windows":
socketPath = "olm"
default: // darwin, linux, and others
socketPath = "/var/run/olm.sock"
}
config := &OlmConfig{
MTU: 1280,
DNS: "8.8.8.8",
LogLevel: "INFO",
InterfaceName: "olm",
EnableAPI: false,
SocketPath: socketPath,
PingInterval: "3s",
PingTimeout: "5s",
Holepunch: false,
// DoNotCreateNewClient: false,
sources: make(map[string]string),
}
// Track default sources
config.sources["mtu"] = string(SourceDefault)
config.sources["dns"] = string(SourceDefault)
config.sources["logLevel"] = string(SourceDefault)
config.sources["interface"] = string(SourceDefault)
config.sources["enableApi"] = string(SourceDefault)
config.sources["httpAddr"] = string(SourceDefault)
config.sources["socketPath"] = string(SourceDefault)
config.sources["pingInterval"] = string(SourceDefault)
config.sources["pingTimeout"] = string(SourceDefault)
config.sources["holepunch"] = string(SourceDefault)
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
return config
}
// getOlmConfigPath returns the path to the olm config file
func getOlmConfigPath() string {
configFile := os.Getenv("CONFIG_FILE")
if configFile != "" {
return configFile
}
var configDir string
switch runtime.GOOS {
case "darwin":
configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client")
case "windows":
configDir = filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "olm-client")
default: // linux and others
configDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client")
}
if err := os.MkdirAll(configDir, 0755); err != nil {
fmt.Printf("Warning: Failed to create config directory: %v\n", err)
}
return filepath.Join(configDir, "config.json")
}
// LoadConfig loads configuration from file, env vars, and CLI args
// Priority: CLI args > Env vars > Config file > Defaults
// Returns: (config, showVersion, showConfig, error)
func LoadConfig(args []string) (*OlmConfig, bool, bool, error) {
// Start with defaults
config := DefaultConfig()
// Load from config file (if exists)
fileConfig, err := loadConfigFromFile()
if err != nil {
return nil, false, false, fmt.Errorf("failed to load config file: %w", err)
}
if fileConfig != nil {
mergeConfigs(config, fileConfig)
}
// Override with environment variables
loadConfigFromEnv(config)
// Override with CLI arguments
showVersion, showConfig, err := loadConfigFromCLI(config, args)
if err != nil {
return nil, false, false, err
}
// Parse duration strings
if err := config.parseDurations(); err != nil {
return nil, false, false, err
}
return config, showVersion, showConfig, nil
}
// loadConfigFromFile loads configuration from the JSON config file
func loadConfigFromFile() (*OlmConfig, error) {
configPath := getOlmConfigPath()
data, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil // File doesn't exist, not an error
}
return nil, err
}
var config OlmConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
return &config, nil
}
// loadConfigFromEnv loads configuration from environment variables
func loadConfigFromEnv(config *OlmConfig) {
if val := os.Getenv("PANGOLIN_ENDPOINT"); val != "" {
config.Endpoint = val
config.sources["endpoint"] = string(SourceEnv)
}
if val := os.Getenv("OLM_ID"); val != "" {
config.ID = val
config.sources["id"] = string(SourceEnv)
}
if val := os.Getenv("OLM_SECRET"); val != "" {
config.Secret = val
config.sources["secret"] = string(SourceEnv)
}
if val := os.Getenv("ORG"); val != "" {
config.OrgID = val
config.sources["org"] = string(SourceEnv)
}
if val := os.Getenv("USER_TOKEN"); val != "" {
config.UserToken = val
config.sources["userToken"] = string(SourceEnv)
}
if val := os.Getenv("MTU"); val != "" {
if mtu, err := strconv.Atoi(val); err == nil {
config.MTU = mtu
config.sources["mtu"] = string(SourceEnv)
} else {
fmt.Printf("Invalid MTU value: %s, keeping current value\n", val)
}
}
if val := os.Getenv("DNS"); val != "" {
config.DNS = val
config.sources["dns"] = string(SourceEnv)
}
if val := os.Getenv("LOG_LEVEL"); val != "" {
config.LogLevel = val
config.sources["logLevel"] = string(SourceEnv)
}
if val := os.Getenv("INTERFACE"); val != "" {
config.InterfaceName = val
config.sources["interface"] = string(SourceEnv)
}
if val := os.Getenv("HTTP_ADDR"); val != "" {
config.HTTPAddr = val
config.sources["httpAddr"] = string(SourceEnv)
}
if val := os.Getenv("PING_INTERVAL"); val != "" {
config.PingInterval = val
config.sources["pingInterval"] = string(SourceEnv)
}
if val := os.Getenv("PING_TIMEOUT"); val != "" {
config.PingTimeout = val
config.sources["pingTimeout"] = string(SourceEnv)
}
if val := os.Getenv("ENABLE_API"); val == "true" {
config.EnableAPI = true
config.sources["enableApi"] = string(SourceEnv)
}
if val := os.Getenv("SOCKET_PATH"); val != "" {
config.SocketPath = val
config.sources["socketPath"] = string(SourceEnv)
}
if val := os.Getenv("HOLEPUNCH"); val == "true" {
config.Holepunch = true
config.sources["holepunch"] = string(SourceEnv)
}
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
// config.DoNotCreateNewClient = true
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
// }
}
// loadConfigFromCLI loads configuration from command-line arguments
func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
// Store original values to detect changes
origValues := map[string]interface{}{
"endpoint": config.Endpoint,
"id": config.ID,
"secret": config.Secret,
"org": config.OrgID,
"userToken": config.UserToken,
"mtu": config.MTU,
"dns": config.DNS,
"logLevel": config.LogLevel,
"interface": config.InterfaceName,
"httpAddr": config.HTTPAddr,
"socketPath": config.SocketPath,
"pingInterval": config.PingInterval,
"pingTimeout": config.PingTimeout,
"enableApi": config.EnableAPI,
"holepunch": config.Holepunch,
// "doNotCreateNewClient": config.DoNotCreateNewClient,
}
// Define flags
serviceFlags.StringVar(&config.Endpoint, "endpoint", config.Endpoint, "Endpoint of your Pangolin server")
serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID")
serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret")
serviceFlags.StringVar(&config.OrgID, "org", config.OrgID, "Organization ID")
serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)")
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)")
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
version := serviceFlags.Bool("version", false, "Print the version")
showConfig := serviceFlags.Bool("show-config", false, "Show configuration sources and exit")
// Parse the arguments
if err := serviceFlags.Parse(args); err != nil {
return false, false, err
}
// Track which values were changed by CLI args
if config.Endpoint != origValues["endpoint"].(string) {
config.sources["endpoint"] = string(SourceCLI)
}
if config.ID != origValues["id"].(string) {
config.sources["id"] = string(SourceCLI)
}
if config.Secret != origValues["secret"].(string) {
config.sources["secret"] = string(SourceCLI)
}
if config.OrgID != origValues["org"].(string) {
config.sources["org"] = string(SourceCLI)
}
if config.UserToken != origValues["userToken"].(string) {
config.sources["userToken"] = string(SourceCLI)
}
if config.MTU != origValues["mtu"].(int) {
config.sources["mtu"] = string(SourceCLI)
}
if config.DNS != origValues["dns"].(string) {
config.sources["dns"] = string(SourceCLI)
}
if config.LogLevel != origValues["logLevel"].(string) {
config.sources["logLevel"] = string(SourceCLI)
}
if config.InterfaceName != origValues["interface"].(string) {
config.sources["interface"] = string(SourceCLI)
}
if config.HTTPAddr != origValues["httpAddr"].(string) {
config.sources["httpAddr"] = string(SourceCLI)
}
if config.SocketPath != origValues["socketPath"].(string) {
config.sources["socketPath"] = string(SourceCLI)
}
if config.PingInterval != origValues["pingInterval"].(string) {
config.sources["pingInterval"] = string(SourceCLI)
}
if config.PingTimeout != origValues["pingTimeout"].(string) {
config.sources["pingTimeout"] = string(SourceCLI)
}
if config.EnableAPI != origValues["enableApi"].(bool) {
config.sources["enableApi"] = string(SourceCLI)
}
if config.Holepunch != origValues["holepunch"].(bool) {
config.sources["holepunch"] = string(SourceCLI)
}
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
// }
return *version, *showConfig, nil
}
// parseDurations parses the duration strings into time.Duration
func (c *OlmConfig) parseDurations() error {
var err error
// Parse ping interval
if c.PingInterval != "" {
c.PingIntervalDuration, err = time.ParseDuration(c.PingInterval)
if err != nil {
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", c.PingInterval)
c.PingIntervalDuration = 3 * time.Second
c.PingInterval = "3s"
}
} else {
c.PingIntervalDuration = 3 * time.Second
c.PingInterval = "3s"
}
// Parse ping timeout
if c.PingTimeout != "" {
c.PingTimeoutDuration, err = time.ParseDuration(c.PingTimeout)
if err != nil {
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", c.PingTimeout)
c.PingTimeoutDuration = 5 * time.Second
c.PingTimeout = "5s"
}
} else {
c.PingTimeoutDuration = 5 * time.Second
c.PingTimeout = "5s"
}
return nil
}
// mergeConfigs merges source config into destination (only non-empty values)
// Also tracks that these values came from a file
func mergeConfigs(dest, src *OlmConfig) {
if src.Endpoint != "" {
dest.Endpoint = src.Endpoint
dest.sources["endpoint"] = string(SourceFile)
}
if src.ID != "" {
dest.ID = src.ID
dest.sources["id"] = string(SourceFile)
}
if src.Secret != "" {
dest.Secret = src.Secret
dest.sources["secret"] = string(SourceFile)
}
if src.OrgID != "" {
dest.OrgID = src.OrgID
dest.sources["org"] = string(SourceFile)
}
if src.UserToken != "" {
dest.UserToken = src.UserToken
dest.sources["userToken"] = string(SourceFile)
}
if src.MTU != 0 && src.MTU != 1280 {
dest.MTU = src.MTU
dest.sources["mtu"] = string(SourceFile)
}
if src.DNS != "" && src.DNS != "8.8.8.8" {
dest.DNS = src.DNS
dest.sources["dns"] = string(SourceFile)
}
if src.LogLevel != "" && src.LogLevel != "INFO" {
dest.LogLevel = src.LogLevel
dest.sources["logLevel"] = string(SourceFile)
}
if src.InterfaceName != "" && src.InterfaceName != "olm" {
dest.InterfaceName = src.InterfaceName
dest.sources["interface"] = string(SourceFile)
}
if src.HTTPAddr != "" && src.HTTPAddr != ":9452" {
dest.HTTPAddr = src.HTTPAddr
dest.sources["httpAddr"] = string(SourceFile)
}
if src.SocketPath != "" {
// Check if it's not the default for any OS
isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm"
if !isDefault {
dest.SocketPath = src.SocketPath
dest.sources["socketPath"] = string(SourceFile)
}
}
if src.PingInterval != "" && src.PingInterval != "3s" {
dest.PingInterval = src.PingInterval
dest.sources["pingInterval"] = string(SourceFile)
}
if src.PingTimeout != "" && src.PingTimeout != "5s" {
dest.PingTimeout = src.PingTimeout
dest.sources["pingTimeout"] = string(SourceFile)
}
if src.TlsClientCert != "" {
dest.TlsClientCert = src.TlsClientCert
dest.sources["tlsClientCert"] = string(SourceFile)
}
// For booleans, we always take the source value if explicitly set
if src.EnableAPI {
dest.EnableAPI = src.EnableAPI
dest.sources["enableApi"] = string(SourceFile)
}
if src.Holepunch {
dest.Holepunch = src.Holepunch
dest.sources["holepunch"] = string(SourceFile)
}
// if src.DoNotCreateNewClient {
// dest.DoNotCreateNewClient = src.DoNotCreateNewClient
// dest.sources["doNotCreateNewClient"] = string(SourceFile)
// }
}
// SaveConfig saves the current configuration to the config file
func SaveConfig(config *OlmConfig) error {
configPath := getOlmConfigPath()
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
return os.WriteFile(configPath, data, 0644)
}
// ShowConfig prints the configuration and the source of each value
func (c *OlmConfig) ShowConfig() {
configPath := getOlmConfigPath()
fmt.Println("\n=== Olm Configuration ===\n")
fmt.Printf("Config File: %s\n", configPath)
// Check if config file exists
if _, err := os.Stat(configPath); err == nil {
fmt.Printf("Config File Status: ✓ exists\n")
} else {
fmt.Printf("Config File Status: ✗ not found\n")
}
fmt.Println("\n--- Configuration Values ---")
fmt.Println("(Format: Setting = Value [source])\n")
// Helper to get source or default
getSource := func(key string) string {
if source, ok := c.sources[key]; ok {
return source
}
return string(SourceDefault)
}
// Helper to format value (mask secrets)
formatValue := func(key, value string) string {
if key == "secret" && value != "" {
if len(value) > 8 {
return value[:4] + "****" + value[len(value)-4:]
}
return "****"
}
if value == "" {
return "(not set)"
}
return value
}
// Connection settings
fmt.Println("Connection:")
fmt.Printf(" endpoint = %s [%s]\n", formatValue("endpoint", c.Endpoint), getSource("endpoint"))
fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id"))
fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret"))
fmt.Printf(" org = %s [%s]\n", formatValue("org", c.OrgID), getSource("org"))
fmt.Printf(" user-token = %s [%s]\n", formatValue("userToken", c.UserToken), getSource("userToken"))
// Network settings
fmt.Println("\nNetwork:")
fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu"))
fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns"))
fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface"))
// Logging
fmt.Println("\nLogging:")
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
// API server
fmt.Println("\nAPI Server:")
fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi"))
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath"))
// Timing
fmt.Println("\nTiming:")
fmt.Printf(" ping-interval = %s [%s]\n", c.PingInterval, getSource("pingInterval"))
fmt.Printf(" ping-timeout = %s [%s]\n", c.PingTimeout, getSource("pingTimeout"))
// Advanced
fmt.Println("\nAdvanced:")
fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch"))
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
if c.TlsClientCert != "" {
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
}
// Source legend
fmt.Println("\n--- Source Legend ---")
fmt.Println(" default = Built-in default value")
fmt.Println(" file = Loaded from config file")
fmt.Println(" environment = Set via environment variable")
fmt.Println(" cli = Provided as command-line argument")
fmt.Println("\nPriority: cli > environment > file > default")
fmt.Println()
}

523
diff Normal file
View File

@@ -0,0 +1,523 @@
diff --git a/api/api.go b/api/api.go
index dd07751..0d2e4ef 100644
--- a/api/api.go
+++ b/api/api.go
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
Endpoint string `json:"endpoint"`
}
+// SwitchOrgRequest defines the structure for switching organizations
+type SwitchOrgRequest struct {
+ OrgID string `json:"orgId"`
+}
+
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
@@ -35,6 +40,7 @@ type StatusResponse struct {
Registered bool `json:"registered"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"`
+ OrgID string `json:"orgId,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
@@ -46,6 +52,7 @@ type API struct {
server *http.Server
connectionChan chan ConnectionRequest
shutdownChan chan struct{}
+ switchOrgChan chan SwitchOrgRequest
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
@@ -53,6 +60,7 @@ type API struct {
isRegistered bool
tunnelIP string
version string
+ orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
@@ -61,6 +69,7 @@ func NewAPI(addr string) *API {
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
shutdownChan: make(chan struct{}, 1),
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API {
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
shutdownChan: make(chan struct{}, 1),
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -85,6 +95,7 @@ func (s *API) Start() error {
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/exit", s.handleExit)
+ mux.HandleFunc("/switch-org", s.handleSwitchOrg)
s.server = &http.Server{
Handler: mux,
@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
}
+// GetSwitchOrgChannel returns the channel for receiving org switch requests
+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
+ return s.switchOrgChan
+}
+
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) {
s.version = version
}
+// SetOrgID sets the org ID
+func (s *API) SetOrgID(orgID string) {
+ s.statusMu.Lock()
+ defer s.statusMu.Unlock()
+ s.orgID = orgID
+}
+
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
Registered: s.isRegistered,
TunnelIP: s.tunnelIP,
Version: s.version,
+ OrgID: s.orgID,
PeerStatuses: s.peerStatuses,
}
@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
"status": "shutdown initiated",
})
}
+
+// handleSwitchOrg handles the /switch-org endpoint
+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ var req SwitchOrgRequest
+ decoder := json.NewDecoder(r.Body)
+ if err := decoder.Decode(&req); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // Validate required fields
+ if req.OrgID == "" {
+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
+ return
+ }
+
+ logger.Info("Received org switch request to orgId: %s", req.OrgID)
+
+ // Send the request to the main goroutine
+ select {
+ case s.switchOrgChan <- req:
+ // Signal sent successfully
+ default:
+ // Channel already has a signal, don't block
+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests)
+ return
+ }
+
+ // Return a success response
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusAccepted)
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "org switch initiated",
+ "orgId": req.OrgID,
+ })
+}
diff --git a/olm/olm.go b/olm/olm.go
index 78080c4..5e292d6 100644
--- a/olm/olm.go
+++ b/olm/olm.go
@@ -58,6 +58,58 @@ type Config struct {
OrgID string
}
+// tunnelState holds all the active tunnel resources that need cleanup
+type tunnelState struct {
+ dev *device.Device
+ tdev tun.Device
+ uapiListener net.Listener
+ peerMonitor *peermonitor.PeerMonitor
+ stopRegister func()
+ connected bool
+}
+
+// teardownTunnel cleans up all tunnel resources
+func teardownTunnel(state *tunnelState) {
+ if state == nil {
+ return
+ }
+
+ logger.Info("Tearing down tunnel...")
+
+ // Stop registration messages
+ if state.stopRegister != nil {
+ state.stopRegister()
+ state.stopRegister = nil
+ }
+
+ // Stop peer monitor
+ if state.peerMonitor != nil {
+ state.peerMonitor.Stop()
+ state.peerMonitor = nil
+ }
+
+ // Close UAPI listener
+ if state.uapiListener != nil {
+ state.uapiListener.Close()
+ state.uapiListener = nil
+ }
+
+ // Close WireGuard device
+ if state.dev != nil {
+ state.dev.Close()
+ state.dev = nil
+ }
+
+ // Close TUN device
+ if state.tdev != nil {
+ state.tdev.Close()
+ state.tdev = nil
+ }
+
+ state.connected = false
+ logger.Info("Tunnel teardown complete")
+}
+
func Run(ctx context.Context, config Config) {
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) {
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
privateKey wgtypes.Key
- connected bool
- dev *device.Device
wgData WgData
holePunchData HolePunchData
- uapiListener net.Listener
- tdev tun.Device
+ orgID = config.OrgID
)
+ // Tunnel state that can be torn down and recreated
+ tunnel := &tunnelState{}
+
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) {
}
apiServer.SetVersion(config.Version)
+ apiServer.SetOrgID(orgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) {
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
- if connected {
+ if tunnel.connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
- if stopRegister != nil {
- stopRegister()
- stopRegister = nil
+ if tunnel.stopRegister != nil {
+ tunnel.stopRegister()
+ tunnel.stopRegister = nil
}
close(stopHolepunch)
@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) {
time.Sleep(500 * time.Millisecond)
// if there is an existing tunnel then close it
- if dev != nil {
+ if tunnel.dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
- dev.Close()
+ tunnel.dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) {
return
}
- tdev, err = func() (tun.Device, error) {
+ tunnel.tdev, err = func() (tun.Device, error) {
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) {
return
}
- if realInterfaceName, err2 := tdev.Name(); err2 == nil {
+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}
@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) {
return
}
- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
- uapiListener, err = uapiListen(interfaceName, fileUAPI)
+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) {
go func() {
for {
- conn, err := uapiListener.Accept()
+ conn, err := tunnel.uapiListener.Accept()
if err != nil {
return
}
- go dev.IpcHandle(conn)
+ go tunnel.dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
- if err = dev.Up(); err != nil {
+ if err = tunnel.dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetTunnelIP(wgData.TunnelIP)
}
- peerMonitor = peermonitor.NewPeerMonitor(
+ tunnel.peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
if apiServer != nil {
// Find the site config to get endpoint information
@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) {
},
fixKey(privateKey.String()),
olm,
- dev,
+ tunnel.dev,
doHolepunch,
)
@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) {
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) {
logger.Info("Configured peer %s", site.PublicKey)
}
- peerMonitor.Start()
+ tunnel.peerMonitor.Start()
if apiServer != nil {
apiServer.SetRegistered(true)
}
- connected = true
+ tunnel.connected = true
logger.Info("WireGuard device created.")
})
@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) {
}
// Update the peer in WireGuard
- if dev != nil {
+ if tunnel.dev != nil {
// Find the existing peer to get old data
var oldRemoteSubnets string
var oldPublicKey string
@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) {
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) {
// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) {
}
// Add the peer to WireGuard
- if dev != nil {
+ if tunnel.dev != nil {
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) {
}
// Remove the peer from WireGuard
- if dev != nil {
- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
+ if tunnel.dev != nil {
+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
logger.Error("Failed to remove peer: %v", err)
// Send error response if needed
return
@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) {
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetConnectionStatus(true)
}
- if connected {
+ if tunnel.connected {
logger.Debug("Already connected, skipping registration")
return nil
}
@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) {
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"olmVersion": config.Version,
- "orgId": config.OrgID,
+ "orgId": orgID,
}, 1*time.Second)
go keepSendingPing(olm)
@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) {
}
defer olm.Close()
+ // Listen for org switch requests from the API (after olm is created)
+ if apiServer != nil {
+ go func() {
+ for req := range apiServer.GetSwitchOrgChannel() {
+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID)
+
+ // Update the orgId
+ orgID = req.OrgID
+
+ // Teardown existing tunnel
+ teardownTunnel(tunnel)
+
+ // Reset tunnel state
+ tunnel = &tunnelState{}
+
+ // Stop holepunch
+ select {
+ case <-stopHolepunch:
+ // Channel already closed
+ default:
+ close(stopHolepunch)
+ }
+ stopHolepunch = make(chan struct{})
+
+ // Clear API server state
+ apiServer.SetRegistered(false)
+ apiServer.SetTunnelIP("")
+ apiServer.SetOrgID(orgID)
+
+ // Send new registration message with updated orgId
+ publicKey := privateKey.PublicKey()
+ logger.Info("Sending registration message with new orgId: %s", orgID)
+
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
+ "publicKey": publicKey.String(),
+ "relay": !doHolepunch,
+ "olmVersion": config.Version,
+ "orgId": orgID,
+ }, 1*time.Second)
+ }
+ }()
+ }
+
select {
case <-ctx.Done():
logger.Info("Context cancelled")
@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) {
close(stopHolepunch)
}
- if stopRegister != nil {
- stopRegister()
- stopRegister = nil
+ if tunnel.stopRegister != nil {
+ tunnel.stopRegister()
+ tunnel.stopRegister = nil
}
select {
@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) {
close(stopPing)
}
- if peerMonitor != nil {
- peerMonitor.Stop()
- }
-
- if uapiListener != nil {
- uapiListener.Close()
- }
- if dev != nil {
- dev.Close()
- }
+ // Use teardownTunnel to clean up all tunnel resources
+ teardownTunnel(tunnel)
if apiServer != nil {
apiServer.Stop()

40
go.mod
View File

@@ -3,49 +3,21 @@ module github.com/fosrl/olm
go 1.25
require (
github.com/Microsoft/go-winio v0.6.2
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7
github.com/gorilla/websocket v1.5.3
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.42.0
golang.org/x/crypto v0.43.0
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
golang.org/x/sys v0.36.0
golang.org/x/net v0.45.0
golang.org/x/sys v0.37.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
software.sslmate.com/src/go-pkcs12 v0.6.0
)
require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.4.0+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/gopacket v1.1.19 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
golang.org/x/net v0.44.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/time v0.12.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
software.sslmate.com/src/go-pkcs12 v0.6.0 // indirect
)

80
go.sum
View File

@@ -1,103 +1,35 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.4.0+incompatible h1:KVC7bz5zJY/4AZe/78BIvCnPsLaC9T/zh72xnlrTTOk=
github.com/docker/docker v28.4.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fosrl/newt v0.0.0-20250730062419-3ccd755d557a h1:bUGN4piHlcqgfdRLrwqiLZZxgcitzBzNDQS1+CHSmJI=
github.com/fosrl/newt v0.0.0-20250730062419-3ccd755d557a/go.mod h1:PbiPYp1hbL07awrmbqTSTz7lTenieTHN6cIkUVCGD3I=
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 h1:6bSU8Efyhx1SR53iSw1Wjk5V8vDfizGAudq/GlE9b+o=
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7/go.mod h1:Ac0k2FmAMC+hu21rAK+p7EnnEGrqKO/QZuGTVHA/XDM=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 h1:Hf9xI/XLML9ElpiHVDNwvqI0hIFlzV8dgIr35kV1kRU=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0/go.mod h1:NfchwuyNoMcZ5MLHwPrODwUF1HWCXWrL31s8gSAdIKY=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c=
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs=
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=

351
holepunch/holepunch.go Normal file
View File

@@ -0,0 +1,351 @@
package holepunch
import (
"encoding/json"
"fmt"
"net"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/bind"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// DomainResolver is a function type for resolving domains to IP addresses
type DomainResolver func(string) (string, error)
// ExitNode represents a WireGuard exit node for hole punching
type ExitNode struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
// Manager handles UDP hole punching operations
type Manager struct {
mu sync.Mutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
olmID string
token string
domainResolver DomainResolver
}
// NewManager creates a new hole punch manager
func NewManager(sharedBind *bind.SharedBind, olmID string, domainResolver DomainResolver) *Manager {
return &Manager{
sharedBind: sharedBind,
olmID: olmID,
domainResolver: domainResolver,
}
}
// SetToken updates the authentication token used for hole punching
func (m *Manager) SetToken(token string) {
m.mu.Lock()
defer m.mu.Unlock()
m.token = token
}
// IsRunning returns whether hole punching is currently active
func (m *Manager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.running
}
// Stop stops any ongoing hole punch operations
func (m *Manager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return
}
if m.stopChan != nil {
close(m.stopChan)
m.stopChan = nil
}
m.running = false
logger.Info("Hole punch manager stopped")
}
// StartMultipleExitNodes starts hole punching to multiple exit nodes
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
if len(exitNodes) == 0 {
m.mu.Unlock()
logger.Warn("No exit nodes provided for hole punching")
return fmt.Errorf("no exit nodes provided")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
go m.runMultipleExitNodes(exitNodes)
return nil
}
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
go m.runSingleEndpoint(endpoint, serverPubKey)
return nil
}
// runMultipleExitNodes performs hole punching to multiple exit nodes
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for all exit nodes")
}()
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
var resolvedNodes []resolvedExitNode
for _, exitNode := range exitNodes {
host, err := m.domainResolver(exitNode.Endpoint)
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
if len(resolvedNodes) == 0 {
logger.Error("No exit nodes could be resolved")
return
}
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
}
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-timeout.C:
logger.Debug("Hole punch timeout reached")
return
case <-ticker.C:
// Send hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
}
}
}
}
}
// runSingleEndpoint performs hole punching to a single endpoint
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
}()
host, err := m.domainResolver(endpoint)
if err != nil {
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
return
}
// Execute once immediately before starting the loop
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Warn("Failed to send initial hole punch: %v", err)
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-timeout.C:
logger.Debug("Hole punch timeout reached")
return
case <-ticker.C:
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Debug("Failed to send hole punch: %v", err)
}
}
}
}
// sendHolePunch sends an encrypted hole punch packet using the shared bind
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
m.mu.Lock()
token := m.token
olmID := m.olmID
m.mu.Unlock()
if serverPubKey == "" || token == "" {
return fmt.Errorf("server public key or OLM token is empty")
}
payload := struct {
OlmID string `json:"olmId"`
Token string `json:"token"`
}{
OlmID: olmID,
Token: token,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %w", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
}
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to write to UDP: %w", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}

View File

@@ -1,217 +0,0 @@
package httpserver
import (
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
// ConnectionRequest defines the structure for an incoming connection request
type ConnectionRequest struct {
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
}
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
Connected bool `json:"connected"`
RTT time.Duration `json:"rtt"`
LastSeen time.Time `json:"lastSeen"`
Endpoint string `json:"endpoint,omitempty"`
IsRelay bool `json:"isRelay"`
}
// StatusResponse is returned by the status endpoint
type StatusResponse struct {
Status string `json:"status"`
Connected bool `json:"connected"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
// HTTPServer represents the HTTP server and its state
type HTTPServer struct {
addr string
server *http.Server
connectionChan chan ConnectionRequest
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
tunnelIP string
version string
}
// NewHTTPServer creates a new HTTP server
func NewHTTPServer(addr string) *HTTPServer {
s := &HTTPServer{
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// Start starts the HTTP server
func (s *HTTPServer) Start() error {
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
s.server = &http.Server{
Addr: s.addr,
Handler: mux,
}
logger.Info("Starting HTTP server on %s", s.addr)
go func() {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server error: %v", err)
}
}()
return nil
}
// Stop stops the HTTP server
func (s *HTTPServer) Stop() error {
logger.Info("Stopping HTTP server")
return s.server.Close()
}
// GetConnectionChannel returns the channel for receiving connection requests
func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest {
return s.connectionChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Connected = connected
status.RTT = rtt
status.LastSeen = time.Now()
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// SetConnectionStatus sets the overall connection status
func (s *HTTPServer) SetConnectionStatus(isConnected bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isConnected = isConnected
if isConnected {
s.connectedAt = time.Now()
} else {
// Clear peer statuses when disconnected
s.peerStatuses = make(map[int]*PeerStatus)
}
}
// SetTunnelIP sets the tunnel IP address
func (s *HTTPServer) SetTunnelIP(tunnelIP string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.tunnelIP = tunnelIP
}
// SetVersion sets the olm version
func (s *HTTPServer) SetVersion(version string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.version = version
}
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// handleConnect handles the /connect endpoint
func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req ConnectionRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
return
}
// Send the request to the main goroutine
s.connectionChan <- req
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{
"status": "connection request accepted",
})
}
// handleStatus handles the /status endpoint
func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
s.statusMu.RLock()
defer s.statusMu.RUnlock()
resp := StatusResponse{
Connected: s.isConnected,
TunnelIP: s.tunnelIP,
Version: s.version,
PeerStatuses: s.peerStatuses,
}
if s.isConnected {
resp.Status = "connected"
} else {
resp.Status = "disconnected"
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}

922
main.go
View File

@@ -2,58 +2,16 @@ package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"net"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/olm/httpserver"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/wgtester"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/fosrl/olm/olm"
)
// Helper function to format endpoints correctly
func formatEndpoint(endpoint string) string {
if endpoint == "" {
return ""
}
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint // Already valid, no change needed
}
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
lastColon := strings.LastIndex(endpoint, ":")
if lastColon > 0 { // Ensure there is a colon and it's not the first character
hostPart := endpoint[:lastColon]
// Check if the host part is a literal IPv6 address
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
// It is! Reformat it with brackets.
portPart := endpoint[lastColon+1:]
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
}
}
// If it's not the specific malformed case, return it as is.
return endpoint
}
func main() {
// Check if we're running as a Windows service
if isWindowsService() {
@@ -195,853 +153,73 @@ func main() {
}
}
// Run in console mode
runOlmMain(context.Background())
}
// Create a context that will be cancelled on interrupt signals
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
func runOlmMain(ctx context.Context) {
// Run in console mode
runOlmMainWithArgs(ctx, os.Args[1:])
}
func runOlmMainWithArgs(ctx context.Context, args []string) {
// Log that we've entered the main function
// fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args)
// Create a new FlagSet for parsing service arguments
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
var (
endpoint string
id string
secret string
mtu string
mtuInt int
dns string
privateKey wgtypes.Key
err error
logLevel string
interfaceName string
enableHTTP bool
httpAddr string
testMode bool // Add this var for the test flag
testTarget string // Add this var for test target
pingInterval time.Duration
pingTimeout time.Duration
doHolepunch bool
connected bool
)
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
// if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
id = os.Getenv("OLM_ID")
secret = os.Getenv("OLM_SECRET")
mtu = os.Getenv("MTU")
dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL")
interfaceName = os.Getenv("INTERFACE")
httpAddr = os.Getenv("HTTP_ADDR")
pingIntervalStr := os.Getenv("PING_INTERVAL")
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
enableHTTPEnv := os.Getenv("ENABLE_HTTP")
holepunchEnv := os.Getenv("HOLEPUNCH")
enableHTTP = enableHTTPEnv == "true"
doHolepunch = holepunchEnv == "true"
if endpoint == "" {
serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
}
if id == "" {
serviceFlags.StringVar(&id, "id", "", "Olm ID")
}
if secret == "" {
serviceFlags.StringVar(&secret, "secret", "", "Olm secret")
}
if mtu == "" {
serviceFlags.StringVar(&mtu, "mtu", "1280", "MTU to use")
}
if dns == "" {
serviceFlags.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
}
if logLevel == "" {
serviceFlags.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if interfaceName == "" {
serviceFlags.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface")
}
if httpAddr == "" {
serviceFlags.StringVar(&httpAddr, "http-addr", ":9452", "HTTP server address (e.g., ':9452')")
}
if pingIntervalStr == "" {
serviceFlags.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)")
}
if pingTimeoutStr == "" {
serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)")
}
if enableHTTPEnv == "" {
serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests")
}
if holepunchEnv == "" {
serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)")
}
version := serviceFlags.Bool("version", false, "Print the version")
// Parse the service arguments
if err := serviceFlags.Parse(args); err != nil {
fmt.Printf("Error parsing service arguments: %v\n", err)
return
}
// Debug: Print final values after flag parsing
// fmt.Printf("After flag parsing: endpoint='%s', id='%s', secret='%s'\n", endpoint, id, secret)
// Parse ping intervals
if pingIntervalStr != "" {
pingInterval, err = time.ParseDuration(pingIntervalStr)
if err != nil {
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr)
pingInterval = 3 * time.Second
}
} else {
pingInterval = 3 * time.Second
}
if pingTimeoutStr != "" {
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
if err != nil {
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr)
pingTimeout = 5 * time.Second
}
} else {
pingTimeout = 5 * time.Second
}
// Setup Windows event logging if on Windows
if runtime.GOOS == "windows" {
if runtime.GOOS != "windows" {
setupWindowsEventLog()
} else {
// Initialize logger for non-Windows platforms
logger.Init()
}
loggerLevel := parseLogLevel(logLevel)
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
// Load configuration from file, env vars, and CLI args
// Priority: CLI args > Env vars > Config file > Defaults
config, showVersion, showConfig, err := LoadConfig(os.Args[1:])
if err != nil {
fmt.Printf("Failed to load configuration: %v\n", err)
return
}
// Handle --show-config flag
if showConfig {
config.ShowConfig()
os.Exit(0)
}
olmVersion := "version_replaceme"
if *version {
if showVersion {
fmt.Println("Olm version " + olmVersion)
os.Exit(0)
}
logger.Info("Olm version " + olmVersion)
config.Version = olmVersion
if err := SaveConfig(config); err != nil {
logger.Error("Failed to save full olm config: %v", err)
} else {
logger.Info("Olm version " + olmVersion)
logger.Debug("Saved full olm config with all options")
}
if err := updates.CheckForUpdate("fosrl", "olm", olmVersion); err != nil {
logger.Debug("Failed to check for updates: %v", err)
// Create a new olm.Config struct and copy values from the main config
olmConfig := olm.Config{
Endpoint: config.Endpoint,
ID: config.ID,
Secret: config.Secret,
UserToken: config.UserToken,
MTU: config.MTU,
DNS: config.DNS,
InterfaceName: config.InterfaceName,
LogLevel: config.LogLevel,
EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr,
SocketPath: config.SocketPath,
Holepunch: config.Holepunch,
TlsClientCert: config.TlsClientCert,
PingIntervalDuration: config.PingIntervalDuration,
PingTimeoutDuration: config.PingTimeoutDuration,
Version: config.Version,
OrgID: config.OrgID,
// DoNotCreateNewClient: config.DoNotCreateNewClient,
}
// Log startup information
logger.Debug("Olm service starting...")
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
logger.Debug("HTTP enabled: %v, HTTP addr: %s", enableHTTP, httpAddr)
if doHolepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
// Handle test mode
if testMode {
if testTarget == "" {
logger.Fatal("Test mode requires -test-target to be set to a server:port")
}
logger.Info("Running in test mode, connecting to %s", testTarget)
// Create a new tester client
tester, err := wgtester.NewClient(testTarget)
if err != nil {
logger.Fatal("Failed to create tester client: %v", err)
}
defer tester.Close()
// Test connection with a 2-second timeout
connected, rtt := tester.TestConnectionWithTimeout(2 * time.Second)
if connected {
logger.Info("Connection test successful! RTT: %v", rtt)
fmt.Printf("Connection test successful! RTT: %v\n", rtt)
os.Exit(0)
} else {
logger.Error("Connection test failed - no response received")
fmt.Println("Connection test failed - no response received")
os.Exit(1)
}
}
var httpServer *httpserver.HTTPServer
if enableHTTP {
httpServer = httpserver.NewHTTPServer(httpAddr)
httpServer.SetVersion(olmVersion)
if err := httpServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
// Use a goroutine to handle connection requests
go func() {
for req := range httpServer.GetConnectionChannel() {
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
}
}()
}
// // Check if required parameters are missing and provide helpful guidance
// missingParams := []string{}
// if id == "" {
// missingParams = append(missingParams, "id (use -id flag or OLM_ID env var)")
// }
// if secret == "" {
// missingParams = append(missingParams, "secret (use -secret flag or OLM_SECRET env var)")
// }
// if endpoint == "" {
// missingParams = append(missingParams, "endpoint (use -endpoint flag or PANGOLIN_ENDPOINT env var)")
// }
// if len(missingParams) > 0 {
// logger.Error("Missing required parameters: %v", missingParams)
// logger.Error("Either provide them as command line flags or set as environment variables")
// fmt.Printf("ERROR: Missing required parameters: %v\n", missingParams)
// fmt.Printf("Please provide them as command line flags or set as environment variables\n")
// if !enableHTTP {
// logger.Error("HTTP server is disabled, cannot receive parameters via API")
// fmt.Printf("HTTP server is disabled, cannot receive parameters via API\n")
// return
// }
// }
// Create a new olm
olm, err := websocket.NewClient(
"olm",
id, // CLI arg takes precedence
secret, // CLI arg takes precedence
endpoint,
pingInterval,
pingTimeout,
)
if err != nil {
logger.Fatal("Failed to create olm: %v", err)
}
endpoint = olm.GetConfig().Endpoint // Update endpoint from config
id = olm.GetConfig().ID // Update ID from config
secret = olm.GetConfig().Secret // Update secret from config
// wait until we have a client id and secret and endpoint
waitCount := 0
for id == "" || secret == "" || endpoint == "" {
select {
case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials")
return
default:
missing := []string{}
if id == "" {
missing = append(missing, "id")
}
if secret == "" {
missing = append(missing, "secret")
}
if endpoint == "" {
missing = append(missing, "endpoint")
}
waitCount++
if waitCount%10 == 1 { // Log every 10 seconds instead of every second
logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount)
}
time.Sleep(1 * time.Second)
}
}
// parse the mtu string into an int
mtuInt, err = strconv.Atoi(mtu)
if err != nil {
logger.Fatal("Failed to parse MTU: %v", err)
}
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
// Create TUN device and network stack
var dev *device.Device
var wgData WgData
var holePunchData HolePunchData
var uapiListener net.Listener
var tdev tun.Device
sourcePort, err := FindAvailableUDPPort(49152, 65535)
if err != nil {
fmt.Printf("Error finding available port: %v\n", err)
os.Exit(1)
}
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Create a new stopHolepunch channel for the new set of goroutines
stopHolepunch = make(chan struct{})
// Start a single hole punch goroutine for all exit nodes
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort)
})
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
logger.Debug("Received message: %v", msg.Data)
type LegacyHolePunchData struct {
ServerPubKey string `json:"serverPubKey"`
Endpoint string `json:"endpoint"`
}
var legacyHolePunchData LegacyHolePunchData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Stop any existing hole punch goroutines by closing the current channel
select {
case <-stopHolepunch:
// Channel already closed
default:
close(stopHolepunch)
}
// Create a new stopHolepunch channel for the new set of goroutines
stopHolepunch = make(chan struct{})
// Start hole punching for each exit node
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey)
})
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
if connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
close(stopHolepunch)
// wait 10 milliseconds to ensure the previous connection is closed
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
time.Sleep(500 * time.Millisecond)
// if there is an existing tunnel then close it
if dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
tdev, err = func() (tun.Device, error) {
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
return nil, err
}
return tun.CreateTUN(interfaceName, mtuInt)
}
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
return createTUNFromFD(tunFdStr, mtuInt)
}
return tun.CreateTUN(interfaceName, mtuInt)
}()
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
return
}
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}
fileUAPI, err := func() (*os.File, error) {
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil
}
return uapiOpen(interfaceName)
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() {
for {
conn, err := uapiListener.Accept()
if err != nil {
return
}
go dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
if err = dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
if httpServer != nil {
httpServer.SetTunnelIP(wgData.TunnelIP)
}
peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
if httpServer != nil {
// Find the site config to get endpoint information
var endpoint string
var isRelay bool
for _, site := range wgData.Sites {
if site.SiteId == siteID {
endpoint = site.Endpoint
// TODO: We'll need to track relay status separately
// For now, assume not using relay unless we get relay data
isRelay = !doHolepunch
break
}
}
httpServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
}
if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
} else {
logger.Warn("Peer %d is disconnected", siteID)
}
},
fixKey(privateKey.String()),
olm,
dev,
doHolepunch,
)
for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
if httpServer != nil {
httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
}
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
logger.Info("Configured peer %s", site.PublicKey)
}
peerMonitor.Start()
connected = true
logger.Info("WireGuard device created.")
})
olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) {
logger.Debug("Received update-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateData UpdatePeerData
if err := json.Unmarshal(jsonData, &updateData); err != nil {
logger.Error("Error unmarshaling update data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: updateData.SiteId,
Endpoint: updateData.Endpoint,
PublicKey: updateData.PublicKey,
ServerIP: updateData.ServerIP,
ServerPort: updateData.ServerPort,
RemoteSubnets: updateData.RemoteSubnets,
}
// Update the peer in WireGuard
if dev != nil {
// Find the existing peer to get old data
var oldRemoteSubnets string
var oldPublicKey string
for _, site := range wgData.Sites {
if site.SiteId == updateData.SiteId {
oldRemoteSubnets = site.RemoteSubnets
oldPublicKey = site.PublicKey
break
}
}
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
}
// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
// Remove old remote subnet routes if they changed
if oldRemoteSubnets != siteConfig.RemoteSubnets {
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
logger.Error("Failed to remove old remote subnet routes: %v", err)
// Continue anyway to add new routes
}
// Add new remote subnet routes
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add new remote subnet routes: %v", err)
return
}
}
// Update successful
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
for i := range wgData.Sites {
if wgData.Sites[i].SiteId == updateData.SiteId {
wgData.Sites[i] = siteConfig
break
}
}
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for adding a new peer
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var addData AddPeerData
if err := json.Unmarshal(jsonData, &addData); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: addData.SiteId,
Endpoint: addData.Endpoint,
PublicKey: addData.PublicKey,
ServerIP: addData.ServerIP,
ServerPort: addData.ServerPort,
RemoteSubnets: addData.RemoteSubnets,
}
// Add the peer to WireGuard
if dev != nil {
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for new peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
// Add successful
logger.Info("Successfully added peer for site %d", addData.SiteId)
// Update WgData with the new peer
wgData.Sites = append(wgData.Sites, siteConfig)
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for removing a peer
olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) {
logger.Debug("Received remove-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeData RemovePeerData
if err := json.Unmarshal(jsonData, &removeData); err != nil {
logger.Error("Error unmarshaling remove data: %v", err)
return
}
// Find the peer to remove
var peerToRemove *SiteConfig
var newSites []SiteConfig
for _, site := range wgData.Sites {
if site.SiteId == removeData.SiteId {
peerToRemove = &site
} else {
newSites = append(newSites, site)
}
}
if peerToRemove == nil {
logger.Error("Peer with site ID %d not found", removeData.SiteId)
return
}
// Remove the peer from WireGuard
if dev != nil {
if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
logger.Error("Failed to remove peer: %v", err)
// Send error response if needed
return
}
// Remove route for the peer
err = removeRouteForServerIP(peerToRemove.ServerIP)
if err != nil {
logger.Error("Failed to remove route for peer: %v", err)
return
}
// Remove routes for remote subnets
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
logger.Error("Failed to remove routes for remote subnets: %v", err)
return
}
// Remove successful
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
// Update WgData to remove the peer
wgData.Sites = newSites
} else {
logger.Error("WireGuard device not initialized")
}
})
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
logger.Debug("Received relay-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData RelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := resolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
if httpServer != nil {
httpServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
logger.Info("Received no-sites message - no sites available for connection")
// if stopRegister != nil {
// stopRegister()
// stopRegister = nil
// }
// select {
// case <-stopHolepunch:
// // Channel already closed, do nothing
// default:
// close(stopHolepunch)
// }
logger.Info("No sites available - stopped registration and holepunch processes")
})
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
olm.Close()
})
olm.OnConnect(func() error {
logger.Info("Websocket Connected")
if httpServer != nil {
httpServer.SetConnectionStatus(true)
}
if connected {
logger.Debug("Already connected, skipping registration")
return nil
}
publicKey := privateKey.PublicKey()
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"olmVersion": olmVersion,
}, 1*time.Second)
go keepSendingPing(olm)
logger.Info("Sent registration message")
return nil
})
olm.OnTokenUpdate(func(token string) {
olmToken = token
})
// Connect to the WebSocket server
if err := olm.Connect(); err != nil {
logger.Fatal("Failed to connect to server: %v", err)
}
defer olm.Close()
// Wait for interrupt signal or context cancellation
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case <-sigCh:
logger.Info("Received interrupt signal")
case <-ctx.Done():
logger.Info("Context cancelled")
}
select {
case <-stopHolepunch:
// Channel already closed, do nothing
default:
close(stopHolepunch)
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
select {
case <-stopPing:
// Channel already closed
default:
close(stopPing)
}
if uapiListener != nil {
uapiListener.Close()
}
if dev != nil {
dev.Close()
}
logger.Info("runOlmMain() exiting")
fmt.Printf("runOlmMain() exiting\n")
olm.Run(ctx, olmConfig)
}

View File

@@ -4,7 +4,7 @@
#define MyAppName "olm"
#define MyAppVersion "1.0.0"
#define MyAppPublisher "Fossorial Inc."
#define MyAppURL "https://fossorial.io"
#define MyAppURL "https://pangolin.net"
#define MyAppExeName "olm.exe"
[Setup]

View File

@@ -1,9 +1,8 @@
package main
package olm
import (
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net"
"os/exec"
@@ -14,13 +13,10 @@ import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"github.com/vishvananda/netlink"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -82,11 +78,6 @@ const (
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
)
type fixedPortBind struct {
port uint16
conn.Bind
}
// PeerAction represents a request to add, update, or remove a peer
type PeerAction struct {
Action string `json:"action"` // "add", "update", or "remove"
@@ -124,16 +115,31 @@ type RelayPeerData struct {
PublicKey string `json:"publicKey"`
}
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
// Ignore the requested port and use our fixed port
return b.Bind.Open(b.port)
}
func NewFixedPortBind(port uint16) conn.Bind {
return &fixedPortBind{
port: port,
Bind: conn.NewDefaultBind(),
// Helper function to format endpoints correctly
func formatEndpoint(endpoint string) string {
if endpoint == "" {
return ""
}
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint // Already valid, no change needed
}
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
lastColon := strings.LastIndex(endpoint, ":")
if lastColon > 0 { // Ensure there is a colon and it's not the first character
hostPart := endpoint[:lastColon]
// Check if the host part is a literal IPv6 address
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
// It is! Reformat it with brackets.
portPart := endpoint[lastColon+1:]
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
}
}
// If it's not the specific malformed case, return it as is.
return endpoint
}
func fixKey(key string) string {
@@ -182,7 +188,7 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int {
}
}
func resolveDomain(domain string) (string, error) {
func ResolveDomain(domain string) (string, error) {
// First handle any protocol prefix
domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://")
@@ -229,273 +235,6 @@ func resolveDomain(domain string) (string, error) {
return ipAddr, nil
}
func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
if serverPubKey == "" || olmToken == "" {
return nil
}
payload := struct {
OlmID string `json:"olmId"`
Token string `json:"token"`
}{
OlmID: olmID,
Token: olmToken,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %v", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %v", err)
}
_, err = conn.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to send UDP packet: %v", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange (replacing deprecated ScalarMult)
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}
func keepSendingUDPHolePunchToMultipleExitNodes(exitNodes []ExitNode, olmID string, sourcePort uint16) {
if len(exitNodes) == 0 {
logger.Warn("No exit nodes provided for hole punching")
return
}
// Check if hole punching is already running
if holePunchRunning {
logger.Debug("UDP hole punch already running, skipping new request")
return
}
// Set the flag to indicate hole punching is running
holePunchRunning = true
defer func() {
holePunchRunning = false
logger.Info("UDP hole punch goroutine ended")
}()
logger.Info("Starting UDP hole punch to %d exit nodes", len(exitNodes))
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
// Create the UDP connection once and reuse it for all exit nodes
localAddr := &net.UDPAddr{
Port: int(sourcePort),
IP: net.IPv4zero,
}
conn, err := net.ListenUDP("udp", localAddr)
if err != nil {
logger.Error("Failed to bind UDP socket: %v", err)
return
}
defer conn.Close()
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
var resolvedNodes []resolvedExitNode
for _, exitNode := range exitNodes {
host, err := resolveDomain(exitNode.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
if len(resolvedNodes) == 0 {
logger.Error("No exit nodes could be resolved")
return
}
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
}
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-stopHolepunch:
logger.Info("Stopping UDP holepunch for all exit nodes")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
return
case <-ticker.C:
// Send hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
}
}
}
}
}
func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, serverPubKey string) {
// Check if hole punching is already running
if holePunchRunning {
logger.Debug("UDP hole punch already running, skipping new request")
return
}
// Set the flag to indicate hole punching is running
holePunchRunning = true
defer func() {
holePunchRunning = false
logger.Info("UDP hole punch goroutine ended")
}()
logger.Info("Starting UDP hole punch to %s", endpoint)
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
host, err := resolveDomain(endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint: %v", err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
// Create the UDP connection once and reuse it
localAddr := &net.UDPAddr{
Port: int(sourcePort),
IP: net.IPv4zero,
}
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address: %v", err)
return
}
conn, err := net.ListenUDP("udp", localAddr)
if err != nil {
logger.Error("Failed to bind UDP socket: %v", err)
return
}
defer conn.Close()
// Execute once immediately before starting the loop
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-stopHolepunch:
logger.Info("Stopping UDP holepunch")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds")
return
case <-ticker.C:
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
}
}
}
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
if maxPort < minPort {
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
@@ -535,6 +274,7 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
func sendPing(olm *websocket.Client) error {
err := olm.SendMessage("olm/ping", map[string]interface{}{
"timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken,
})
if err != nil {
logger.Error("Failed to send ping message: %v", err)
@@ -571,7 +311,7 @@ func keepSendingPing(olm *websocket.Client) {
// ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
siteHost, err := resolveDomain(siteConfig.Endpoint)
siteHost, err := ResolveDomain(siteConfig.Endpoint)
if err != nil {
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
}
@@ -628,7 +368,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable
primaryRelay, err := ResolveDomain(endpoint) // Using global endpoint variable
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}

892
olm/olm.go Normal file
View File

@@ -0,0 +1,892 @@
package olm
import (
"context"
"encoding/json"
"net"
"os"
"runtime"
"strconv"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates"
"github.com/fosrl/olm/api"
"github.com/fosrl/olm/bind"
"github.com/fosrl/olm/holepunch"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Config struct {
// Connection settings
Endpoint string
ID string
Secret string
UserToken string
// Network settings
MTU int
DNS string
InterfaceName string
// Logging
LogLevel string
// HTTP server
EnableAPI bool
HTTPAddr string
SocketPath string
// Advanced
Holepunch bool
TlsClientCert string
// Parsed values (not in JSON)
PingIntervalDuration time.Duration
PingTimeoutDuration time.Duration
// Source tracking (not in JSON)
sources map[string]string
Version string
OrgID string
// DoNotCreateNewClient bool
}
var (
privateKey wgtypes.Key
connected bool
dev *device.Device
wgData WgData
holePunchData HolePunchData
uapiListener net.Listener
tdev tun.Device
apiServer *api.API
olmClient *websocket.Client
tunnelCancel context.CancelFunc
tunnelRunning bool
sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager
)
func Run(ctx context.Context, config Config) {
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger.GetLogger().SetLevel(parseLogLevel(config.LogLevel))
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
logger.Debug("Failed to check for updates: %v", err)
}
if config.Holepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
if config.HTTPAddr != "" {
apiServer = api.NewAPI(config.HTTPAddr)
} else if config.SocketPath != "" {
apiServer = api.NewAPISocket(config.SocketPath)
}
apiServer.SetVersion(config.Version)
apiServer.SetOrgID(config.OrgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
// Listen for shutdown requests from the API
go func() {
<-apiServer.GetShutdownChannel()
logger.Info("Shutdown requested via API")
// Cancel the context to trigger graceful shutdown
cancel()
}()
var (
id = config.ID
secret = config.Secret
endpoint = config.Endpoint
userToken = config.UserToken
)
// Main event loop that handles connect, disconnect, and reconnect
for {
select {
case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials")
goto shutdown
case req := <-apiServer.GetConnectionChannel():
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Stop any existing tunnel before starting a new one
if olmClient != nil {
logger.Info("Stopping existing tunnel before starting new connection")
StopTunnel()
}
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
userToken := req.UserToken
// Start the tunnel process with the new credentials
if id != "" && secret != "" && endpoint != "" {
logger.Info("Starting tunnel with new credentials")
tunnelRunning = true
go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
}
case <-apiServer.GetDisconnectChannel():
logger.Info("Received disconnect request via API")
StopTunnel()
// Clear credentials so we wait for new connect call
id = ""
secret = ""
endpoint = ""
userToken = ""
default:
// If we have credentials and no tunnel is running, start it
if id != "" && secret != "" && endpoint != "" && !tunnelRunning {
logger.Info("Starting tunnel process with initial credentials")
tunnelRunning = true
go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
} else if id == "" || secret == "" || endpoint == "" {
// If we don't have credentials, check if API is enabled
if !config.EnableAPI {
missing := []string{}
if id == "" {
missing = append(missing, "id")
}
if secret == "" {
missing = append(missing, "secret")
}
if endpoint == "" {
missing = append(missing, "endpoint")
}
// exit the application because there is no way to provide the missing parameters
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
goto shutdown
}
}
// Sleep briefly to prevent tight loop
time.Sleep(100 * time.Millisecond)
}
}
shutdown:
Stop()
apiServer.Stop()
logger.Info("Olm service shutting down")
}
func TunnelProcess(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) {
// Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(ctx)
tunnelCancel = cancel
defer func() {
tunnelCancel = nil
}()
// Recreate channels for this tunnel session
stopPing = make(chan struct{})
var (
interfaceName = config.InterfaceName
loggerLevel = parseLogLevel(config.LogLevel)
)
// Create a new olm client using the provided credentials
olm, err := websocket.NewClient(
id, // Use provided ID
secret, // Use provided secret
userToken, // Use provided user token OPTIONAL
endpoint, // Use provided endpoint
config.PingIntervalDuration,
config.PingTimeoutDuration,
)
if err != nil {
logger.Error("Failed to create olm: %v", err)
return
}
// Store the client reference globally
olmClient = olm
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Error("Failed to generate private key: %v", err)
return
}
// Create shared UDP socket for both holepunch and WireGuard
if sharedBind == nil {
sourcePort, err := FindAvailableUDPPort(49152, 65535)
if err != nil {
logger.Error("Error finding available port: %v", err)
return
}
localAddr := &net.UDPAddr{
Port: int(sourcePort),
IP: net.IPv4zero,
}
udpConn, err := net.ListenUDP("udp", localAddr)
if err != nil {
logger.Error("Failed to create shared UDP socket: %v", err)
return
}
sharedBind, err = bind.New(udpConn)
if err != nil {
logger.Error("Failed to create shared bind: %v", err)
udpConn.Close()
return
}
// Add a reference for the hole punch senders (creator already has one reference for WireGuard)
sharedBind.AddRef()
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
}
// Create the holepunch manager
if holePunchManager == nil {
holePunchManager = holepunch.NewManager(sharedBind, id, ResolveDomain)
}
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
for i, node := range holePunchData.ExitNodes {
exitNodes[i] = holepunch.ExitNode{
Endpoint: node.Endpoint,
PublicKey: node.PublicKey,
}
}
// Start hole punching using the manager
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
})
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
logger.Debug("Received message: %v", msg.Data)
type LegacyHolePunchData struct {
ServerPubKey string `json:"serverPubKey"`
Endpoint string `json:"endpoint"`
}
var legacyHolePunchData LegacyHolePunchData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Stop any existing hole punch operations
if holePunchManager != nil {
holePunchManager.Stop()
}
// Start hole punching for the exit node
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
})
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
if connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
// wait 10 milliseconds to ensure the previous connection is closed
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
time.Sleep(500 * time.Millisecond)
// if there is an existing tunnel then close it
if dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
tdev, err = func() (tun.Device, error) {
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
return nil, err
}
return tun.CreateTUN(interfaceName, config.MTU)
}
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
return createTUNFromFD(tunFdStr, config.MTU)
}
return tun.CreateTUN(interfaceName, config.MTU)
}()
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
return
}
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}
fileUAPI, err := func() (*os.File, error) {
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil
}
return uapiOpen(interfaceName)
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
dev = device.NewDevice(tdev, sharedBind, device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() {
for {
conn, err := uapiListener.Accept()
if err != nil {
return
}
go dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
if err = dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
apiServer.SetTunnelIP(wgData.TunnelIP)
peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
// Find the site config to get endpoint information
var endpoint string
var isRelay bool
for _, site := range wgData.Sites {
if site.SiteId == siteID {
endpoint = site.Endpoint
// TODO: We'll need to track relay status separately
// For now, assume not using relay unless we get relay data
isRelay = !config.Holepunch
break
}
}
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
} else {
logger.Warn("Peer %d is disconnected", siteID)
}
},
fixKey(privateKey.String()),
olm,
dev,
config.Holepunch,
)
for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
logger.Info("Configured peer %s", site.PublicKey)
}
peerMonitor.Start()
apiServer.SetRegistered(true)
connected = true
logger.Info("WireGuard device created.")
})
olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) {
logger.Debug("Received update-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateData UpdatePeerData
if err := json.Unmarshal(jsonData, &updateData); err != nil {
logger.Error("Error unmarshaling update data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: updateData.SiteId,
Endpoint: updateData.Endpoint,
PublicKey: updateData.PublicKey,
ServerIP: updateData.ServerIP,
ServerPort: updateData.ServerPort,
RemoteSubnets: updateData.RemoteSubnets,
}
// Update the peer in WireGuard
if dev != nil {
// Find the existing peer to get old data
var oldRemoteSubnets string
var oldPublicKey string
for _, site := range wgData.Sites {
if site.SiteId == updateData.SiteId {
oldRemoteSubnets = site.RemoteSubnets
oldPublicKey = site.PublicKey
break
}
}
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
}
// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
// Remove old remote subnet routes if they changed
if oldRemoteSubnets != siteConfig.RemoteSubnets {
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
logger.Error("Failed to remove old remote subnet routes: %v", err)
// Continue anyway to add new routes
}
// Add new remote subnet routes
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add new remote subnet routes: %v", err)
return
}
}
// Update successful
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
for i := range wgData.Sites {
if wgData.Sites[i].SiteId == updateData.SiteId {
wgData.Sites[i] = siteConfig
break
}
}
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for adding a new peer
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var addData AddPeerData
if err := json.Unmarshal(jsonData, &addData); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: addData.SiteId,
Endpoint: addData.Endpoint,
PublicKey: addData.PublicKey,
ServerIP: addData.ServerIP,
ServerPort: addData.ServerPort,
RemoteSubnets: addData.RemoteSubnets,
}
// Add the peer to WireGuard
if dev != nil {
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for new peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
// Add successful
logger.Info("Successfully added peer for site %d", addData.SiteId)
// Update WgData with the new peer
wgData.Sites = append(wgData.Sites, siteConfig)
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for removing a peer
olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) {
logger.Debug("Received remove-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeData RemovePeerData
if err := json.Unmarshal(jsonData, &removeData); err != nil {
logger.Error("Error unmarshaling remove data: %v", err)
return
}
// Find the peer to remove
var peerToRemove *SiteConfig
var newSites []SiteConfig
for _, site := range wgData.Sites {
if site.SiteId == removeData.SiteId {
peerToRemove = &site
} else {
newSites = append(newSites, site)
}
}
if peerToRemove == nil {
logger.Error("Peer with site ID %d not found", removeData.SiteId)
return
}
// Remove the peer from WireGuard
if dev != nil {
if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
logger.Error("Failed to remove peer: %v", err)
// Send error response if needed
return
}
// Remove route for the peer
err = removeRouteForServerIP(peerToRemove.ServerIP)
if err != nil {
logger.Error("Failed to remove route for peer: %v", err)
return
}
// Remove routes for remote subnets
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
logger.Error("Failed to remove routes for remote subnets: %v", err)
return
}
// Remove successful
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
// Update WgData to remove the peer
wgData.Sites = newSites
} else {
logger.Error("WireGuard device not initialized")
}
})
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
logger.Debug("Received relay-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData RelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := ResolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
logger.Info("Received no-sites message - no sites available for connection")
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
logger.Info("No sites available - stopped registration and holepunch processes")
})
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
olm.Close()
})
olm.OnConnect(func() error {
logger.Info("Websocket Connected")
apiServer.SetConnectionStatus(true)
if connected {
logger.Debug("Already connected, skipping registration")
return nil
}
publicKey := privateKey.PublicKey()
if stopRegister == nil {
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
// "doNotCreateNewClient": config.DoNotCreateNewClient,
}, 1*time.Second)
}
go keepSendingPing(olm)
return nil
})
olm.OnTokenUpdate(func(token string) {
if holePunchManager != nil {
holePunchManager.SetToken(token)
}
})
// Connect to the WebSocket server
if err := olm.Connect(); err != nil {
logger.Error("Failed to connect to server: %v", err)
return
}
defer olm.Close()
// Listen for org switch requests from the API
go func() {
for req := range apiServer.GetSwitchOrgChannel() {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Update the config with the new orgId
config.OrgID = req.OrgID
// Mark as not connected to trigger re-registration
connected = false
Stop()
// Clear peer statuses in API
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
apiServer.SetOrgID(config.OrgID)
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
}, 1*time.Second)
}
}()
// Wait for context cancellation
<-tunnelCtx.Done()
logger.Info("Tunnel process context cancelled, cleaning up")
}
func Stop() {
// Stop hole punch manager
if holePunchManager != nil {
holePunchManager.Stop()
}
if stopPing != nil {
select {
case <-stopPing:
// Channel already closed
default:
close(stopPing)
}
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
if peerMonitor != nil {
peerMonitor.Stop()
peerMonitor = nil
}
if uapiListener != nil {
uapiListener.Close()
uapiListener = nil
}
if dev != nil {
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
dev = nil
}
// Close TUN device
if tdev != nil {
tdev.Close()
tdev = nil
}
// Release the hole punch reference to the shared bind
if sharedBind != nil {
// Release hole punch reference (WireGuard already released its reference via dev.Close())
logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount())
sharedBind.Release()
sharedBind = nil
logger.Info("Released shared UDP bind")
}
logger.Info("Olm service stopped")
}
// StopTunnel stops just the tunnel process and websocket connection
// without shutting down the entire application
func StopTunnel() {
logger.Info("Stopping tunnel process")
// Cancel the tunnel context if it exists
if tunnelCancel != nil {
tunnelCancel()
// Give it a moment to clean up
time.Sleep(200 * time.Millisecond)
}
// Close the websocket connection
if olmClient != nil {
olmClient.Close()
olmClient = nil
}
Stop()
// Reset the connected state
connected = false
tunnelRunning = false
// Update API server status
apiServer.SetConnectionStatus(false)
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
logger.Info("Tunnel process stopped")
}

View File

@@ -1,6 +1,6 @@
//go:build !windows
package main
package olm
import (
"net"

View File

@@ -1,6 +1,6 @@
//go:build windows
package main
package olm
import (
"errors"

View File

@@ -8,7 +8,7 @@ import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/olm/websocket"
"github.com/fosrl/olm/wgtester"
"golang.zx2c4.com/wireguard/device"
)
@@ -205,11 +205,11 @@ func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) {
return
}
// Check for IPv6 and format the endpoint correctly
formattedEndpoint := relayEndpoint
if strings.Contains(relayEndpoint, ":") {
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
}
// Check for IPv6 and format the endpoint correctly
formattedEndpoint := relayEndpoint
if strings.Contains(relayEndpoint, ":") {
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
}
// Configure WireGuard to use the relay
wgConfig := fmt.Sprintf(`private_key=%s

634
websocket/client.go Normal file
View File

@@ -0,0 +1,634 @@
package websocket
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"software.sslmate.com/src/go-pkcs12"
"github.com/fosrl/newt/logger"
"github.com/gorilla/websocket"
)
type TokenResponse struct {
Data struct {
Token string `json:"token"`
} `json:"data"`
Success bool `json:"success"`
Message string `json:"message"`
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
}
// this is not json anymore
type Config struct {
ID string
Secret string
Endpoint string
TlsClientCert string // legacy PKCS12 file path
UserToken string // optional user token for websocket authentication
}
type Client struct {
config *Config
conn *websocket.Conn
baseURL string
handlers map[string]MessageHandler
done chan struct{}
handlersMux sync.RWMutex
reconnectInterval time.Duration
isConnected bool
reconnectMux sync.RWMutex
pingInterval time.Duration
pingTimeout time.Duration
onConnect func() error
onTokenUpdate func(token string)
writeMux sync.Mutex
clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig
configNeedsSave bool // Flag to track if config needs to be saved
}
type ClientOption func(*Client)
type MessageHandler func(message WSMessage)
// TLSConfig holds TLS configuration options
type TLSConfig struct {
// New separate certificate support
ClientCertFile string
ClientKeyFile string
CAFiles []string
// Existing PKCS12 support (deprecated)
PKCS12File string
}
// WithBaseURL sets the base URL for the client
func WithBaseURL(url string) ClientOption {
return func(c *Client) {
c.baseURL = url
}
}
// WithTLSConfig sets the TLS configuration for the client
func WithTLSConfig(config TLSConfig) ClientOption {
return func(c *Client) {
c.tlsConfig = config
// For backward compatibility, also set the legacy field
if config.PKCS12File != "" {
c.config.TlsClientCert = config.PKCS12File
}
}
}
func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
func (c *Client) OnTokenUpdate(callback func(token string)) {
c.onTokenUpdate = callback
}
// NewClient creates a new websocket client
func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{
ID: ID,
Secret: secret,
Endpoint: endpoint,
UserToken: userToken,
}
client := &Client{
config: config,
baseURL: endpoint, // default value
handlers: make(map[string]MessageHandler),
done: make(chan struct{}),
reconnectInterval: 3 * time.Second,
isConnected: false,
pingInterval: pingInterval,
pingTimeout: pingTimeout,
clientType: "olm",
}
// Apply options before loading config
for _, opt := range opts {
if opt == nil {
continue
}
opt(client)
}
return client, nil
}
func (c *Client) GetConfig() *Config {
return c.config
}
// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
go c.connectWithRetry()
return nil
}
// Close closes the WebSocket connection gracefully
func (c *Client) Close() error {
// Signal shutdown to all goroutines first
select {
case <-c.done:
// Already closed
return nil
default:
close(c.done)
}
// Set connection status to false
c.setConnected(false)
// Close the WebSocket connection gracefully
if c.conn != nil {
// Send close message
c.writeMux.Lock()
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.writeMux.Unlock()
// Close the connection
return c.conn.Close()
}
return nil
}
// SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessage(messageType string, data interface{}) error {
if c.conn == nil {
return fmt.Errorf("not connected")
}
msg := WSMessage{
Type: messageType,
Data: data,
}
logger.Debug("Sending message: %s, data: %+v", messageType, data)
c.writeMux.Lock()
defer c.writeMux.Unlock()
return c.conn.WriteJSON(msg)
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
stopChan := make(chan struct{})
go func() {
count := 0
maxAttempts := 10
err := c.SendMessage(messageType, data) // Send immediately
if err != nil {
logger.Error("Failed to send initial message: %v", err)
}
count++
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if count >= maxAttempts {
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return
}
err = c.SendMessage(messageType, data)
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case <-stopChan:
return
}
}
}()
return func() {
close(stopChan)
}
}
// RegisterHandler registers a handler for a specific message type
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
c.handlersMux.Lock()
defer c.handlersMux.Unlock()
c.handlers[messageType] = handler
}
func (c *Client) getToken() (string, error) {
// Parse the base URL to ensure we have the correct hostname
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return "", fmt.Errorf("failed to parse base URL: %w", err)
}
// Ensure we have the base URL without trailing slashes
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
var tlsConfig *tls.Config = nil
// Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
tlsConfig, err = c.setupTLS()
if err != nil {
return "", fmt.Errorf("failed to setup TLS configuration: %w", err)
}
}
// Check for environment variable to skip TLS verification
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
tlsConfig.InsecureSkipVerify = true
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
var tokenData map[string]interface{}
tokenData = map[string]interface{}{
"olmId": c.config.ID,
"secret": c.config.Secret,
}
jsonData, err := json.Marshal(tokenData)
if err != nil {
return "", fmt.Errorf("failed to marshal token request data: %w", err)
}
// Create a new request
req, err := http.NewRequest(
"POST",
baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token",
bytes.NewBuffer(jsonData),
)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// Make the request
client := &http.Client{}
if tlsConfig != nil {
client.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
}
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to request new token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
}
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("Failed to decode token response.")
return "", fmt.Errorf("failed to decode token response: %w", err)
}
if !tokenResp.Success {
return "", fmt.Errorf("failed to get token: %s", tokenResp.Message)
}
if tokenResp.Data.Token == "" {
return "", fmt.Errorf("received empty token from server")
}
logger.Debug("Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, nil
}
func (c *Client) connectWithRetry() {
for {
select {
case <-c.done:
return
default:
err := c.establishConnection()
if err != nil {
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
time.Sleep(c.reconnectInterval)
continue
}
return
}
}
}
func (c *Client) establishConnection() error {
// Get token for authentication
token, err := c.getToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
if c.onTokenUpdate != nil {
c.onTokenUpdate(token)
}
// Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return fmt.Errorf("failed to parse base URL: %w", err)
}
// Determine WebSocket protocol based on HTTP protocol
wsProtocol := "wss"
if baseURL.Scheme == "http" {
wsProtocol = "ws"
}
// Create WebSocket URL
wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host)
u, err := url.Parse(wsURL)
if err != nil {
return fmt.Errorf("failed to parse WebSocket URL: %w", err)
}
// Add token to query parameters
q := u.Query()
q.Set("token", token)
q.Set("clientType", c.clientType)
if c.config.UserToken != "" {
q.Set("userToken", c.config.UserToken)
}
u.RawQuery = q.Encode()
// Connect to WebSocket
dialer := websocket.DefaultDialer
// Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
logger.Info("Setting up TLS configuration for WebSocket connection")
tlsConfig, err := c.setupTLS()
if err != nil {
return fmt.Errorf("failed to setup TLS configuration: %w", err)
}
dialer.TLSClientConfig = tlsConfig
}
// Check for environment variable to skip TLS verification for WebSocket connection
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
if dialer.TLSClientConfig == nil {
dialer.TLSClientConfig = &tls.Config{}
}
dialer.TLSClientConfig.InsecureSkipVerify = true
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
conn, _, err := dialer.Dial(u.String(), nil)
if err != nil {
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
c.conn = conn
c.setConnected(true)
// Start the ping monitor
go c.pingMonitor()
// Start the read pump with disconnect detection
go c.readPumpWithDisconnectDetection()
if c.onConnect != nil {
if err := c.onConnect(); err != nil {
logger.Error("OnConnect callback failed: %v", err)
}
}
return nil
}
// setupTLS configures TLS based on the TLS configuration
func (c *Client) setupTLS() (*tls.Config, error) {
tlsConfig := &tls.Config{}
// Handle new separate certificate configuration
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
logger.Info("Loading separate certificate files for mTLS")
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
// Load client certificate and key
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate pair: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
// Load CA certificates for remote validation if specified
if len(c.tlsConfig.CAFiles) > 0 {
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
caCertPool := x509.NewCertPool()
for _, caFile := range c.tlsConfig.CAFiles {
caCert, err := os.ReadFile(caFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file %s: %w", caFile, err)
}
// Try to parse as PEM first, then DER
if !caCertPool.AppendCertsFromPEM(caCert) {
// If PEM parsing failed, try DER
cert, err := x509.ParseCertificate(caCert)
if err != nil {
return nil, fmt.Errorf("failed to parse CA certificate from %s: %w", caFile, err)
}
caCertPool.AddCert(cert)
}
}
tlsConfig.RootCAs = caCertPool
}
return tlsConfig, nil
}
// Fallback to existing PKCS12 implementation for backward compatibility
if c.tlsConfig.PKCS12File != "" {
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
return c.setupPKCS12TLS()
}
// Legacy fallback using config.TlsClientCert
if c.config.TlsClientCert != "" {
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
return loadClientCertificate(c.config.TlsClientCert)
}
return nil, nil
}
// setupPKCS12TLS loads TLS configuration from PKCS12 file
func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
return loadClientCertificate(c.tlsConfig.PKCS12File)
}
// pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) pingMonitor() {
ticker := time.NewTicker(c.pingInterval)
defer ticker.Stop()
for {
select {
case <-c.done:
return
case <-ticker.C:
if c.conn == nil {
return
}
c.writeMux.Lock()
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("Ping failed: %v", err)
c.reconnect()
return
}
}
}
}
}
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
func (c *Client) readPumpWithDisconnectDetection() {
defer func() {
if c.conn != nil {
c.conn.Close()
}
// Only attempt reconnect if we're not shutting down
select {
case <-c.done:
// Shutting down, don't reconnect
return
default:
c.reconnect()
}
}()
for {
select {
case <-c.done:
return
default:
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err != nil {
// Check if we're shutting down before logging error
select {
case <-c.done:
// Expected during shutdown, don't log as error
logger.Debug("WebSocket connection closed during shutdown")
return
default:
// Unexpected error during normal operation
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Error("WebSocket read error: %v", err)
} else {
logger.Debug("WebSocket connection closed: %v", err)
}
return // triggers reconnect via defer
}
}
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
handler(msg)
}
c.handlersMux.RUnlock()
}
}
}
func (c *Client) reconnect() {
c.setConnected(false)
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
// Only reconnect if we're not shutting down
select {
case <-c.done:
return
default:
go c.connectWithRetry()
}
}
func (c *Client) setConnected(status bool) {
c.reconnectMux.Lock()
defer c.reconnectMux.Unlock()
c.isConnected = status
}
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
func loadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path)
if err != nil {
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
}
// Parse PKCS12 with empty password for non-encrypted files
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
if err != nil {
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
}
// Create certificate
cert := tls.Certificate{
Certificate: [][]byte{certificate.Raw},
PrivateKey: privateKey,
}
// Optional: Add CA certificates if present
rootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
}
if len(caCerts) > 0 {
for _, caCert := range caCerts {
rootCAs.AddCert(caCert)
}
}
// Create TLS configuration
return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: rootCAs,
}, nil
}