Merge pull request #87 from fosrl/dev

1.4.0

Former-commit-id: 1212217421
This commit is contained in:
Owen Schwartz
2026-01-23 10:25:03 -08:00
committed by GitHub
29 changed files with 3073 additions and 1348 deletions

89
API.md
View File

@@ -46,7 +46,18 @@ Initiates a new connection request to a Pangolin server.
"tlsClientCert": "string",
"pingInterval": "3s",
"pingTimeout": "5s",
"orgId": "string"
"orgId": "string",
"fingerprint": {
"username": "string",
"hostname": "string",
"platform": "string",
"osVersion": "string",
"kernelVersion": "string",
"arch": "string",
"deviceModel": "string",
"serialNumber": "string"
},
"postures": {}
}
```
@@ -67,6 +78,16 @@ Initiates a new connection request to a Pangolin server.
- `pingInterval`: Interval for pinging the server (default: 3s)
- `pingTimeout`: Timeout for each ping (default: 5s)
- `orgId`: Organization ID to connect to
- `fingerprint`: Device fingerprinting information (should be set before connecting)
- `username`: Current username on the device
- `hostname`: Device hostname
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
- `osVersion`: Operating system version
- `kernelVersion`: Kernel version
- `arch`: System architecture (e.g., amd64, arm64)
- `deviceModel`: Device model identifier
- `serialNumber`: Device serial number
- `postures`: Device posture/security information
**Response:**
- **Status Code:** `202 Accepted`
@@ -205,6 +226,56 @@ Switches to a different organization while maintaining the connection.
---
### PUT /metadata
Updates device fingerprinting and posture information. This endpoint can be called at any time to update metadata, but it's recommended to provide this information in the initial `/connect` request or immediately before connecting.
**Request Body:**
```json
{
"fingerprint": {
"username": "string",
"hostname": "string",
"platform": "string",
"osVersion": "string",
"kernelVersion": "string",
"arch": "string",
"deviceModel": "string",
"serialNumber": "string"
},
"postures": {}
}
```
**Optional Fields:**
- `fingerprint`: Device fingerprinting information
- `username`: Current username on the device
- `hostname`: Device hostname
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
- `osVersion`: Operating system version
- `kernelVersion`: Kernel version
- `arch`: System architecture (e.g., amd64, arm64)
- `deviceModel`: Device model identifier
- `serialNumber`: Device serial number
- `postures`: Device posture/security information (object with arbitrary key-value pairs)
**Response:**
- **Status Code:** `200 OK`
- **Content-Type:** `application/json`
```json
{
"status": "metadata updated"
}
```
**Error Responses:**
- `405 Method Not Allowed` - Non-PUT requests
- `400 Bad Request` - Invalid JSON
**Note:** It's recommended to call this endpoint BEFORE `/connect` to ensure fingerprinting information is available during the initial connection handshake.
---
### POST /exit
Initiates a graceful shutdown of the Olm process.
@@ -247,6 +318,22 @@ Simple health check endpoint to verify the API server is running.
## Usage Examples
### Update metadata before connecting (recommended)
```bash
curl -X PUT http://localhost:9452/metadata \
-H "Content-Type: application/json" \
-d '{
"fingerprint": {
"username": "john",
"hostname": "johns-laptop",
"platform": "macos",
"osVersion": "14.2.1",
"arch": "arm64",
"deviceModel": "MacBookPro18,3"
}
}'
```
### Connect to a peer
```bash
curl -X POST http://localhost:9452/connect \

View File

@@ -5,6 +5,9 @@ all: local
local:
CGO_ENABLED=0 go build -o ./bin/olm
docker-build:
docker build -t fosrl/olm:latest .
docker-build-release:
@if [ -z "$(tag)" ]; then \
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/http"
"strconv"
"sync"
"time"
@@ -32,7 +33,12 @@ type ConnectionRequest struct {
// SwitchOrgRequest defines the structure for switching organizations
type SwitchOrgRequest struct {
OrgID string `json:"orgId"`
OrgID string `json:"org_id"`
}
// PowerModeRequest represents a request to change power mode
type PowerModeRequest struct {
Mode string `json:"mode"` // "normal" or "low"
}
// PeerStatus represents the status of a peer connection
@@ -48,11 +54,18 @@ type PeerStatus struct {
HolepunchConnected bool `json:"holepunchConnected"`
}
// OlmError holds error information from registration failures
type OlmError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// StatusResponse is returned by the status endpoint
type StatusResponse struct {
Connected bool `json:"connected"`
Registered bool `json:"registered"`
Terminated bool `json:"terminated"`
OlmError *OlmError `json:"error,omitempty"`
Version string `json:"version,omitempty"`
Agent string `json:"agent,omitempty"`
OrgID string `json:"orgId,omitempty"`
@@ -60,22 +73,34 @@ type StatusResponse struct {
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
}
type MetadataChangeRequest struct {
Fingerprint map[string]any `json:"fingerprint"`
Postures map[string]any `json:"postures"`
}
// API represents the HTTP server and its state
type API struct {
addr string
socketPath string
listener net.Listener
server *http.Server
onConnect func(ConnectionRequest) error
onSwitchOrg func(SwitchOrgRequest) error
onMetadataChange func(MetadataChangeRequest) error
onDisconnect func() error
onExit func() error
onRebind func() error
onPowerMode func(PowerModeRequest) error
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
isRegistered bool
isTerminated bool
olmError *OlmError
version string
agent string
orgID string
@@ -101,28 +126,49 @@ func NewAPISocket(socketPath string) *API {
return s
}
func NewAPIStub() *API {
s := &API{
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// SetHandlers sets the callback functions for handling API requests
func (s *API) SetHandlers(
onConnect func(ConnectionRequest) error,
onSwitchOrg func(SwitchOrgRequest) error,
onMetadataChange func(MetadataChangeRequest) error,
onDisconnect func() error,
onExit func() error,
onRebind func() error,
onPowerMode func(PowerModeRequest) error,
) {
s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg
s.onMetadataChange = onMetadataChange
s.onDisconnect = onDisconnect
s.onExit = onExit
s.onRebind = onRebind
s.onPowerMode = onPowerMode
}
// Start starts the HTTP server
func (s *API) Start() error {
if s.socketPath == "" && s.addr == "" {
return fmt.Errorf("either socketPath or addr must be provided to start the API server")
}
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/metadata", s.handleMetadataChange)
mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit)
mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/rebind", s.handleRebind)
mux.HandleFunc("/power-mode", s.handlePowerMode)
s.server = &http.Server{
Handler: mux,
@@ -160,7 +206,7 @@ func (s *API) Stop() error {
// Close the server first, which will also close the listener gracefully
if s.server != nil {
s.server.Close()
_ = s.server.Close()
}
// Clean up socket file if using Unix socket
@@ -236,6 +282,27 @@ func (s *API) SetRegistered(registered bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isRegistered = registered
// Clear any registration error when successfully registered
if registered {
s.olmError = nil
}
}
// SetOlmError sets the registration error
func (s *API) SetOlmError(code string, message string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.olmError = &OlmError{
Code: code,
Message: message,
}
}
// ClearOlmError clears any registration error
func (s *API) ClearOlmError() {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.olmError = nil
}
func (s *API) SetTerminated(terminated bool) {
@@ -345,7 +412,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "connection request accepted",
})
}
@@ -358,12 +425,12 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
}
s.statusMu.RLock()
defer s.statusMu.RUnlock()
resp := StatusResponse{
Connected: s.isConnected,
Registered: s.isRegistered,
Terminated: s.isTerminated,
OlmError: s.olmError,
Version: s.version,
Agent: s.agent,
OrgID: s.orgID,
@@ -371,8 +438,18 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
NetworkSettings: network.GetSettings(),
}
s.statusMu.RUnlock()
data, err := json.Marshal(resp)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}
// handleHealth handles the /health endpoint
@@ -384,7 +461,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "ok",
})
}
@@ -401,7 +478,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
// Return a success response first
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "shutdown initiated",
})
@@ -450,7 +527,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "org switch request accepted",
})
}
@@ -484,16 +561,43 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "disconnect initiated",
})
}
// handleMetadataChange handles the /metadata endpoint
func (s *API) handleMetadataChange(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req MetadataChangeRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
logger.Info("Received metadata change request via API: %v", req)
_ = s.onMetadataChange(req)
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "metadata updated",
})
}
func (s *API) GetStatus() StatusResponse {
return StatusResponse{
Connected: s.isConnected,
Registered: s.isRegistered,
Terminated: s.isTerminated,
OlmError: s.olmError,
Version: s.version,
Agent: s.agent,
OrgID: s.orgID,
@@ -501,3 +605,74 @@ func (s *API) GetStatus() StatusResponse {
NetworkSettings: network.GetSettings(),
}
}
// handleRebind handles the /rebind endpoint
// This triggers a socket rebind, which is necessary when network connectivity changes
// (e.g., WiFi to cellular transition on macOS/iOS) and the old socket becomes stale.
func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("Received rebind request via API")
// Call the rebind handler if set
if s.onRebind != nil {
if err := s.onRebind(); err != nil {
http.Error(w, fmt.Sprintf("Rebind failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "Rebind handler not configured", http.StatusNotImplemented)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "socket rebound successfully",
})
}
// handlePowerMode handles the /power-mode endpoint
// This allows changing the power mode between "normal" and "low"
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req PowerModeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate power mode
if req.Mode != "normal" && req.Mode != "low" {
http.Error(w, "Invalid power mode: must be 'normal' or 'low'", http.StatusBadRequest)
return
}
logger.Info("Received power mode change request via API: mode=%s", req.Mode)
// Call the power mode handler if set
if s.onPowerMode != nil {
if err := s.onPowerMode(req); err != nil {
http.Error(w, fmt.Sprintf("Power mode change failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "Power mode handler not configured", http.StatusNotImplemented)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": fmt.Sprintf("power mode changed to %s successfully", req.Mode),
})
}

View File

@@ -1,9 +1,12 @@
package device
import (
"io"
"net/netip"
"os"
"sync"
"sync/atomic"
"time"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun"
@@ -18,14 +21,68 @@ type FilterRule struct {
Handler PacketHandler
}
// MiddleDevice wraps a TUN device with packet filtering capabilities
type MiddleDevice struct {
// closeAwareDevice wraps a tun.Device along with a flag
// indicating whether its Close method was called.
type closeAwareDevice struct {
isClosed atomic.Bool
tun.Device
rules []FilterRule
mutex sync.RWMutex
readCh chan readResult
injectCh chan []byte
closed chan struct{}
closeEventCh chan struct{}
wg sync.WaitGroup
closeOnce sync.Once
}
func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice {
return &closeAwareDevice{
Device: tunDevice,
isClosed: atomic.Bool{},
closeEventCh: make(chan struct{}),
}
}
// redirectEvents redirects the Events() method of the underlying tun.Device
// to the given channel.
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
for {
select {
case ev, ok := <-c.Device.Events():
if !ok {
return
}
if ev == tun.EventDown {
continue
}
select {
case out <- ev:
case <-c.closeEventCh:
return
}
case <-c.closeEventCh:
return
}
}
}()
}
// Close calls the underlying Device's Close method
// after setting isClosed to true.
func (c *closeAwareDevice) Close() (err error) {
c.closeOnce.Do(func() {
c.isClosed.Store(true)
close(c.closeEventCh)
err = c.Device.Close()
c.wg.Wait()
})
return err
}
func (c *closeAwareDevice) IsClosed() bool {
return c.isClosed.Load()
}
type readResult struct {
@@ -36,59 +93,137 @@ type readResult struct {
err error
}
// MiddleDevice wraps a TUN device with packet filtering capabilities
// and supports swapping the underlying device.
type MiddleDevice struct {
devices []*closeAwareDevice
mu sync.Mutex
cond *sync.Cond
rules []FilterRule
rulesMutex sync.RWMutex
readCh chan readResult
injectCh chan []byte
closed atomic.Bool
events chan tun.Event
}
// NewMiddleDevice creates a new filtered TUN device wrapper
func NewMiddleDevice(device tun.Device) *MiddleDevice {
d := &MiddleDevice{
Device: device,
devices: make([]*closeAwareDevice, 0),
rules: make([]FilterRule, 0),
readCh: make(chan readResult),
readCh: make(chan readResult, 16),
injectCh: make(chan []byte, 100),
closed: make(chan struct{}),
events: make(chan tun.Event, 16),
}
go d.pump()
d.cond = sync.NewCond(&d.mu)
if device != nil {
d.AddDevice(device)
}
return d
}
func (d *MiddleDevice) pump() {
// AddDevice adds a new underlying TUN device, closing any previous one
func (d *MiddleDevice) AddDevice(device tun.Device) {
d.mu.Lock()
if d.closed.Load() {
d.mu.Unlock()
_ = device.Close()
return
}
var toClose *closeAwareDevice
if len(d.devices) > 0 {
toClose = d.devices[len(d.devices)-1]
}
cad := newCloseAwareDevice(device)
cad.redirectEvents(d.events)
d.devices = []*closeAwareDevice{cad}
// Start pump for the new device
go d.pump(cad)
d.cond.Broadcast()
d.mu.Unlock()
if toClose != nil {
logger.Debug("MiddleDevice: Closing previous device")
if err := toClose.Close(); err != nil {
logger.Debug("MiddleDevice: Error closing previous device: %v", err)
}
}
}
func (d *MiddleDevice) pump(dev *closeAwareDevice) {
const defaultOffset = 16
batchSize := d.Device.BatchSize()
logger.Debug("MiddleDevice: pump started")
batchSize := dev.BatchSize()
logger.Debug("MiddleDevice: pump started for device")
// Recover from panic if readCh is closed while we're trying to send
defer func() {
if r := recover(); r != nil {
logger.Debug("MiddleDevice: pump recovered from panic (channel closed)")
}
}()
for {
// Check closed first with priority
select {
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel")
// Check if this device is closed
if dev.IsClosed() {
logger.Debug("MiddleDevice: pump exiting, device is closed")
return
}
// Check if MiddleDevice itself is closed
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed")
return
default:
}
// Allocate buffers for reading
// We allocate new buffers for each read to avoid race conditions
// since we pass them to the channel
bufs := make([][]byte, batchSize)
sizes := make([]int, batchSize)
for i := range bufs {
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
}
n, err := d.Device.Read(bufs, sizes, defaultOffset)
n, err := dev.Read(bufs, sizes, defaultOffset)
// Check closed again after read returns
select {
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
// Check if device was closed during read
if dev.IsClosed() {
logger.Debug("MiddleDevice: pump exiting, device closed during read")
return
}
// Check if MiddleDevice was closed during read
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read")
return
}
// Try to send the result - check closed state first to avoid sending on closed channel
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, device closed before send")
return
default:
}
// Now try to send the result
select {
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
default:
// Channel full, check if we should exit
if dev.IsClosed() || d.closed.Load() {
return
}
// Try again with blocking
select {
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
case <-dev.closeEventCh:
return
}
}
if err != nil {
logger.Debug("MiddleDevice: pump exiting due to read error: %v", err)
@@ -99,16 +234,28 @@ func (d *MiddleDevice) pump() {
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
func (d *MiddleDevice) InjectOutbound(packet []byte) {
if d.closed.Load() {
return
}
// Use defer/recover to handle panic from sending on closed channel
// This can happen during shutdown race conditions
defer func() {
if r := recover(); r != nil {
logger.Debug("MiddleDevice: InjectOutbound recovered from panic (channel closed)")
}
}()
select {
case d.injectCh <- packet:
case <-d.closed:
default:
// Channel full, drop packet
logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full")
}
}
// AddRule adds a packet filtering rule
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.rulesMutex.Lock()
defer d.rulesMutex.Unlock()
d.rules = append(d.rules, FilterRule{
DestIP: destIP,
Handler: handler,
@@ -117,8 +264,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
// RemoveRule removes all rules for a given destination IP
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.rulesMutex.Lock()
defer d.rulesMutex.Unlock()
newRules := make([]FilterRule, 0, len(d.rules))
for _, rule := range d.rules {
if rule.DestIP != destIP {
@@ -130,18 +277,120 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
// Close stops the device
func (d *MiddleDevice) Close() error {
select {
case <-d.closed:
// Already closed
return nil
default:
logger.Debug("MiddleDevice: Closing, signaling closed channel")
close(d.closed)
if !d.closed.CompareAndSwap(false, true) {
return nil // already closed
}
logger.Debug("MiddleDevice: Closing underlying TUN device")
err := d.Device.Close()
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
return err
d.mu.Lock()
devices := d.devices
d.devices = nil
d.cond.Broadcast()
d.mu.Unlock()
// Close underlying devices first - this causes the pump goroutines to exit
// when their read operations return errors
var lastErr error
logger.Debug("MiddleDevice: Closing %d devices", len(devices))
for _, device := range devices {
if err := device.Close(); err != nil {
logger.Debug("MiddleDevice: Error closing device: %v", err)
lastErr = err
}
}
// Now close channels to unblock any remaining readers
// The pump should have exited by now, but close channels to be safe
close(d.readCh)
close(d.injectCh)
close(d.events)
return lastErr
}
// Events returns the events channel
func (d *MiddleDevice) Events() <-chan tun.Event {
return d.events
}
// File returns the underlying file descriptor
func (d *MiddleDevice) File() *os.File {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return nil
}
continue
}
file := dev.File()
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return file
}
}
// MTU returns the MTU of the underlying device
func (d *MiddleDevice) MTU() (int, error) {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
continue
}
mtu, err := dev.MTU()
if err == nil {
return mtu, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return 0, err
}
}
// Name returns the name of the underlying device
func (d *MiddleDevice) Name() (string, error) {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return "", io.EOF
}
continue
}
name, err := dev.Name()
if err == nil {
return name, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return "", err
}
}
// BatchSize returns the batch size
func (d *MiddleDevice) BatchSize() int {
dev := d.peekLast()
if dev == nil {
return 1
}
return dev.BatchSize()
}
// extractDestIP extracts destination IP from packet (fast path)
@@ -176,18 +425,34 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
// Read intercepts packets going UP from the TUN device (towards WireGuard)
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
// Check if already closed first (non-blocking)
select {
case <-d.closed:
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
return 0, os.ErrClosed
default:
for {
if d.closed.Load() {
logger.Debug("MiddleDevice: Read returning io.EOF, device closed")
return 0, io.EOF
}
// Now block waiting for data
// Wait for a device to be available
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
continue
}
// Now block waiting for data from readCh or injectCh
select {
case res := <-d.readCh:
case res, ok := <-d.readCh:
if !ok {
// Channel closed, device is shutting down
return 0, io.EOF
}
if res.err != nil {
// Check if device was swapped
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
return 0, res.err
}
@@ -195,19 +460,14 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
// Copy packets from result to provided buffers
count := 0
for i := 0; i < res.n && i < len(bufs); i++ {
// Handle offset mismatch if necessary
// We assume the pump used defaultOffset (16)
// If caller asks for different offset, we need to shift
src := res.bufs[i]
srcOffset := res.offset
srcSize := res.sizes[i]
// Calculate where the packet data starts and ends in src
pktData := src[srcOffset : srcOffset+srcSize]
// Ensure dest buffer is large enough
if len(bufs[i]) < offset+len(pktData) {
continue // Skip if buffer too small
continue
}
copy(bufs[i][offset:], pktData)
@@ -216,25 +476,26 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
}
n = count
case pkt := <-d.injectCh:
case pkt, ok := <-d.injectCh:
if !ok {
// Channel closed, device is shutting down
return 0, io.EOF
}
if len(bufs) == 0 {
return 0, nil
}
if len(bufs[0]) < offset+len(pkt) {
return 0, nil // Buffer too small
return 0, nil
}
copy(bufs[0][offset:], pkt)
sizes[0] = len(pkt)
n = 1
case <-d.closed:
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
return 0, os.ErrClosed // Signal that device is closed
}
d.mutex.RLock()
// Apply filtering rules
d.rulesMutex.RLock()
rules := d.rules
d.mutex.RUnlock()
d.rulesMutex.RUnlock()
if len(rules) == 0 {
return n, nil
@@ -247,7 +508,6 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
@@ -256,12 +516,10 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
continue
}
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
}
@@ -269,7 +527,6 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
}
if !handled {
// Keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
@@ -278,21 +535,34 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
}
}
return writeIdx, err
return writeIdx, nil
}
}
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
if len(rules) == 0 {
return d.Device.Write(bufs, offset)
for {
if d.closed.Load() {
return 0, io.EOF
}
// Filter packets going down
filteredBufs := make([][]byte, 0, len(bufs))
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
continue
}
d.rulesMutex.RLock()
rules := d.rules
d.rulesMutex.RUnlock()
var filteredBufs [][]byte
if len(rules) == 0 {
filteredBufs = bufs
} else {
filteredBufs = make([][]byte, 0, len(bufs))
for _, buf := range bufs {
if len(buf) <= offset {
continue
@@ -301,17 +571,14 @@ func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
packet := buf[offset:]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
filteredBufs = append(filteredBufs, buf)
continue
}
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
}
@@ -322,10 +589,75 @@ func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs = append(filteredBufs, buf)
}
}
if len(filteredBufs) == 0 {
return len(bufs), nil // All packets were handled
}
return d.Device.Write(filteredBufs, offset)
if len(filteredBufs) == 0 {
return len(bufs), nil
}
n, err := dev.Write(filteredBufs, offset)
if err == nil {
return n, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
}
func (d *MiddleDevice) waitForDevice() bool {
d.mu.Lock()
defer d.mu.Unlock()
for len(d.devices) == 0 && !d.closed.Load() {
d.cond.Wait()
}
return !d.closed.Load()
}
func (d *MiddleDevice) peekLast() *closeAwareDevice {
d.mu.Lock()
defer d.mu.Unlock()
if len(d.devices) == 0 {
return nil
}
return d.devices[len(d.devices)-1]
}
// WriteToTun writes packets directly to the underlying TUN device,
// bypassing WireGuard. This is useful for sending packets that should
// appear to come from the TUN interface (e.g., DNS responses from a proxy).
// Unlike Write(), this does not go through packet filtering rules.
func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) {
for {
if d.closed.Load() {
return 0, io.EOF
}
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
continue
}
n, err := dev.Write(bufs, offset)
if err == nil {
return n, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
}

View File

@@ -1,4 +1,4 @@
//go:build !windows
//go:build darwin
package device
@@ -26,7 +26,7 @@ func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
}
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
device, err := tun.CreateTUNFromFile(file, mtuInt)
device, err := tun.CreateTUNFromFile(file, 0)
if err != nil {
file.Close()
return nil, err

50
device/tun_linux.go Normal file
View File

@@ -0,0 +1,50 @@
//go:build linux
package device
import (
"net"
"os"
"runtime"
"github.com/fosrl/newt/logger"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
)
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
if runtime.GOOS == "android" { // otherwise we get a permission denied
theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
return theTun, err
}
dupTunFd, err := unix.Dup(int(tunFd))
if err != nil {
logger.Error("Unable to dup tun fd: %v", err)
return nil, err
}
err = unix.SetNonblock(dupTunFd, true)
if err != nil {
unix.Close(dupTunFd)
return nil, err
}
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
device, err := tun.CreateTUNFromFile(file, mtuInt)
if err != nil {
file.Close()
return nil, err
}
return device, nil
}
func UapiOpen(interfaceName string) (*os.File, error) {
return ipc.UAPIOpen(interfaceName)
}
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
return ipc.UAPIListen(interfaceName, fileUAPI)
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/device"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
@@ -36,8 +35,7 @@ type DNSProxy struct {
upstreamDNS []string
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes
recordStore *DNSRecordStore // Local DNS records
// Tunnel DNS fields - for sending queries over WireGuard
@@ -53,7 +51,7 @@ type DNSProxy struct {
}
// NewDNSProxy creates a new DNS proxy
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
proxyIP, err := PickIPFromSubnet(utilitySubnet)
if err != nil {
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
@@ -68,7 +66,6 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
proxy := &DNSProxy{
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
middleDevice: middleDevice,
upstreamDNS: upstreamDns,
tunnelDNS: tunnelDns,
@@ -602,12 +599,12 @@ func (p *DNSProxy) runTunnelPacketSender() {
defer p.wg.Done()
logger.Debug("DNS tunnel packet sender goroutine started")
ticker := time.NewTicker(1 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-p.ctx.Done():
// Use blocking ReadContext instead of polling - much more CPU efficient
// This will block until a packet is available or context is cancelled
pkt := p.tunnelEp.ReadContext(p.ctx)
if pkt == nil {
// Context was cancelled or endpoint closed
logger.Debug("DNS tunnel packet sender exiting")
// Drain any remaining packets
for {
@@ -618,12 +615,6 @@ func (p *DNSProxy) runTunnelPacketSender() {
pkt.DecRef()
}
return
case <-ticker.C:
// Try to read packets
for i := 0; i < 10; i++ {
pkt := p.tunnelEp.Read()
if pkt == nil {
break
}
// Extract packet data
@@ -647,8 +638,6 @@ func (p *DNSProxy) runTunnelPacketSender() {
pkt.DecRef()
}
}
}
}
// runPacketSender sends packets from netstack back to TUN
@@ -660,18 +649,12 @@ func (p *DNSProxy) runPacketSender() {
const offset = 16
for {
select {
case <-p.ctx.Done():
return
default:
}
// Read packets from netstack endpoint
pkt := p.ep.Read()
// Use blocking ReadContext instead of polling - much more CPU efficient
// This will block until a packet is available or context is cancelled
pkt := p.ep.ReadContext(p.ctx)
if pkt == nil {
// No packet available, small sleep to avoid busy loop
time.Sleep(1 * time.Millisecond)
continue
// Context was cancelled or endpoint closed
return
}
// Extract packet data as slices
@@ -694,9 +677,9 @@ func (p *DNSProxy) runPacketSender() {
pos += len(slice)
}
// Write packet to TUN device
// Write packet to TUN device via MiddleDevice
// offset=16 indicates packet data starts at position 16 in the buffer
_, err := p.tunDevice.Write([][]byte{buf}, offset)
_, err := p.middleDevice.WriteToTun([][]byte{buf}, offset)
if err != nil {
logger.Error("Failed to write DNS response to TUN: %v", err)
}

View File

@@ -0,0 +1,16 @@
//go:build android
package olm
import "net/netip"
// SetupDNSOverride is a no-op on Android
// Android handles DNS through the VpnService API at the Java/Kotlin layer
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
return nil
}
// RestoreDNSOverride is a no-op on Android
func RestoreDNSOverride() error {
return nil
}

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
// Uses scutil for DNS configuration
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
var err error
configurator, err = platform.NewDarwinDNSConfigurator()
if err != nil {
@@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
dnsProxy.GetProxyIP(),
proxyIp,
}
logger.Info("Setting DNS servers to: %v", newDNS)

View File

@@ -0,0 +1,15 @@
//go:build ios
package olm
import "net/netip"
// SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
return nil
}
// RestoreDNSOverride is a no-op on iOS as DNS configuration is handled by the system
func RestoreDNSOverride() error {
return nil
}

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
var err error
// Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability
@@ -32,7 +27,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using systemd-resolved DNS configurator")
return setDNS(dnsProxy, configurator)
return setDNS(proxyIp, configurator)
}
logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err)
@@ -40,7 +35,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using NetworkManager DNS configurator")
return setDNS(dnsProxy, configurator)
return setDNS(proxyIp, configurator)
}
logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err)
@@ -48,7 +43,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using resolvconf DNS configurator")
return setDNS(dnsProxy, configurator)
return setDNS(proxyIp, configurator)
}
logger.Warn("Failed to create resolvconf configurator: %v, falling back", err)
}
@@ -60,11 +55,11 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
}
logger.Info("Using file-based DNS configurator")
return setDNS(dnsProxy, configurator)
return setDNS(proxyIp, configurator)
}
// setDNS is a helper function to set DNS and log the results
func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error {
// Get current DNS servers before changing
currentDNS, err := conf.GetCurrentDNS()
if err != nil {
@@ -75,7 +70,7 @@ func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
dnsProxy.GetProxyIP(),
proxyIp,
}
logger.Info("Setting DNS servers to: %v", newDNS)

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
// Uses registry-based configuration (automatically extracts interface GUID)
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
var err error
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
if err != nil {
@@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
dnsProxy.GetProxyIP(),
proxyIp,
}
logger.Info("Setting DNS servers to: %v", newDNS)

5
go.mod
View File

@@ -4,7 +4,7 @@ go 1.25
require (
github.com/Microsoft/go-winio v0.6.2
github.com/fosrl/newt v1.8.1
github.com/fosrl/newt v1.9.0
github.com/godbus/dbus/v5 v5.2.2
github.com/gorilla/websocket v1.5.3
github.com/miekg/dns v1.1.70
@@ -30,3 +30,6 @@ require (
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
)
// To be used ONLY for local development
// replace github.com/fosrl/newt => ../newt

4
go.sum
View File

@@ -1,7 +1,7 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/fosrl/newt v1.8.1 h1:oP3xBEISoO/TENsHccqqs6LXpoOWCt6aiP75CfIWpvk=
github.com/fosrl/newt v1.8.1/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE=
github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=

13
main.go
View File

@@ -10,7 +10,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates"
"github.com/fosrl/olm/olm"
olmpkg "github.com/fosrl/olm/olm"
)
func main() {
@@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
}
// Create a new olm.Config struct and copy values from the main config
olmConfig := olm.GlobalConfig{
olmConfig := olmpkg.OlmConfig{
LogLevel: config.LogLevel,
EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr,
@@ -219,15 +219,20 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
Agent: "Olm CLI",
OnExit: cancel, // Pass cancel function directly to trigger shutdown
OnTerminated: cancel,
PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE
}
olm, err := olmpkg.Init(ctx, olmConfig)
if err != nil {
logger.Fatal("Failed to initialize olm: %v", err)
}
olm.Init(ctx, olmConfig)
if err := olm.StartApi(); err != nil {
logger.Fatal("Failed to start API server: %v", err)
}
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
tunnelConfig := olm.TunnelConfig{
tunnelConfig := olmpkg.TunnelConfig{
Endpoint: config.Endpoint,
ID: config.ID,
Secret: config.Secret,

275
olm/connect.go Normal file
View File

@@ -0,0 +1,275 @@
package olm
import (
"encoding/json"
"fmt"
"os"
"runtime"
"strconv"
"strings"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
olmDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns"
dnsOverride "github.com/fosrl/olm/dns/override"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
)
// OlmErrorData represents the error data sent from the server
type OlmErrorData struct {
Code string `json:"code"`
Message string `json:"message"`
}
func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
var wgData WgData
if o.connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
if o.stopRegister != nil {
o.stopRegister()
o.stopRegister = nil
}
if o.updateRegister != nil {
o.updateRegister = nil
}
// if there is an existing tunnel then close it
if o.dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
o.dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
o.tdev, err = func() (tun.Device, error) {
if o.tunnelConfig.FileDescriptorTun != 0 {
return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU)
}
ifName := o.tunnelConfig.InterfaceName
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
ifName, err = network.FindUnusedUTUN()
if err != nil {
return nil, err
}
}
return tun.CreateTUN(ifName, o.tunnelConfig.MTU)
}()
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
return
}
// if config.FileDescriptorTun == 0 {
if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
o.tunnelConfig.InterfaceName = realInterfaceName
}
// }
// Wrap TUN device with packet filter for DNS proxy
o.middleDev = olmDevice.NewMiddleDevice(o.tdev)
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
// Use filtered device instead of raw TUN device
o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger))
if o.tunnelConfig.EnableUAPI {
fileUAPI, err := func() (*os.File, error) {
if o.tunnelConfig.FileDescriptorUAPI != 0 {
fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err)
}
return os.NewFile(uintptr(fd), ""), nil
}
return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName)
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() {
for {
conn, err := o.uapiListener.Accept()
if err != nil {
return
}
go o.dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
}
if err = o.dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
// Extract interface IP (strip CIDR notation if present)
interfaceIP := wgData.TunnelIP
if strings.Contains(interfaceIP, "/") {
interfaceIP = strings.Split(interfaceIP, "/")[0]
}
// Create and start DNS proxy
o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP)
if err != nil {
logger.Error("Failed to create DNS proxy: %v", err)
}
if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil {
logger.Error("Failed to o.tunnelConfigure interface: %v", err)
}
if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet
logger.Error("Failed to add route for utility subnet: %v", err)
}
// Create peer manager with integrated peer monitoring
o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
Device: o.dev,
DNSProxy: o.dnsProxy,
InterfaceName: o.tunnelConfig.InterfaceName,
PrivateKey: o.privateKey,
MiddleDev: o.middleDev,
LocalIP: interfaceIP,
SharedBind: o.sharedBind,
WSClient: o.websocket,
APIServer: o.apiServer,
})
for i := range wgData.Sites {
site := wgData.Sites[i]
var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {
siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
}
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
if err := o.peerManager.AddPeer(site); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
logger.Info("Configured peer %s", site.PublicKey)
}
o.peerManager.Start()
if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime
logger.Error("Failed to start DNS proxy: %v", err)
}
if o.tunnelConfig.OverrideDNS {
// Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
logger.Error("Failed to setup DNS override: %v", err)
return
}
network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()})
}
o.apiServer.SetRegistered(true)
o.connected = true
// Start ping monitor now that we are registered and connected
o.websocket.StartPingMonitor()
// Invoke onConnected callback if configured
if o.olmConfig.OnConnected != nil {
go o.olmConfig.OnConnected()
}
logger.Info("WireGuard device created.")
}
func (o *Olm) handleOlmError(msg websocket.WSMessage) {
logger.Debug("Received olm error message: %v", msg.Data)
var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling olm error data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &errorData); err != nil {
logger.Error("Error unmarshaling olm error data: %v", err)
return
}
logger.Error("Olm error (code: %s): %s", errorData.Code, errorData.Message)
// Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
// Invoke onOlmError callback if configured
if o.olmConfig.OnOlmError != nil {
go o.olmConfig.OnOlmError(errorData.Code, errorData.Message)
}
}
func (o *Olm) handleTerminate(msg websocket.WSMessage) {
logger.Info("Received terminate message")
var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling terminate error data: %v", err)
} else {
if err := json.Unmarshal(jsonData, &errorData); err != nil {
logger.Error("Error unmarshaling terminate error data: %v", err)
} else {
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
// Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
}
}
o.apiServer.SetTerminated(true)
o.apiServer.SetConnectionStatus(false)
o.apiServer.SetRegistered(false)
o.apiServer.ClearPeerStatuses()
network.ClearNetworkSettings()
o.Close()
if o.olmConfig.OnTerminated != nil {
go o.olmConfig.OnTerminated()
}
}

347
olm/data.go Normal file
View File

@@ -0,0 +1,347 @@
package olm
import (
"encoding/json"
"time"
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
)
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var addSubnetsData peers.PeerAdd
if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil {
logger.Error("Error unmarshaling add-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId)
return
}
// Add new subnets
for _, subnet := range addSubnetsData.RemoteSubnets {
if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
}
}
// Add new aliases
for _, alias := range addSubnetsData.Aliases {
if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil {
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
}
}
}
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeSubnetsData peers.RemovePeerData
if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil {
logger.Error("Error unmarshaling remove-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId)
return
}
// Remove subnets
for _, subnet := range removeSubnetsData.RemoteSubnets {
if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
}
}
// Remove aliases
for _, alias := range removeSubnetsData.Aliases {
if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil {
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
}
}
}
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateSubnetsData peers.UpdatePeerData
if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil {
logger.Error("Error unmarshaling update-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId)
return
}
// Add new subnets BEFORE removing old ones to preserve shared subnets
// This ensures that if an old and new subnet are the same on different peers,
// the route won't be temporarily removed
for _, subnet := range updateSubnetsData.NewRemoteSubnets {
if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
}
}
// Remove old subnets after new ones are added
for _, subnet := range updateSubnetsData.OldRemoteSubnets {
if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
}
}
// Add new aliases BEFORE removing old ones to preserve shared IP addresses
// This ensures that if an old and new alias share the same IP, the IP won't be
// temporarily removed from the allowed IPs list
for _, alias := range updateSubnetsData.NewAliases {
if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil {
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
}
}
// Remove old aliases after new ones are added
for _, alias := range updateSubnetsData.OldAliases {
if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil {
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
}
}
logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId)
}
// Handler for syncing peer configuration - reconciles expected state with actual state
func (o *Olm) handleSync(msg websocket.WSMessage) {
logger.Debug("Received sync message: %v", msg.Data)
if !o.connected {
logger.Warn("Not connected, ignoring sync request")
return
}
if o.peerManager == nil {
logger.Warn("Peer manager not initialized, ignoring sync request")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
var syncData SyncData
if err := json.Unmarshal(jsonData, &syncData); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
// Sync exit nodes for hole punching
o.syncExitNodes(syncData.ExitNodes)
// Build a map of expected peers from the incoming data
expectedPeers := make(map[int]peers.SiteConfig)
for _, site := range syncData.Sites {
expectedPeers[site.SiteId] = site
}
// Get all current peers
currentPeers := o.peerManager.GetAllPeers()
currentPeerMap := make(map[int]peers.SiteConfig)
for _, peer := range currentPeers {
currentPeerMap[peer.SiteId] = peer
}
// Find peers to remove (in current but not in expected)
for siteId := range currentPeerMap {
if _, exists := expectedPeers[siteId]; !exists {
logger.Info("Sync: Removing peer for site %d (no longer in expected config)", siteId)
if err := o.peerManager.RemovePeer(siteId); err != nil {
logger.Error("Sync: Failed to remove peer %d: %v", siteId, err)
} else {
// Remove any exit nodes associated with this peer from hole punching
if o.holePunchManager != nil {
removed := o.holePunchManager.RemoveExitNodesByPeer(siteId)
if removed > 0 {
logger.Info("Sync: Removed %d exit nodes associated with peer %d from hole punch rotation", removed, siteId)
}
}
}
}
}
// Find peers to add (in expected but not in current) and peers to update
for siteId, expectedSite := range expectedPeers {
if _, exists := currentPeerMap[siteId]; !exists {
// New peer - add it using the add flow (with holepunch)
logger.Info("Sync: Adding new peer for site %d", siteId)
o.holePunchManager.TriggerHolePunch()
// // TODO: do we need to send the message to the cloud to add the peer that way?
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
// logger.Error("Sync: Failed to add peer %d: %v", siteId, err)
// } else {
// logger.Info("Sync: Successfully added peer for site %d", siteId)
// }
// add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
}, 1*time.Second, 10)
} else {
// Existing peer - check if update is needed
currentSite := currentPeerMap[siteId]
needsUpdate := false
// Check if any fields have changed
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
needsUpdate = true
}
if expectedSite.RelayEndpoint != "" && expectedSite.RelayEndpoint != currentSite.RelayEndpoint {
needsUpdate = true
}
if expectedSite.PublicKey != "" && expectedSite.PublicKey != currentSite.PublicKey {
needsUpdate = true
}
if expectedSite.ServerIP != "" && expectedSite.ServerIP != currentSite.ServerIP {
needsUpdate = true
}
if expectedSite.ServerPort != 0 && expectedSite.ServerPort != currentSite.ServerPort {
needsUpdate = true
}
// Check remote subnets
if expectedSite.RemoteSubnets != nil && !slicesEqual(expectedSite.RemoteSubnets, currentSite.RemoteSubnets) {
needsUpdate = true
}
// Check aliases
if expectedSite.Aliases != nil && !aliasesEqual(expectedSite.Aliases, currentSite.Aliases) {
needsUpdate = true
}
if needsUpdate {
logger.Info("Sync: Updating peer for site %d", siteId)
// Merge expected data with current data
siteConfig := currentSite
if expectedSite.Endpoint != "" {
siteConfig.Endpoint = expectedSite.Endpoint
}
if expectedSite.RelayEndpoint != "" {
siteConfig.RelayEndpoint = expectedSite.RelayEndpoint
}
if expectedSite.PublicKey != "" {
siteConfig.PublicKey = expectedSite.PublicKey
}
if expectedSite.ServerIP != "" {
siteConfig.ServerIP = expectedSite.ServerIP
}
if expectedSite.ServerPort != 0 {
siteConfig.ServerPort = expectedSite.ServerPort
}
if expectedSite.RemoteSubnets != nil {
siteConfig.RemoteSubnets = expectedSite.RemoteSubnets
}
if expectedSite.Aliases != nil {
siteConfig.Aliases = expectedSite.Aliases
}
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
logger.Error("Sync: Failed to update peer %d: %v", siteId, err)
} else {
// If the endpoint changed, trigger holepunch to refresh NAT mappings
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
logger.Info("Sync: Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", siteId)
o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval()
}
logger.Info("Sync: Successfully updated peer for site %d", siteId)
}
}
}
}
logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers))
}
// syncExitNodes reconciles the expected exit nodes with the current ones in the hole punch manager
func (o *Olm) syncExitNodes(expectedExitNodes []SyncExitNode) {
if o.holePunchManager == nil {
logger.Warn("Hole punch manager not initialized, skipping exit node sync")
return
}
// Build a map of expected exit nodes by endpoint
expectedExitNodeMap := make(map[string]SyncExitNode)
for _, exitNode := range expectedExitNodes {
expectedExitNodeMap[exitNode.Endpoint] = exitNode
}
// Get current exit nodes from hole punch manager
currentExitNodes := o.holePunchManager.GetExitNodes()
currentExitNodeMap := make(map[string]holepunch.ExitNode)
for _, exitNode := range currentExitNodes {
currentExitNodeMap[exitNode.Endpoint] = exitNode
}
// Find exit nodes to remove (in current but not in expected)
for endpoint := range currentExitNodeMap {
if _, exists := expectedExitNodeMap[endpoint]; !exists {
logger.Info("Sync: Removing exit node %s (no longer in expected config)", endpoint)
o.holePunchManager.RemoveExitNode(endpoint)
}
}
// Find exit nodes to add (in expected but not in current)
for endpoint, expectedExitNode := range expectedExitNodeMap {
if _, exists := currentExitNodeMap[endpoint]; !exists {
logger.Info("Sync: Adding new exit node %s", endpoint)
relayPort := expectedExitNode.RelayPort
if relayPort == 0 {
relayPort = 21820 // default relay port
}
hpExitNode := holepunch.ExitNode{
Endpoint: expectedExitNode.Endpoint,
RelayPort: relayPort,
PublicKey: expectedExitNode.PublicKey,
SiteIds: expectedExitNode.SiteIds,
}
if o.holePunchManager.AddExitNode(hpExitNode) {
logger.Info("Sync: Successfully added exit node %s", endpoint)
}
o.holePunchManager.TriggerHolePunch()
}
}
logger.Info("Sync exit nodes completed: processed %d expected exit nodes, had %d current exit nodes", len(expectedExitNodeMap), len(currentExitNodeMap))
}

1286
olm/olm.go

File diff suppressed because it is too large Load Diff

258
olm/peer.go Normal file
View File

@@ -0,0 +1,258 @@
package olm
import (
"encoding/json"
"time"
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
)
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data)
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var siteConfig peers.SiteConfig
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfig); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
}
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
logger.Debug("Received remove-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeData peers.PeerRemove
if err := json.Unmarshal(jsonData, &removeData); err != nil {
logger.Error("Error unmarshaling remove data: %v", err)
return
}
if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil {
logger.Error("Failed to remove peer: %v", err)
return
}
// Remove any exit nodes associated with this peer from hole punching
if o.holePunchManager != nil {
removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId)
if removed > 0 {
logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId)
}
}
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
}
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
logger.Debug("Received update-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateData peers.SiteConfig
if err := json.Unmarshal(jsonData, &updateData); err != nil {
logger.Error("Error unmarshaling update data: %v", err)
return
}
// Get existing peer from PeerManager
existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId)
if !exists {
logger.Warn("Peer with site ID %d not found", updateData.SiteId)
return
}
// Create updated site config by merging with existing data
siteConfig := existingPeer
if updateData.Endpoint != "" {
siteConfig.Endpoint = updateData.Endpoint
}
if updateData.RelayEndpoint != "" {
siteConfig.RelayEndpoint = updateData.RelayEndpoint
}
if updateData.PublicKey != "" {
siteConfig.PublicKey = updateData.PublicKey
}
if updateData.ServerIP != "" {
siteConfig.ServerIP = updateData.ServerIP
}
if updateData.ServerPort != 0 {
siteConfig.ServerPort = updateData.ServerPort
}
if updateData.RemoteSubnets != nil {
siteConfig.RemoteSubnets = updateData.RemoteSubnets
}
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
// If the endpoint changed, trigger holepunch to refresh NAT mappings
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
_ = o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval()
}
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
}
func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
logger.Debug("Received relay-peer message: %v", msg.Data)
// Check if peerManager is still valid (may be nil during shutdown)
if o.peerManager == nil {
logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData peers.RelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err)
return
}
// Update HTTP server to mark this peer as using relay
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true)
o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort)
}
func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
logger.Debug("Received unrelay-peer message: %v", msg.Data)
// Check if peerManager is still valid (may be nil during shutdown)
if o.peerManager == nil {
logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData peers.UnRelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false)
o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
}
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
logger.Debug("Received peer-handshake message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling handshake data: %v", err)
return
}
var handshakeData struct {
SiteId int `json:"siteId"`
ExitNode struct {
PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"`
RelayPort uint16 `json:"relayPort"`
} `json:"exitNode"`
}
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
logger.Error("Error unmarshaling handshake data: %v", err)
return
}
// Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
if exists {
logger.Warn("Peer with site ID %d already added", handshakeData.SiteId)
return
}
relayPort := handshakeData.ExitNode.RelayPort
if relayPort == 0 {
relayPort = 21820 // default relay port
}
siteId := handshakeData.SiteId
exitNode := holepunch.ExitNode{
Endpoint: handshakeData.ExitNode.Endpoint,
RelayPort: relayPort,
PublicKey: handshakeData.ExitNode.PublicKey,
SiteIds: []int{siteId},
}
added := o.holePunchManager.AddExitNode(exitNode)
if added {
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
} else {
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
}
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
}, 1*time.Second, 10)
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
}

View File

@@ -12,9 +12,22 @@ type WgData struct {
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
}
type GlobalConfig struct {
type SyncData struct {
Sites []peers.SiteConfig `json:"sites"`
ExitNodes []SyncExitNode `json:"exitNodes"`
}
type SyncExitNode struct {
Endpoint string `json:"endpoint"`
RelayPort uint16 `json:"relayPort"`
PublicKey string `json:"publicKey"`
SiteIds []int `json:"siteIds"`
}
type OlmConfig struct {
// Logging
LogLevel string
LogFilePath string
// HTTP server
EnableAPI bool
@@ -23,11 +36,17 @@ type GlobalConfig struct {
Version string
Agent string
WakeUpDebounce time.Duration
// Debugging
PprofAddr string // Address to serve pprof on (e.g., "localhost:6060")
// Callbacks
OnRegistered func()
OnConnected func()
OnTerminated func()
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
OnOlmError func(code string, message string) // Called when registration fails
OnExit func() // Called when exit is requested via API
}
@@ -63,5 +82,8 @@ type TunnelConfig struct {
OverrideDNS bool
TunnelDNS bool
InitialFingerprint map[string]any
InitialPostures map[string]any
DisableRelay bool
}

View File

@@ -1,55 +1,47 @@
package olm
import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/olm/websocket"
"github.com/fosrl/olm/peers"
)
func sendPing(olm *websocket.Client) error {
err := olm.SendMessage("olm/ping", map[string]interface{}{
"timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken,
})
if err != nil {
logger.Error("Failed to send ping message: %v", err)
return err
// slicesEqual compares two string slices for equality (order-independent)
func slicesEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
logger.Debug("Sent ping message")
return nil
// Create a map to count occurrences in slice a
counts := make(map[string]int)
for _, v := range a {
counts[v]++
}
// Check if slice b has the same elements
for _, v := range b {
counts[v]--
if counts[v] < 0 {
return false
}
}
return true
}
func keepSendingPing(olm *websocket.Client) {
// Send ping immediately on startup
if err := sendPing(olm); err != nil {
logger.Error("Failed to send initial ping: %v", err)
} else {
logger.Info("Sent initial ping message")
}
// Set up ticker for one minute intervals
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopPing:
logger.Info("Stopping ping messages")
return
case <-ticker.C:
if err := sendPing(olm); err != nil {
logger.Error("Failed to send periodic ping: %v", err)
// aliasesEqual compares two Alias slices for equality (order-independent)
func aliasesEqual(a, b []peers.Alias) bool {
if len(a) != len(b) {
return false
}
// Create a map to count occurrences in slice a (using alias+address as key)
counts := make(map[string]int)
for _, v := range a {
key := v.Alias + "|" + v.AliasAddress
counts[key]++
}
// Check if slice b has the same elements
for _, v := range b {
key := v.Alias + "|" + v.AliasAddress
counts[key]--
if counts[key] < 0 {
return false
}
}
}
func GetNetworkSettingsJSON() (string, error) {
return network.GetJSON()
}
func GetNetworkSettingsIncrementor() int {
return network.GetIncrementor()
return true
}

View File

@@ -50,6 +50,8 @@ type PeerManager struct {
// key is the CIDR string, value is a set of siteIds that want this IP
allowedIPClaims map[string]map[int]bool
APIServer *api.API
PersistentKeepalive int
}
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
@@ -84,6 +86,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
return peer, ok
}
// GetPeerMonitor returns the internal peer monitor instance
func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor {
pm.mu.RLock()
defer pm.mu.RUnlock()
return pm.peerMonitor
}
func (pm *PeerManager) GetAllPeers() []SiteConfig {
pm.mu.RLock()
defer pm.mu.RUnlock()
@@ -120,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
return err
}
@@ -159,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
return nil
}
// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once
// without recreating them. Returns a map of siteId to error for any peers that failed to update.
func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error {
pm.mu.RLock()
defer pm.mu.RUnlock()
pm.PersistentKeepalive = interval
errors := make(map[int]error)
for siteId, peer := range pm.peers {
err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval)
if err != nil {
errors[siteId] = err
}
}
if len(errors) == 0 {
return nil
}
return errors
}
func (pm *PeerManager) RemovePeer(siteId int) error {
pm.mu.Lock()
defer pm.mu.Unlock()
@@ -238,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
wgConfig := promotedPeer
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}
@@ -314,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
return err
}
@@ -324,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
promotedWgConfig := promotedPeer
promotedWgConfig.AllowedIps = promotedOwnedIPs
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}

View File

@@ -31,7 +31,6 @@ type PeerMonitor struct {
monitors map[int]*Client
mutex sync.Mutex
running bool
interval time.Duration
timeout time.Duration
maxAttempts int
wsClient *websocket.Client
@@ -42,7 +41,7 @@ type PeerMonitor struct {
stack *stack.Stack
ep *channel.Endpoint
activePorts map[uint16]bool
portsLock sync.Mutex
portsLock sync.RWMutex
nsCtx context.Context
nsCancel context.CancelFunc
nsWg sync.WaitGroup
@@ -50,17 +49,26 @@ type PeerMonitor struct {
// Holepunch testing fields
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchInterval time.Duration
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
// Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
holepunchMaxAttempts int // max consecutive failures before triggering relay
holepunchFailures map[int]int // siteID -> consecutive failure count
// Exponential backoff fields for holepunch monitor
defaultHolepunchMinInterval time.Duration // Minimum interval (initial)
defaultHolepunchMaxInterval time.Duration
holepunchMinInterval time.Duration // Minimum interval (initial)
holepunchMaxInterval time.Duration // Maximum interval (cap for backoff)
holepunchBackoffMultiplier float64 // Multiplier for each stable check
holepunchStableCount map[int]int // siteID -> consecutive stable status count
holepunchCurrentInterval time.Duration // Current interval with backoff applied
// Rapid initial test fields
rapidTestInterval time.Duration // interval between rapid test attempts
rapidTestTimeout time.Duration // timeout for each rapid test attempt
@@ -78,7 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
monitors: make(map[int]*Client),
interval: 2 * time.Second, // Default check interval (faster)
timeout: 3 * time.Second,
maxAttempts: 3,
wsClient: wsClient,
@@ -88,7 +95,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
nsCtx: ctx,
nsCancel: cancel,
sharedBind: sharedBind,
holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds
holepunchTimeout: 2 * time.Second, // Faster timeout
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
@@ -101,6 +107,15 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
apiServer: apiServer,
wgConnectionStatus: make(map[int]bool),
// Exponential backoff settings for holepunch monitor
defaultHolepunchMinInterval: 2 * time.Second,
defaultHolepunchMaxInterval: 30 * time.Second,
holepunchMinInterval: 2 * time.Second,
holepunchMaxInterval: 30 * time.Second,
holepunchBackoffMultiplier: 1.5,
holepunchStableCount: make(map[int]int),
holepunchCurrentInterval: 2 * time.Second,
holepunchUpdateChan: make(chan struct{}, 1),
}
if err := pm.initNetstack(); err != nil {
@@ -116,41 +131,75 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
}
// SetInterval changes how frequently peers are checked
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.interval = interval
// Update interval for all existing monitors
for _, client := range pm.monitors {
client.SetPacketInterval(interval)
client.SetPacketInterval(minInterval, maxInterval)
}
logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval)
}
// SetTimeout changes the timeout for waiting for responses
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
func (pm *PeerMonitor) ResetPeerInterval() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.timeout = timeout
// Update timeout for all existing monitors
// Update interval for all existing monitors
for _, client := range pm.monitors {
client.SetTimeout(timeout)
client.ResetPacketInterval()
}
}
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
pm.holepunchMinInterval = minInterval
pm.holepunchMaxInterval = maxInterval
// Reset current interval to the new minimum
pm.holepunchCurrentInterval = minInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.maxAttempts = attempts
return pm.holepunchMinInterval, pm.holepunchMaxInterval
}
// Update max attempts for all existing monitors
for _, client := range pm.monitors {
client.SetMaxAttempts(attempts)
func (pm *PeerMonitor) ResetPeerHolepunchInterval() {
pm.mutex.Lock()
pm.holepunchMinInterval = pm.defaultHolepunchMinInterval
pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval
pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
@@ -169,10 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
return err
}
client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts)
pm.monitors[siteID] = client
pm.holepunchEndpoints[siteID] = holepunchEndpoint
@@ -470,31 +515,59 @@ func (pm *PeerMonitor) stopHolepunchMonitor() {
logger.Info("Stopped holepunch connection monitor")
}
// runHolepunchMonitor runs the holepunch monitoring loop
// runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff
func (pm *PeerMonitor) runHolepunchMonitor() {
ticker := time.NewTicker(pm.holepunchInterval)
defer ticker.Stop()
pm.mutex.Lock()
pm.holepunchCurrentInterval = pm.holepunchMinInterval
pm.mutex.Unlock()
// Do initial check immediately
pm.checkHolepunchEndpoints()
timer := time.NewTimer(0) // Fire immediately for initial check
defer timer.Stop()
for {
select {
case <-pm.holepunchStopChan:
return
case <-ticker.C:
pm.checkHolepunchEndpoints()
case <-pm.holepunchUpdateChan:
// Interval settings changed, reset to minimum
pm.mutex.Lock()
pm.holepunchCurrentInterval = pm.holepunchMinInterval
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
case <-timer.C:
anyStatusChanged := pm.checkHolepunchEndpoints()
pm.mutex.Lock()
if anyStatusChanged {
// Reset to minimum interval on any status change
pm.holepunchCurrentInterval = pm.holepunchMinInterval
} else {
// Apply exponential backoff when stable
newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier)
if newInterval > pm.holepunchMaxInterval {
newInterval = pm.holepunchMaxInterval
}
pm.holepunchCurrentInterval = newInterval
}
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
}
}
}
// checkHolepunchEndpoints tests all holepunch endpoints
func (pm *PeerMonitor) checkHolepunchEndpoints() {
// Returns true if any endpoint's status changed
func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
pm.mutex.Lock()
// Check if we're still running before doing any work
if !pm.running {
pm.mutex.Unlock()
return
return false
}
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
for siteID, endpoint := range pm.holepunchEndpoints {
@@ -504,8 +577,10 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
maxAttempts := pm.holepunchMaxAttempts
pm.mutex.Unlock()
anyStatusChanged := false
for siteID, endpoint := range endpoints {
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
// logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint)
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
pm.mutex.Lock()
@@ -529,7 +604,9 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
pm.mutex.Unlock()
// Log status changes
if !exists || previousStatus != result.Success {
statusChanged := !exists || previousStatus != result.Success
if statusChanged {
anyStatusChanged = true
if result.Success {
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
} else {
@@ -562,7 +639,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
pm.mutex.Unlock()
if !stillRunning {
return // Stop processing if shutdown is in progress
return anyStatusChanged // Stop processing if shutdown is in progress
}
if !result.Success && !isRelayed && failureCount >= maxAttempts {
@@ -579,6 +656,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
}
}
}
return anyStatusChanged
}
// GetHolepunchStatus returns the current holepunch status for all endpoints
@@ -650,55 +729,55 @@ func (pm *PeerMonitor) Close() {
logger.Debug("PeerMonitor: Cleanup complete")
}
// TestPeer tests connectivity to a specific peer
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
pm.mutex.Lock()
client, exists := pm.monitors[siteID]
pm.mutex.Unlock()
// // TestPeer tests connectivity to a specific peer
// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
// pm.mutex.Lock()
// client, exists := pm.monitors[siteID]
// pm.mutex.Unlock()
if !exists {
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
}
// if !exists {
// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
// }
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
defer cancel()
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
// defer cancel()
connected, rtt := client.TestConnection(ctx)
return connected, rtt, nil
}
// connected, rtt := client.TestPeerConnection(ctx)
// return connected, rtt, nil
// }
// TestAllPeers tests connectivity to all peers
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
Connected bool
RTT time.Duration
} {
pm.mutex.Lock()
peers := make(map[int]*Client, len(pm.monitors))
for siteID, client := range pm.monitors {
peers[siteID] = client
}
pm.mutex.Unlock()
// // TestAllPeers tests connectivity to all peers
// func (pm *PeerMonitor) TestAllPeers() map[int]struct {
// Connected bool
// RTT time.Duration
// } {
// pm.mutex.Lock()
// peers := make(map[int]*Client, len(pm.monitors))
// for siteID, client := range pm.monitors {
// peers[siteID] = client
// }
// pm.mutex.Unlock()
results := make(map[int]struct {
Connected bool
RTT time.Duration
})
for siteID, client := range peers {
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
connected, rtt := client.TestConnection(ctx)
cancel()
// results := make(map[int]struct {
// Connected bool
// RTT time.Duration
// })
// for siteID, client := range peers {
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
// connected, rtt := client.TestPeerConnection(ctx)
// cancel()
results[siteID] = struct {
Connected bool
RTT time.Duration
}{
Connected: connected,
RTT: rtt,
}
}
// results[siteID] = struct {
// Connected bool
// RTT time.Duration
// }{
// Connected: connected,
// RTT: rtt,
// }
// }
return results
}
// return results
// }
// initNetstack initializes the gvisor netstack
func (pm *PeerMonitor) initNetstack() error {
@@ -770,9 +849,9 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool {
}
// Check if we are listening on this port
pm.portsLock.Lock()
pm.portsLock.RLock()
active := pm.activePorts[uint16(port)]
pm.portsLock.Unlock()
pm.portsLock.RUnlock()
if !active {
return false
@@ -803,13 +882,12 @@ func (pm *PeerMonitor) runPacketSender() {
defer pm.nsWg.Done()
logger.Debug("PeerMonitor: Packet sender goroutine started")
// Use a ticker to periodically check for packets without blocking indefinitely
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-pm.nsCtx.Done():
// Use blocking ReadContext instead of polling - much more CPU efficient
// This will block until a packet is available or context is cancelled
pkt := pm.ep.ReadContext(pm.nsCtx)
if pkt == nil {
// Context was cancelled or endpoint closed
logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets")
// Drain any remaining packets before exiting
for {
@@ -821,12 +899,6 @@ func (pm *PeerMonitor) runPacketSender() {
}
logger.Debug("PeerMonitor: Packet sender goroutine exiting")
return
case <-ticker.C:
// Try to read packets in batches
for i := 0; i < 10; i++ {
pkt := pm.ep.Read()
if pkt == nil {
break
}
// Extract packet data
@@ -850,8 +922,6 @@ func (pm *PeerMonitor) runPacketSender() {
pkt.DecRef()
}
}
}
}
// dial creates a UDP connection using the netstack

View File

@@ -32,10 +32,19 @@ type Client struct {
monitorLock sync.Mutex
connLock sync.Mutex // Protects connection operations
shutdownCh chan struct{}
updateCh chan struct{}
packetInterval time.Duration
timeout time.Duration
maxAttempts int
dialer Dialer
// Exponential backoff fields
defaultMinInterval time.Duration // Default minimum interval (initial)
defaultMaxInterval time.Duration // Default maximum interval (cap for backoff)
minInterval time.Duration // Minimum interval (initial)
maxInterval time.Duration // Maximum interval (cap for backoff)
backoffMultiplier float64 // Multiplier for each stable check
stableCountToBackoff int // Number of stable checks before backing off
}
// Dialer is a function that creates a connection
@@ -52,7 +61,14 @@ func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
return &Client{
serverAddr: serverAddr,
shutdownCh: make(chan struct{}),
updateCh: make(chan struct{}, 1),
packetInterval: 2 * time.Second,
defaultMinInterval: 2 * time.Second,
defaultMaxInterval: 30 * time.Second,
minInterval: 2 * time.Second,
maxInterval: 30 * time.Second,
backoffMultiplier: 1.5,
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
dialer: dialer,
@@ -60,18 +76,42 @@ func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
}
// SetPacketInterval changes how frequently packets are sent in monitor mode
func (c *Client) SetPacketInterval(interval time.Duration) {
c.packetInterval = interval
func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) {
c.monitorLock.Lock()
c.packetInterval = minInterval
c.minInterval = minInterval
c.maxInterval = maxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// Signal the goroutine to apply the new interval if running
if monitorRunning && updateCh != nil {
select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// SetTimeout changes the timeout for waiting for responses
func (c *Client) SetTimeout(timeout time.Duration) {
c.timeout = timeout
}
func (c *Client) ResetPacketInterval() {
c.monitorLock.Lock()
c.packetInterval = c.defaultMinInterval
c.minInterval = c.defaultMinInterval
c.maxInterval = c.defaultMaxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (c *Client) SetMaxAttempts(attempts int) {
c.maxAttempts = attempts
// Signal the goroutine to apply the new interval if running
if monitorRunning && updateCh != nil {
select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// UpdateServerAddr updates the server address and resets the connection
@@ -125,9 +165,10 @@ func (c *Client) ensureConnection() error {
return nil
}
// TestConnection checks if the connection to the server is working
// TestPeerConnection checks if the connection to the server is working
// Returns true if connected, false otherwise
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
// logger.Debug("wgtester: testing connection to peer %s", c.serverAddr)
if err := c.ensureConnection(); err != nil {
logger.Warn("Failed to ensure connection: %v", err)
return false, 0
@@ -138,6 +179,9 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
packet[4] = packetTypeRequest
// Reusable response buffer
responseBuffer := make([]byte, packetSize)
// Send multiple attempts as specified
for attempt := 0; attempt < c.maxAttempts; attempt++ {
select {
@@ -157,20 +201,17 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
return false, 0
}
// logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
_, err := c.conn.Write(packet)
if err != nil {
c.connLock.Unlock()
logger.Info("Error sending packet: %v", err)
continue
}
// logger.Debug("Successfully sent monitor packet")
// Set read deadline
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
// Wait for response
responseBuffer := make([]byte, packetSize)
n, err := c.conn.Read(responseBuffer)
c.connLock.Unlock()
@@ -211,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return c.TestConnection(ctx)
return c.TestPeerConnection(ctx)
}
// MonitorCallback is the function type for connection status change callbacks
@@ -238,28 +279,61 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
go func() {
var lastConnected bool
firstRun := true
stableCount := 0
currentInterval := c.minInterval
ticker := time.NewTicker(c.packetInterval)
defer ticker.Stop()
timer := time.NewTimer(currentInterval)
defer timer.Stop()
for {
select {
case <-c.shutdownCh:
return
case <-ticker.C:
case <-c.updateCh:
// Interval settings changed, reset to minimum
c.monitorLock.Lock()
currentInterval = c.minInterval
c.monitorLock.Unlock()
// Reset backoff state
stableCount = 0
timer.Reset(currentInterval)
logger.Debug("Packet interval updated, reset to %v", currentInterval)
case <-timer.C:
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
connected, rtt := c.TestConnection(ctx)
connected, rtt := c.TestPeerConnection(ctx)
cancel()
statusChanged := connected != lastConnected
// Callback if status changed or it's the first check
if connected != lastConnected || firstRun {
if statusChanged || firstRun {
callback(ConnectionStatus{
Connected: connected,
RTT: rtt,
})
lastConnected = connected
firstRun = false
// Reset backoff on status change
stableCount = 0
currentInterval = c.minInterval
} else {
// Status is stable, increment counter
stableCount++
// Apply exponential backoff after stable threshold
if stableCount >= c.stableCountToBackoff {
newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier)
if newInterval > c.maxInterval {
newInterval = c.maxInterval
}
currentInterval = newInterval
}
}
// Reset timer with current interval
timer.Reset(currentInterval)
}
}
}()

View File

@@ -11,7 +11,7 @@ import (
)
// ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error {
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
var endpoint string
if relay && siteConfig.RelayEndpoint != "" {
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
@@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
}
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
configBuilder.WriteString("persistent_keepalive_interval=5\n")
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive))
config := configBuilder.String()
logger.Debug("Configuring peer with config: %s", config)
@@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [
return nil
}
// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it
func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error {
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
configBuilder.WriteString("update_only=true\n")
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval))
config := configBuilder.String()
logger.Debug("Updating persistent keepalive for peer with config: %s", config)
err := dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err)
}
return nil
}
func formatEndpoint(endpoint string) string {
if strings.Contains(endpoint, ":") {
return endpoint

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -56,6 +57,7 @@ type ExitNode struct {
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
ConfigVersion int `json:"configVersion,omitempty"`
}
// this is not json anymore
@@ -77,6 +79,7 @@ type Client struct {
handlersMux sync.RWMutex
reconnectInterval time.Duration
isConnected bool
isDisconnected bool // Flag to track if client is intentionally disconnected
reconnectMux sync.RWMutex
pingInterval time.Duration
pingTimeout time.Duration
@@ -87,6 +90,19 @@ type Client struct {
clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig
configNeedsSave bool // Flag to track if config needs to be saved
configVersion int // Latest config version received from server
configVersionMux sync.RWMutex
token string // Cached authentication token
exitNodes []ExitNode // Cached exit nodes from token response
tokenMux sync.RWMutex // Protects token and exitNodes
forceNewToken bool // Flag to force fetching a new token on next connection
processingMessage bool // Flag to track if a message is currently being processed
processingMux sync.RWMutex // Protects processingMessage
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
getPingData func() map[string]any // Callback to get additional ping data
pingStarted bool // Flag to track if ping monitor has been started
pingStartedMux sync.Mutex // Protects pingStarted
pingDone chan struct{} // Channel to stop the ping monitor independently
}
type ClientOption func(*Client)
@@ -122,6 +138,13 @@ func WithTLSConfig(config TLSConfig) ClientOption {
}
}
// WithPingDataProvider sets a callback to provide additional data for ping messages
func WithPingDataProvider(fn func() map[string]any) ClientOption {
return func(c *Client) {
c.getPingData = fn
}
}
func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
@@ -154,6 +177,7 @@ func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.
pingInterval: pingInterval,
pingTimeout: pingTimeout,
clientType: "olm",
pingDone: make(chan struct{}),
}
// Apply options before loading config
@@ -173,6 +197,9 @@ func (c *Client) GetConfig() *Config {
// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
if c.isDisconnected {
c.isDisconnected = false
}
go c.connectWithRetry()
return nil
}
@@ -205,9 +232,31 @@ func (c *Client) Close() error {
return nil
}
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
func (c *Client) Disconnect() error {
c.isDisconnected = true
c.setConnected(false)
// Stop the ping monitor
c.stopPingMonitor()
// Wait for any message currently being processed to complete
c.processingWg.Wait()
if c.conn != nil {
c.writeMux.Lock()
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.writeMux.Unlock()
err := c.conn.Close()
c.conn = nil
return err
}
return nil
}
// SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessage(messageType string, data interface{}) error {
if c.conn == nil {
if c.isDisconnected || c.conn == nil {
return fmt.Errorf("not connected")
}
@@ -216,14 +265,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
Data: data,
}
logger.Debug("Sending message: %s, data: %+v", messageType, data)
logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
c.writeMux.Lock()
defer c.writeMux.Unlock()
return c.conn.WriteJSON(msg)
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
stopChan := make(chan struct{})
updateChan := make(chan interface{})
var dataMux sync.Mutex
@@ -231,30 +280,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
go func() {
count := 0
maxAttempts := 10
err := c.SendMessage(messageType, currentData) // Send immediately
send := func() {
if c.isDisconnected || c.conn == nil {
return
}
err := c.SendMessage(messageType, currentData)
if err != nil {
logger.Error("Failed to send initial message: %v", err)
logger.Error("websocket: Failed to send message: %v", err)
}
count++
}
send() // Send immediately
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if count >= maxAttempts {
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
if maxAttempts != -1 && count >= maxAttempts {
logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return
}
dataMux.Lock()
err = c.SendMessage(messageType, currentData)
send()
dataMux.Unlock()
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case newData := <-updateChan:
dataMux.Lock()
// Merge newData into currentData if both are maps
@@ -277,6 +328,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
case <-stopChan:
return
}
// Suspend sending if disconnected
for c.isDisconnected {
select {
case <-stopChan:
return
case <-time.After(500 * time.Millisecond):
}
}
}
}()
return func() {
@@ -323,7 +382,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
tlsConfig = &tls.Config{}
}
tlsConfig.InsecureSkipVerify = true
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
tokenData := map[string]interface{}{
@@ -352,7 +411,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// print out the request for debugging
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
// Make the request
client := &http.Client{}
@@ -369,7 +428,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
// Return AuthError for 401/403 status codes
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
@@ -385,7 +444,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("Failed to decode token response.")
logger.Error("websocket: Failed to decode token response.")
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
}
@@ -397,7 +456,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
return "", nil, fmt.Errorf("received empty token from server")
}
logger.Debug("Received token: %s", tokenResp.Data.Token)
logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
}
@@ -411,7 +470,8 @@ func (c *Client) connectWithRetry() {
err := c.establishConnection()
if err != nil {
// Check if this is an auth error (401/403)
if authErr, ok := err.(*AuthError); ok {
var authErr *AuthError
if errors.As(err, &authErr) {
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
// Trigger auth error callback if set (this should terminate the tunnel)
if c.onAuthError != nil {
@@ -422,7 +482,7 @@ func (c *Client) connectWithRetry() {
continue
}
// For other errors (5xx, network issues), continue retrying
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
time.Sleep(c.reconnectInterval)
continue
}
@@ -432,15 +492,25 @@ func (c *Client) connectWithRetry() {
}
func (c *Client) establishConnection() error {
// Get token for authentication
// Get token for authentication - reuse cached token unless forced to get new one
c.tokenMux.Lock()
needNewToken := c.token == "" || c.forceNewToken
if needNewToken {
token, exitNodes, err := c.getToken()
if err != nil {
c.tokenMux.Unlock()
return fmt.Errorf("failed to get token: %w", err)
}
c.token = token
c.exitNodes = exitNodes
c.forceNewToken = false
if c.onTokenUpdate != nil {
c.onTokenUpdate(token, exitNodes)
}
}
token := c.token
c.tokenMux.Unlock()
// Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL)
@@ -475,7 +545,7 @@ func (c *Client) establishConnection() error {
// Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
logger.Info("Setting up TLS configuration for WebSocket connection")
logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
tlsConfig, err := c.setupTLS()
if err != nil {
return fmt.Errorf("failed to setup TLS configuration: %w", err)
@@ -489,25 +559,38 @@ func (c *Client) establishConnection() error {
dialer.TLSClientConfig = &tls.Config{}
}
dialer.TLSClientConfig.InsecureSkipVerify = true
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
conn, _, err := dialer.Dial(u.String(), nil)
conn, resp, err := dialer.Dial(u.String(), nil)
if err != nil {
// Check if this is an unauthorized error (401)
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized")
// Force getting a new token on next reconnect attempt
c.tokenMux.Lock()
c.forceNewToken = true
c.tokenMux.Unlock()
return &AuthError{
StatusCode: http.StatusUnauthorized,
Message: "WebSocket connection unauthorized",
}
}
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
c.conn = conn
c.setConnected(true)
// Start the ping monitor
go c.pingMonitor()
// Note: ping monitor is NOT started here - it will be started when
// StartPingMonitor() is called after registration completes
// Start the read pump with disconnect detection
go c.readPumpWithDisconnectDetection()
if c.onConnect != nil {
if err := c.onConnect(); err != nil {
logger.Error("OnConnect callback failed: %v", err)
logger.Error("websocket: OnConnect callback failed: %v", err)
}
}
@@ -520,9 +603,9 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Handle new separate certificate configuration
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
logger.Info("Loading separate certificate files for mTLS")
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
logger.Info("websocket: Loading separate certificate files for mTLS")
logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
// Load client certificate and key
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
@@ -533,7 +616,7 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Load CA certificates for remote validation if specified
if len(c.tlsConfig.CAFiles) > 0 {
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
caCertPool := x509.NewCertPool()
for _, caFile := range c.tlsConfig.CAFiles {
caCert, err := os.ReadFile(caFile)
@@ -559,13 +642,13 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Fallback to existing PKCS12 implementation for backward compatibility
if c.tlsConfig.PKCS12File != "" {
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
return c.setupPKCS12TLS()
}
// Legacy fallback using config.TlsClientCert
if c.config.TlsClientCert != "" {
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
return loadClientCertificate(c.config.TlsClientCert)
}
@@ -577,6 +660,59 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
return loadClientCertificate(c.tlsConfig.PKCS12File)
}
// sendPing sends a single ping message
func (c *Client) sendPing() {
if c.isDisconnected || c.conn == nil {
return
}
// Skip ping if a message is currently being processed
c.processingMux.RLock()
isProcessing := c.processingMessage
c.processingMux.RUnlock()
if isProcessing {
logger.Debug("websocket: Skipping ping, message is being processed")
return
}
// Send application-level ping with config version
c.configVersionMux.RLock()
configVersion := c.configVersion
c.configVersionMux.RUnlock()
pingData := map[string]any{
"timestamp": time.Now().Unix(),
"userToken": c.config.UserToken,
}
if c.getPingData != nil {
for k, v := range c.getPingData() {
pingData[k] = v
}
}
pingMsg := WSMessage{
Type: "olm/ping",
Data: pingData,
ConfigVersion: configVersion,
}
logger.Debug("websocket: Sending ping: %+v", pingMsg)
c.writeMux.Lock()
err := c.conn.WriteJSON(pingMsg)
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("websocket: Ping failed: %v", err)
c.reconnect()
return
}
}
}
// pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) pingMonitor() {
ticker := time.NewTicker(c.pingInterval)
@@ -586,27 +722,63 @@ func (c *Client) pingMonitor() {
select {
case <-c.done:
return
case <-c.pingDone:
return
case <-ticker.C:
if c.conn == nil {
c.sendPing()
}
}
}
// StartPingMonitor starts the ping monitor goroutine.
// This should be called after the client is registered and connected.
// It is safe to call multiple times - only the first call will start the monitor.
func (c *Client) StartPingMonitor() {
c.pingStartedMux.Lock()
defer c.pingStartedMux.Unlock()
if c.pingStarted {
return
}
c.writeMux.Lock()
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("Ping failed: %v", err)
c.reconnect()
c.pingStarted = true
// Create a new pingDone channel for this ping monitor instance
c.pingDone = make(chan struct{})
// Send an initial ping immediately
go func() {
c.sendPing()
c.pingMonitor()
}()
}
// stopPingMonitor stops the ping monitor goroutine if it's running.
func (c *Client) stopPingMonitor() {
c.pingStartedMux.Lock()
defer c.pingStartedMux.Unlock()
if !c.pingStarted {
return
}
}
}
}
// Close the pingDone channel to stop the monitor
close(c.pingDone)
c.pingStarted = false
}
// GetConfigVersion returns the current config version
func (c *Client) GetConfigVersion() int {
c.configVersionMux.RLock()
defer c.configVersionMux.RUnlock()
return c.configVersion
}
// setConfigVersion updates the config version if the new version is higher
func (c *Client) setConfigVersion(version int) {
c.configVersionMux.Lock()
defer c.configVersionMux.Unlock()
logger.Debug("websocket: setting config version to %d", version)
c.configVersion = version
}
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
@@ -633,26 +805,47 @@ func (c *Client) readPumpWithDisconnectDetection() {
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err != nil {
// Check if we're shutting down before logging error
// Check if we're shutting down or explicitly disconnected before logging error
select {
case <-c.done:
// Expected during shutdown, don't log as error
logger.Debug("WebSocket connection closed during shutdown")
logger.Debug("websocket: connection closed during shutdown")
return
default:
// Check if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: connection closed: client was explicitly disconnected")
return
}
// Unexpected error during normal operation
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Error("WebSocket read error: %v", err)
logger.Error("websocket: read error: %v", err)
} else {
logger.Debug("WebSocket connection closed: %v", err)
logger.Debug("websocket: connection closed: %v", err)
}
return // triggers reconnect via defer
}
}
// Update config version from incoming message
c.setConfigVersion(msg.ConfigVersion)
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
// Mark that we're processing a message
c.processingMux.Lock()
c.processingMessage = true
c.processingMux.Unlock()
c.processingWg.Add(1)
handler(msg)
// Mark that we're done processing
c.processingWg.Done()
c.processingMux.Lock()
c.processingMessage = false
c.processingMux.Unlock()
}
c.handlersMux.RUnlock()
}
@@ -666,6 +859,12 @@ func (c *Client) reconnect() {
c.conn = nil
}
// Don't reconnect if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
return
}
// Only reconnect if we're not shutting down
select {
case <-c.done:
@@ -683,7 +882,7 @@ func (c *Client) setConnected(status bool) {
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
func loadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("Loading tls-client-cert %s", p12Path)
logger.Info("websocket: Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path)
if err != nil {