diff --git a/API.md b/API.md index dc2f7ff..e4a501f 100644 --- a/API.md +++ b/API.md @@ -46,7 +46,18 @@ Initiates a new connection request to a Pangolin server. "tlsClientCert": "string", "pingInterval": "3s", "pingTimeout": "5s", - "orgId": "string" + "orgId": "string", + "fingerprint": { + "username": "string", + "hostname": "string", + "platform": "string", + "osVersion": "string", + "kernelVersion": "string", + "arch": "string", + "deviceModel": "string", + "serialNumber": "string" + }, + "postures": {} } ``` @@ -67,6 +78,16 @@ Initiates a new connection request to a Pangolin server. - `pingInterval`: Interval for pinging the server (default: 3s) - `pingTimeout`: Timeout for each ping (default: 5s) - `orgId`: Organization ID to connect to +- `fingerprint`: Device fingerprinting information (should be set before connecting) + - `username`: Current username on the device + - `hostname`: Device hostname + - `platform`: Operating system platform (macos, windows, linux, ios, android, unknown) + - `osVersion`: Operating system version + - `kernelVersion`: Kernel version + - `arch`: System architecture (e.g., amd64, arm64) + - `deviceModel`: Device model identifier + - `serialNumber`: Device serial number +- `postures`: Device posture/security information **Response:** - **Status Code:** `202 Accepted` @@ -205,6 +226,56 @@ Switches to a different organization while maintaining the connection. --- +### PUT /metadata +Updates device fingerprinting and posture information. This endpoint can be called at any time to update metadata, but it's recommended to provide this information in the initial `/connect` request or immediately before connecting. + +**Request Body:** +```json +{ + "fingerprint": { + "username": "string", + "hostname": "string", + "platform": "string", + "osVersion": "string", + "kernelVersion": "string", + "arch": "string", + "deviceModel": "string", + "serialNumber": "string" + }, + "postures": {} +} +``` + +**Optional Fields:** +- `fingerprint`: Device fingerprinting information + - `username`: Current username on the device + - `hostname`: Device hostname + - `platform`: Operating system platform (macos, windows, linux, ios, android, unknown) + - `osVersion`: Operating system version + - `kernelVersion`: Kernel version + - `arch`: System architecture (e.g., amd64, arm64) + - `deviceModel`: Device model identifier + - `serialNumber`: Device serial number +- `postures`: Device posture/security information (object with arbitrary key-value pairs) + +**Response:** +- **Status Code:** `200 OK` +- **Content-Type:** `application/json` + +```json +{ + "status": "metadata updated" +} +``` + +**Error Responses:** +- `405 Method Not Allowed` - Non-PUT requests +- `400 Bad Request` - Invalid JSON + +**Note:** It's recommended to call this endpoint BEFORE `/connect` to ensure fingerprinting information is available during the initial connection handshake. + +--- + ### POST /exit Initiates a graceful shutdown of the Olm process. @@ -247,6 +318,22 @@ Simple health check endpoint to verify the API server is running. ## Usage Examples +### Update metadata before connecting (recommended) +```bash +curl -X PUT http://localhost:9452/metadata \ + -H "Content-Type: application/json" \ + -d '{ + "fingerprint": { + "username": "john", + "hostname": "johns-laptop", + "platform": "macos", + "osVersion": "14.2.1", + "arch": "arm64", + "deviceModel": "MacBookPro18,3" + } + }' +``` + ### Connect to a peer ```bash curl -X POST http://localhost:9452/connect \ diff --git a/Makefile b/Makefile index 7bb25cc..4732d29 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,9 @@ all: local local: CGO_ENABLED=0 go build -o ./bin/olm +docker-build: + docker build -t fosrl/olm:latest . + docker-build-release: @if [ -z "$(tag)" ]; then \ echo "Error: tag is required. Usage: make docker-build-release tag="; \ diff --git a/api/api.go b/api/api.go index 787f958..efd3346 100644 --- a/api/api.go +++ b/api/api.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "strconv" "sync" "time" @@ -32,7 +33,12 @@ type ConnectionRequest struct { // SwitchOrgRequest defines the structure for switching organizations type SwitchOrgRequest struct { - OrgID string `json:"orgId"` + OrgID string `json:"org_id"` +} + +// PowerModeRequest represents a request to change power mode +type PowerModeRequest struct { + Mode string `json:"mode"` // "normal" or "low" } // PeerStatus represents the status of a peer connection @@ -48,11 +54,18 @@ type PeerStatus struct { HolepunchConnected bool `json:"holepunchConnected"` } +// OlmError holds error information from registration failures +type OlmError struct { + Code string `json:"code"` + Message string `json:"message"` +} + // StatusResponse is returned by the status endpoint type StatusResponse struct { Connected bool `json:"connected"` Registered bool `json:"registered"` Terminated bool `json:"terminated"` + OlmError *OlmError `json:"error,omitempty"` Version string `json:"version,omitempty"` Agent string `json:"agent,omitempty"` OrgID string `json:"orgId,omitempty"` @@ -60,25 +73,37 @@ type StatusResponse struct { NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"` } +type MetadataChangeRequest struct { + Fingerprint map[string]any `json:"fingerprint"` + Postures map[string]any `json:"postures"` +} + // API represents the HTTP server and its state type API struct { - addr string - socketPath string - listener net.Listener - server *http.Server - onConnect func(ConnectionRequest) error - onSwitchOrg func(SwitchOrgRequest) error - onDisconnect func() error - onExit func() error + addr string + socketPath string + listener net.Listener + server *http.Server + + onConnect func(ConnectionRequest) error + onSwitchOrg func(SwitchOrgRequest) error + onMetadataChange func(MetadataChangeRequest) error + onDisconnect func() error + onExit func() error + onRebind func() error + onPowerMode func(PowerModeRequest) error + statusMu sync.RWMutex peerStatuses map[int]*PeerStatus connectedAt time.Time isConnected bool isRegistered bool isTerminated bool - version string - agent string - orgID string + olmError *OlmError + + version string + agent string + orgID string } // NewAPI creates a new HTTP server that listens on a TCP address @@ -101,28 +126,49 @@ func NewAPISocket(socketPath string) *API { return s } +func NewAPIStub() *API { + s := &API{ + peerStatuses: make(map[int]*PeerStatus), + } + + return s +} + // SetHandlers sets the callback functions for handling API requests func (s *API) SetHandlers( onConnect func(ConnectionRequest) error, onSwitchOrg func(SwitchOrgRequest) error, + onMetadataChange func(MetadataChangeRequest) error, onDisconnect func() error, onExit func() error, + onRebind func() error, + onPowerMode func(PowerModeRequest) error, ) { s.onConnect = onConnect s.onSwitchOrg = onSwitchOrg + s.onMetadataChange = onMetadataChange s.onDisconnect = onDisconnect s.onExit = onExit + s.onRebind = onRebind + s.onPowerMode = onPowerMode } // Start starts the HTTP server func (s *API) Start() error { + if s.socketPath == "" && s.addr == "" { + return fmt.Errorf("either socketPath or addr must be provided to start the API server") + } + mux := http.NewServeMux() mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/switch-org", s.handleSwitchOrg) + mux.HandleFunc("/metadata", s.handleMetadataChange) mux.HandleFunc("/disconnect", s.handleDisconnect) mux.HandleFunc("/exit", s.handleExit) mux.HandleFunc("/health", s.handleHealth) + mux.HandleFunc("/rebind", s.handleRebind) + mux.HandleFunc("/power-mode", s.handlePowerMode) s.server = &http.Server{ Handler: mux, @@ -160,7 +206,7 @@ func (s *API) Stop() error { // Close the server first, which will also close the listener gracefully if s.server != nil { - s.server.Close() + _ = s.server.Close() } // Clean up socket file if using Unix socket @@ -236,6 +282,27 @@ func (s *API) SetRegistered(registered bool) { s.statusMu.Lock() defer s.statusMu.Unlock() s.isRegistered = registered + // Clear any registration error when successfully registered + if registered { + s.olmError = nil + } +} + +// SetOlmError sets the registration error +func (s *API) SetOlmError(code string, message string) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.olmError = &OlmError{ + Code: code, + Message: message, + } +} + +// ClearOlmError clears any registration error +func (s *API) ClearOlmError() { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.olmError = nil } func (s *API) SetTerminated(terminated bool) { @@ -345,7 +412,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusAccepted) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "connection request accepted", }) } @@ -358,12 +425,12 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { } s.statusMu.RLock() - defer s.statusMu.RUnlock() resp := StatusResponse{ Connected: s.isConnected, Registered: s.isRegistered, Terminated: s.isTerminated, + OlmError: s.olmError, Version: s.version, Agent: s.agent, OrgID: s.orgID, @@ -371,8 +438,18 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { NetworkSettings: network.GetSettings(), } + s.statusMu.RUnlock() + + data, err := json.Marshal(resp) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + w.Header().Set("Content-Length", strconv.Itoa(len(data))) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(data) } // handleHealth handles the /health endpoint @@ -384,7 +461,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "ok", }) } @@ -401,7 +478,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) { // Return a success response first w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "shutdown initiated", }) @@ -450,7 +527,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "org switch request accepted", }) } @@ -484,16 +561,43 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "disconnect initiated", }) } +// handleMetadataChange handles the /metadata endpoint +func (s *API) handleMetadataChange(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req MetadataChangeRequest + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) + return + } + + logger.Info("Received metadata change request via API: %v", req) + + _ = s.onMetadataChange(req) + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "metadata updated", + }) +} + func (s *API) GetStatus() StatusResponse { return StatusResponse{ Connected: s.isConnected, Registered: s.isRegistered, Terminated: s.isTerminated, + OlmError: s.olmError, Version: s.version, Agent: s.agent, OrgID: s.orgID, @@ -501,3 +605,74 @@ func (s *API) GetStatus() StatusResponse { NetworkSettings: network.GetSettings(), } } + +// handleRebind handles the /rebind endpoint +// This triggers a socket rebind, which is necessary when network connectivity changes +// (e.g., WiFi to cellular transition on macOS/iOS) and the old socket becomes stale. +func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + logger.Info("Received rebind request via API") + + // Call the rebind handler if set + if s.onRebind != nil { + if err := s.onRebind(); err != nil { + http.Error(w, fmt.Sprintf("Rebind failed: %v", err), http.StatusInternalServerError) + return + } + } else { + http.Error(w, "Rebind handler not configured", http.StatusNotImplemented) + return + } + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "socket rebound successfully", + }) +} + +// handlePowerMode handles the /power-mode endpoint +// This allows changing the power mode between "normal" and "low" +func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req PowerModeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + // Validate power mode + if req.Mode != "normal" && req.Mode != "low" { + http.Error(w, "Invalid power mode: must be 'normal' or 'low'", http.StatusBadRequest) + return + } + + logger.Info("Received power mode change request via API: mode=%s", req.Mode) + + // Call the power mode handler if set + if s.onPowerMode != nil { + if err := s.onPowerMode(req); err != nil { + http.Error(w, fmt.Sprintf("Power mode change failed: %v", err), http.StatusInternalServerError) + return + } + } else { + http.Error(w, "Power mode handler not configured", http.StatusNotImplemented) + return + } + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": fmt.Sprintf("power mode changed to %s successfully", req.Mode), + }) +} diff --git a/device/middle_device.go b/device/middle_device.go index b031871..7dfbec8 100644 --- a/device/middle_device.go +++ b/device/middle_device.go @@ -1,9 +1,12 @@ package device import ( + "io" "net/netip" "os" "sync" + "sync/atomic" + "time" "github.com/fosrl/newt/logger" "golang.zx2c4.com/wireguard/tun" @@ -18,14 +21,68 @@ type FilterRule struct { Handler PacketHandler } -// MiddleDevice wraps a TUN device with packet filtering capabilities -type MiddleDevice struct { +// closeAwareDevice wraps a tun.Device along with a flag +// indicating whether its Close method was called. +type closeAwareDevice struct { + isClosed atomic.Bool tun.Device - rules []FilterRule - mutex sync.RWMutex - readCh chan readResult - injectCh chan []byte - closed chan struct{} + closeEventCh chan struct{} + wg sync.WaitGroup + closeOnce sync.Once +} + +func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice { + return &closeAwareDevice{ + Device: tunDevice, + isClosed: atomic.Bool{}, + closeEventCh: make(chan struct{}), + } +} + +// redirectEvents redirects the Events() method of the underlying tun.Device +// to the given channel. +func (c *closeAwareDevice) redirectEvents(out chan tun.Event) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + for { + select { + case ev, ok := <-c.Device.Events(): + if !ok { + return + } + + if ev == tun.EventDown { + continue + } + + select { + case out <- ev: + case <-c.closeEventCh: + return + } + case <-c.closeEventCh: + return + } + } + }() +} + +// Close calls the underlying Device's Close method +// after setting isClosed to true. +func (c *closeAwareDevice) Close() (err error) { + c.closeOnce.Do(func() { + c.isClosed.Store(true) + close(c.closeEventCh) + err = c.Device.Close() + c.wg.Wait() + }) + + return err +} + +func (c *closeAwareDevice) IsClosed() bool { + return c.isClosed.Load() } type readResult struct { @@ -36,58 +93,136 @@ type readResult struct { err error } +// MiddleDevice wraps a TUN device with packet filtering capabilities +// and supports swapping the underlying device. +type MiddleDevice struct { + devices []*closeAwareDevice + mu sync.Mutex + cond *sync.Cond + rules []FilterRule + rulesMutex sync.RWMutex + readCh chan readResult + injectCh chan []byte + closed atomic.Bool + events chan tun.Event +} + // NewMiddleDevice creates a new filtered TUN device wrapper func NewMiddleDevice(device tun.Device) *MiddleDevice { d := &MiddleDevice{ - Device: device, + devices: make([]*closeAwareDevice, 0), rules: make([]FilterRule, 0), - readCh: make(chan readResult), + readCh: make(chan readResult, 16), injectCh: make(chan []byte, 100), - closed: make(chan struct{}), + events: make(chan tun.Event, 16), } - go d.pump() + d.cond = sync.NewCond(&d.mu) + + if device != nil { + d.AddDevice(device) + } + return d } -func (d *MiddleDevice) pump() { +// AddDevice adds a new underlying TUN device, closing any previous one +func (d *MiddleDevice) AddDevice(device tun.Device) { + d.mu.Lock() + if d.closed.Load() { + d.mu.Unlock() + _ = device.Close() + return + } + + var toClose *closeAwareDevice + if len(d.devices) > 0 { + toClose = d.devices[len(d.devices)-1] + } + + cad := newCloseAwareDevice(device) + cad.redirectEvents(d.events) + + d.devices = []*closeAwareDevice{cad} + + // Start pump for the new device + go d.pump(cad) + + d.cond.Broadcast() + d.mu.Unlock() + + if toClose != nil { + logger.Debug("MiddleDevice: Closing previous device") + if err := toClose.Close(); err != nil { + logger.Debug("MiddleDevice: Error closing previous device: %v", err) + } + } +} + +func (d *MiddleDevice) pump(dev *closeAwareDevice) { const defaultOffset = 16 - batchSize := d.Device.BatchSize() - logger.Debug("MiddleDevice: pump started") + batchSize := dev.BatchSize() + logger.Debug("MiddleDevice: pump started for device") + + // Recover from panic if readCh is closed while we're trying to send + defer func() { + if r := recover(); r != nil { + logger.Debug("MiddleDevice: pump recovered from panic (channel closed)") + } + }() for { - // Check closed first with priority - select { - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel") + // Check if this device is closed + if dev.IsClosed() { + logger.Debug("MiddleDevice: pump exiting, device is closed") + return + } + + // Check if MiddleDevice itself is closed + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed") return - default: } // Allocate buffers for reading - // We allocate new buffers for each read to avoid race conditions - // since we pass them to the channel bufs := make([][]byte, batchSize) sizes := make([]int, batchSize) for i := range bufs { bufs[i] = make([]byte, 2048) // Standard MTU + headroom } - n, err := d.Device.Read(bufs, sizes, defaultOffset) + n, err := dev.Read(bufs, sizes, defaultOffset) - // Check closed again after read returns - select { - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)") + // Check if device was closed during read + if dev.IsClosed() { + logger.Debug("MiddleDevice: pump exiting, device closed during read") + return + } + + // Check if MiddleDevice was closed during read + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read") + return + } + + // Try to send the result - check closed state first to avoid sending on closed channel + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, device closed before send") return - default: } - // Now try to send the result select { case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)") - return + default: + // Channel full, check if we should exit + if dev.IsClosed() || d.closed.Load() { + return + } + // Try again with blocking + select { + case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: + case <-dev.closeEventCh: + return + } } if err != nil { @@ -99,16 +234,28 @@ func (d *MiddleDevice) pump() { // InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN) func (d *MiddleDevice) InjectOutbound(packet []byte) { + if d.closed.Load() { + return + } + // Use defer/recover to handle panic from sending on closed channel + // This can happen during shutdown race conditions + defer func() { + if r := recover(); r != nil { + logger.Debug("MiddleDevice: InjectOutbound recovered from panic (channel closed)") + } + }() select { case d.injectCh <- packet: - case <-d.closed: + default: + // Channel full, drop packet + logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full") } } // AddRule adds a packet filtering rule func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) { - d.mutex.Lock() - defer d.mutex.Unlock() + d.rulesMutex.Lock() + defer d.rulesMutex.Unlock() d.rules = append(d.rules, FilterRule{ DestIP: destIP, Handler: handler, @@ -117,8 +264,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) { // RemoveRule removes all rules for a given destination IP func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { - d.mutex.Lock() - defer d.mutex.Unlock() + d.rulesMutex.Lock() + defer d.rulesMutex.Unlock() newRules := make([]FilterRule, 0, len(d.rules)) for _, rule := range d.rules { if rule.DestIP != destIP { @@ -130,18 +277,120 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { // Close stops the device func (d *MiddleDevice) Close() error { - select { - case <-d.closed: - // Already closed - return nil - default: - logger.Debug("MiddleDevice: Closing, signaling closed channel") - close(d.closed) + if !d.closed.CompareAndSwap(false, true) { + return nil // already closed } - logger.Debug("MiddleDevice: Closing underlying TUN device") - err := d.Device.Close() - logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err) - return err + + d.mu.Lock() + devices := d.devices + d.devices = nil + d.cond.Broadcast() + d.mu.Unlock() + + // Close underlying devices first - this causes the pump goroutines to exit + // when their read operations return errors + var lastErr error + logger.Debug("MiddleDevice: Closing %d devices", len(devices)) + for _, device := range devices { + if err := device.Close(); err != nil { + logger.Debug("MiddleDevice: Error closing device: %v", err) + lastErr = err + } + } + + // Now close channels to unblock any remaining readers + // The pump should have exited by now, but close channels to be safe + close(d.readCh) + close(d.injectCh) + close(d.events) + + return lastErr +} + +// Events returns the events channel +func (d *MiddleDevice) Events() <-chan tun.Event { + return d.events +} + +// File returns the underlying file descriptor +func (d *MiddleDevice) File() *os.File { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return nil + } + continue + } + + file := dev.File() + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return file + } +} + +// MTU returns the MTU of the underlying device +func (d *MiddleDevice) MTU() (int, error) { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } + continue + } + + mtu, err := dev.MTU() + if err == nil { + return mtu, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return 0, err + } +} + +// Name returns the name of the underlying device +func (d *MiddleDevice) Name() (string, error) { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return "", io.EOF + } + continue + } + + name, err := dev.Name() + if err == nil { + return name, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return "", err + } +} + +// BatchSize returns the batch size +func (d *MiddleDevice) BatchSize() int { + dev := d.peekLast() + if dev == nil { + return 1 + } + return dev.BatchSize() } // extractDestIP extracts destination IP from packet (fast path) @@ -176,156 +425,239 @@ func extractDestIP(packet []byte) (netip.Addr, bool) { // Read intercepts packets going UP from the TUN device (towards WireGuard) func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { - // Check if already closed first (non-blocking) - select { - case <-d.closed: - logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)") - return 0, os.ErrClosed - default: - } - - // Now block waiting for data - select { - case res := <-d.readCh: - if res.err != nil { - logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err) - return 0, res.err + for { + if d.closed.Load() { + logger.Debug("MiddleDevice: Read returning io.EOF, device closed") + return 0, io.EOF } - // Copy packets from result to provided buffers - count := 0 - for i := 0; i < res.n && i < len(bufs); i++ { - // Handle offset mismatch if necessary - // We assume the pump used defaultOffset (16) - // If caller asks for different offset, we need to shift - src := res.bufs[i] - srcOffset := res.offset - srcSize := res.sizes[i] - - // Calculate where the packet data starts and ends in src - pktData := src[srcOffset : srcOffset+srcSize] - - // Ensure dest buffer is large enough - if len(bufs[i]) < offset+len(pktData) { - continue // Skip if buffer too small + // Wait for a device to be available + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF } - - copy(bufs[i][offset:], pktData) - sizes[i] = len(pktData) - count++ - } - n = count - - case pkt := <-d.injectCh: - if len(bufs) == 0 { - return 0, nil - } - if len(bufs[0]) < offset+len(pkt) { - return 0, nil // Buffer too small - } - copy(bufs[0][offset:], pkt) - sizes[0] = len(pkt) - n = 1 - - case <-d.closed: - logger.Debug("MiddleDevice: Read returning os.ErrClosed") - return 0, os.ErrClosed // Signal that device is closed - } - - d.mutex.RLock() - rules := d.rules - d.mutex.RUnlock() - - if len(rules) == 0 { - return n, nil - } - - // Process packets and filter out handled ones - writeIdx := 0 - for readIdx := 0; readIdx < n; readIdx++ { - packet := bufs[readIdx][offset : offset+sizes[readIdx]] - - destIP, ok := extractDestIP(packet) - if !ok { - // Can't parse, keep packet - if writeIdx != readIdx { - bufs[writeIdx] = bufs[readIdx] - sizes[writeIdx] = sizes[readIdx] - } - writeIdx++ continue } - // Check if packet matches any rule - handled := false - for _, rule := range rules { - if rule.DestIP == destIP { - if rule.Handler(packet) { - // Packet was handled and should be dropped - handled = true - break + // Now block waiting for data from readCh or injectCh + select { + case res, ok := <-d.readCh: + if !ok { + // Channel closed, device is shutting down + return 0, io.EOF + } + if res.err != nil { + // Check if device was swapped + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err) + return 0, res.err + } + + // Copy packets from result to provided buffers + count := 0 + for i := 0; i < res.n && i < len(bufs); i++ { + src := res.bufs[i] + srcOffset := res.offset + srcSize := res.sizes[i] + + pktData := src[srcOffset : srcOffset+srcSize] + + if len(bufs[i]) < offset+len(pktData) { + continue + } + + copy(bufs[i][offset:], pktData) + sizes[i] = len(pktData) + count++ + } + n = count + + case pkt, ok := <-d.injectCh: + if !ok { + // Channel closed, device is shutting down + return 0, io.EOF + } + if len(bufs) == 0 { + return 0, nil + } + if len(bufs[0]) < offset+len(pkt) { + return 0, nil + } + copy(bufs[0][offset:], pkt) + sizes[0] = len(pkt) + n = 1 + } + + // Apply filtering rules + d.rulesMutex.RLock() + rules := d.rules + d.rulesMutex.RUnlock() + + if len(rules) == 0 { + return n, nil + } + + // Process packets and filter out handled ones + writeIdx := 0 + for readIdx := 0; readIdx < n; readIdx++ { + packet := bufs[readIdx][offset : offset+sizes[readIdx]] + + destIP, ok := extractDestIP(packet) + if !ok { + if writeIdx != readIdx { + bufs[writeIdx] = bufs[readIdx] + sizes[writeIdx] = sizes[readIdx] + } + writeIdx++ + continue + } + + handled := false + for _, rule := range rules { + if rule.DestIP == destIP { + if rule.Handler(packet) { + handled = true + break + } } } - } - if !handled { - // Keep packet - if writeIdx != readIdx { - bufs[writeIdx] = bufs[readIdx] - sizes[writeIdx] = sizes[readIdx] + if !handled { + if writeIdx != readIdx { + bufs[writeIdx] = bufs[readIdx] + sizes[writeIdx] = sizes[readIdx] + } + writeIdx++ } - writeIdx++ } - } - return writeIdx, err + return writeIdx, nil + } } // Write intercepts packets going DOWN to the TUN device (from WireGuard) func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) { - d.mutex.RLock() - rules := d.rules - d.mutex.RUnlock() + for { + if d.closed.Load() { + return 0, io.EOF + } - if len(rules) == 0 { - return d.Device.Write(bufs, offset) - } - - // Filter packets going down - filteredBufs := make([][]byte, 0, len(bufs)) - for _, buf := range bufs { - if len(buf) <= offset { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } continue } - packet := buf[offset:] - destIP, ok := extractDestIP(packet) - if !ok { - // Can't parse, keep packet - filteredBufs = append(filteredBufs, buf) - continue - } + d.rulesMutex.RLock() + rules := d.rules + d.rulesMutex.RUnlock() - // Check if packet matches any rule - handled := false - for _, rule := range rules { - if rule.DestIP == destIP { - if rule.Handler(packet) { - // Packet was handled and should be dropped - handled = true - break + var filteredBufs [][]byte + if len(rules) == 0 { + filteredBufs = bufs + } else { + filteredBufs = make([][]byte, 0, len(bufs)) + for _, buf := range bufs { + if len(buf) <= offset { + continue + } + + packet := buf[offset:] + destIP, ok := extractDestIP(packet) + if !ok { + filteredBufs = append(filteredBufs, buf) + continue + } + + handled := false + for _, rule := range rules { + if rule.DestIP == destIP { + if rule.Handler(packet) { + handled = true + break + } + } + } + + if !handled { + filteredBufs = append(filteredBufs, buf) } } } - if !handled { - filteredBufs = append(filteredBufs, buf) + if len(filteredBufs) == 0 { + return len(bufs), nil } - } - if len(filteredBufs) == 0 { - return len(bufs), nil // All packets were handled - } + n, err := dev.Write(filteredBufs, offset) + if err == nil { + return n, nil + } - return d.Device.Write(filteredBufs, offset) + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } } + +func (d *MiddleDevice) waitForDevice() bool { + d.mu.Lock() + defer d.mu.Unlock() + + for len(d.devices) == 0 && !d.closed.Load() { + d.cond.Wait() + } + return !d.closed.Load() +} + +func (d *MiddleDevice) peekLast() *closeAwareDevice { + d.mu.Lock() + defer d.mu.Unlock() + + if len(d.devices) == 0 { + return nil + } + + return d.devices[len(d.devices)-1] +} + +// WriteToTun writes packets directly to the underlying TUN device, +// bypassing WireGuard. This is useful for sending packets that should +// appear to come from the TUN interface (e.g., DNS responses from a proxy). +// Unlike Write(), this does not go through packet filtering rules. +func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) { + for { + if d.closed.Load() { + return 0, io.EOF + } + + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } + continue + } + + n, err := dev.Write(bufs, offset) + if err == nil { + return n, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } +} \ No newline at end of file diff --git a/device/tun_unix.go b/device/tun_darwin.go similarity index 91% rename from device/tun_unix.go rename to device/tun_darwin.go index c9bab60..df87d53 100644 --- a/device/tun_unix.go +++ b/device/tun_darwin.go @@ -1,4 +1,4 @@ -//go:build !windows +//go:build darwin package device @@ -26,7 +26,7 @@ func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { } file := os.NewFile(uintptr(dupTunFd), "/dev/tun") - device, err := tun.CreateTUNFromFile(file, mtuInt) + device, err := tun.CreateTUNFromFile(file, 0) if err != nil { file.Close() return nil, err diff --git a/device/tun_linux.go b/device/tun_linux.go new file mode 100644 index 0000000..902f269 --- /dev/null +++ b/device/tun_linux.go @@ -0,0 +1,50 @@ +//go:build linux + +package device + +import ( + "net" + "os" + "runtime" + + "github.com/fosrl/newt/logger" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/tun" +) + +func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { + if runtime.GOOS == "android" { // otherwise we get a permission denied + theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd)) + return theTun, err + } + + dupTunFd, err := unix.Dup(int(tunFd)) + if err != nil { + logger.Error("Unable to dup tun fd: %v", err) + return nil, err + } + + err = unix.SetNonblock(dupTunFd, true) + if err != nil { + unix.Close(dupTunFd) + return nil, err + } + + file := os.NewFile(uintptr(dupTunFd), "/dev/tun") + device, err := tun.CreateTUNFromFile(file, mtuInt) + if err != nil { + file.Close() + return nil, err + } + + return device, nil +} + +func UapiOpen(interfaceName string) (*os.File, error) { + return ipc.UAPIOpen(interfaceName) +} + +func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { + return ipc.UAPIListen(interfaceName, fileUAPI) +} diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 6d56379..f65e923 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -12,7 +12,6 @@ import ( "github.com/fosrl/newt/util" "github.com/fosrl/olm/device" "github.com/miekg/dns" - "golang.zx2c4.com/wireguard/tun" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" @@ -34,18 +33,17 @@ type DNSProxy struct { ep *channel.Endpoint proxyIP netip.Addr upstreamDNS []string - tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally + tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally mtu int - tunDevice tun.Device // Direct reference to underlying TUN device for responses - middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering + middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes recordStore *DNSRecordStore // Local DNS records // Tunnel DNS fields - for sending queries over WireGuard - tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries) - tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries - tunnelEp *channel.Endpoint + tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries) + tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries + tunnelEp *channel.Endpoint tunnelActivePorts map[uint16]bool - tunnelPortsLock sync.Mutex + tunnelPortsLock sync.Mutex ctx context.Context cancel context.CancelFunc @@ -53,7 +51,7 @@ type DNSProxy struct { } // NewDNSProxy creates a new DNS proxy -func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) { +func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) { proxyIP, err := PickIPFromSubnet(utilitySubnet) if err != nil { return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err) @@ -68,7 +66,6 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in proxy := &DNSProxy{ proxyIP: proxyIP, mtu: mtu, - tunDevice: tunDevice, middleDevice: middleDevice, upstreamDNS: upstreamDns, tunnelDNS: tunnelDns, @@ -602,12 +599,12 @@ func (p *DNSProxy) runTunnelPacketSender() { defer p.wg.Done() logger.Debug("DNS tunnel packet sender goroutine started") - ticker := time.NewTicker(1 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-p.ctx.Done(): + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := p.tunnelEp.ReadContext(p.ctx) + if pkt == nil { + // Context was cancelled or endpoint closed logger.Debug("DNS tunnel packet sender exiting") // Drain any remaining packets for { @@ -618,36 +615,28 @@ func (p *DNSProxy) runTunnelPacketSender() { pkt.DecRef() } return - case <-ticker.C: - // Try to read packets - for i := 0; i < 10; i++ { - pkt := p.tunnelEp.Read() - if pkt == nil { - break - } - - // Extract packet data - slices := pkt.AsSlices() - if len(slices) > 0 { - var totalSize int - for _, slice := range slices { - totalSize += len(slice) - } - - buf := make([]byte, totalSize) - pos := 0 - for _, slice := range slices { - copy(buf[pos:], slice) - pos += len(slice) - } - - // Inject into MiddleDevice (outbound to WG) - p.middleDevice.InjectOutbound(buf) - } - - pkt.DecRef() - } } + + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + p.middleDevice.InjectOutbound(buf) + } + + pkt.DecRef() } } @@ -660,18 +649,12 @@ func (p *DNSProxy) runPacketSender() { const offset = 16 for { - select { - case <-p.ctx.Done(): - return - default: - } - - // Read packets from netstack endpoint - pkt := p.ep.Read() + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := p.ep.ReadContext(p.ctx) if pkt == nil { - // No packet available, small sleep to avoid busy loop - time.Sleep(1 * time.Millisecond) - continue + // Context was cancelled or endpoint closed + return } // Extract packet data as slices @@ -694,9 +677,9 @@ func (p *DNSProxy) runPacketSender() { pos += len(slice) } - // Write packet to TUN device + // Write packet to TUN device via MiddleDevice // offset=16 indicates packet data starts at position 16 in the buffer - _, err := p.tunDevice.Write([][]byte{buf}, offset) + _, err := p.middleDevice.WriteToTun([][]byte{buf}, offset) if err != nil { logger.Error("Failed to write DNS response to TUN: %v", err) } diff --git a/dns/dns_records.go b/dns/dns_records.go index ed57b77..5308b0e 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -322,4 +322,4 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool { } return matchWildcardInternal(pattern, domain, pi+1, di+1) -} \ No newline at end of file +} diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go index 0bb18a1..f922afb 100644 --- a/dns/dns_records_test.go +++ b/dns/dns_records_test.go @@ -37,7 +37,7 @@ func TestWildcardMatching(t *testing.T) { domain: "autoco.internal.", expected: false, }, - + // Question mark wildcard tests { name: "host-0?.autoco.internal matches host-01.autoco.internal", @@ -63,7 +63,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-012.autoco.internal.", expected: false, }, - + // Combined wildcard tests { name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal", @@ -83,7 +83,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-01.autoco.internal.", expected: false, }, - + // Multiple asterisks { name: "*.*. autoco.internal matches any.thing.autoco.internal", @@ -97,7 +97,7 @@ func TestWildcardMatching(t *testing.T) { domain: "single.autoco.internal.", expected: false, }, - + // Asterisk in middle { name: "host-*.autoco.internal matches host-anything.autoco.internal", @@ -111,7 +111,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-.autoco.internal.", expected: true, }, - + // Multiple question marks { name: "host-??.autoco.internal matches host-01.autoco.internal", @@ -125,7 +125,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-1.autoco.internal.", expected: false, }, - + // Exact match (no wildcards) { name: "exact.autoco.internal matches exact.autoco.internal", @@ -139,7 +139,7 @@ func TestWildcardMatching(t *testing.T) { domain: "other.autoco.internal.", expected: false, }, - + // Edge cases { name: "* matches anything", @@ -154,7 +154,7 @@ func TestWildcardMatching(t *testing.T) { expected: true, }, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := matchWildcard(tt.pattern, tt.domain) @@ -167,21 +167,21 @@ func TestWildcardMatching(t *testing.T) { func TestDNSRecordStoreWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard records wildcardIP := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", wildcardIP) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Add exact record exactIP := net.ParseIP("10.0.0.2") err = store.AddRecord("exact.autoco.internal", exactIP) if err != nil { t.Fatalf("Failed to add exact record: %v", err) } - + // Test exact match takes precedence ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -190,7 +190,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) { if !ips[0].Equal(exactIP) { t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) } - + // Test wildcard match ips = store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -199,7 +199,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) { if !ips[0].Equal(wildcardIP) { t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) } - + // Test non-match (base domain) ips = store.GetRecords("autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -209,14 +209,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) { func TestDNSRecordStoreComplexWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add complex wildcard pattern ip1 := net.ParseIP("10.0.0.1") err := store.AddRecord("*.host-0?.autoco.internal", ip1) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Test matching domain ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -225,13 +225,13 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { if len(ips) > 0 && !ips[0].Equal(ip1) { t.Errorf("Expected IP %v, got %v", ip1, ips[0]) } - + // Test non-matching domain (missing prefix) ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) } - + // Test non-matching domain (wrong ? position) ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -241,23 +241,23 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { func TestDNSRecordStoreRemoveWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard record ip := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Verify it exists ips := store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP before removal, got %d", len(ips)) } - + // Remove wildcard record store.RemoveRecord("*.autoco.internal", nil) - + // Verify it's gone ips = store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -267,40 +267,40 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { func TestDNSRecordStoreMultipleWildcards(t *testing.T) { store := NewDNSRecordStore() - + // Add multiple wildcard patterns that don't overlap ip1 := net.ParseIP("10.0.0.1") ip2 := net.ParseIP("10.0.0.2") ip3 := net.ParseIP("10.0.0.3") - + err := store.AddRecord("*.prod.autoco.internal", ip1) if err != nil { t.Fatalf("Failed to add first wildcard: %v", err) } - + err = store.AddRecord("*.dev.autoco.internal", ip2) if err != nil { t.Fatalf("Failed to add second wildcard: %v", err) } - + // Add a broader wildcard that matches both err = store.AddRecord("*.autoco.internal", ip3) if err != nil { t.Fatalf("Failed to add third wildcard: %v", err) } - + // Test domain matching only the prod pattern and the broad pattern ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) } - + // Test domain matching only the dev pattern and the broad pattern ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) } - + // Test domain matching only the broad pattern ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -310,14 +310,14 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) { func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add IPv6 wildcard record ip := net.ParseIP("2001:db8::1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add IPv6 wildcard record: %v", err) } - + // Test wildcard match for IPv6 ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) if len(ips) != 1 { @@ -330,21 +330,21 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { func TestHasRecordWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard record ip := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Test HasRecord with wildcard match if !store.HasRecord("host.autoco.internal.", RecordTypeA) { t.Error("Expected HasRecord to return true for wildcard match") } - + // Test HasRecord with non-match if store.HasRecord("autoco.internal.", RecordTypeA) { t.Error("Expected HasRecord to return false for base domain") } -} \ No newline at end of file +} diff --git a/dns/override/dns_override_android.go b/dns/override/dns_override_android.go new file mode 100644 index 0000000..d3fd78e --- /dev/null +++ b/dns/override/dns_override_android.go @@ -0,0 +1,16 @@ +//go:build android + +package olm + +import "net/netip" + +// SetupDNSOverride is a no-op on Android +// Android handles DNS through the VpnService API at the Java/Kotlin layer +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { + return nil +} + +// RestoreDNSOverride is a no-op on Android +func RestoreDNSOverride() error { + return nil +} \ No newline at end of file diff --git a/dns/override/dns_override_darwin.go b/dns/override/dns_override_darwin.go index 6ccc3fb..c1c3789 100644 --- a/dns/override/dns_override_darwin.go +++ b/dns/override/dns_override_darwin.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on macOS // Uses scutil for DNS configuration -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error configurator, err = platform.NewDarwinDNSConfigurator() if err != nil { @@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/dns/override/dns_override_ios.go b/dns/override/dns_override_ios.go new file mode 100644 index 0000000..6c95c71 --- /dev/null +++ b/dns/override/dns_override_ios.go @@ -0,0 +1,15 @@ +//go:build ios + +package olm + +import "net/netip" + +// SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { + return nil +} + +// RestoreDNSOverride is a no-op on iOS as DNS configuration is handled by the system +func RestoreDNSOverride() error { + return nil +} \ No newline at end of file diff --git a/dns/override/dns_override_unix.go b/dns/override/dns_override_unix.go index c3b31e8..12cb692 100644 --- a/dns/override/dns_override_unix.go +++ b/dns/override/dns_override_unix.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD // Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error // Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability @@ -32,7 +27,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName) if err == nil { logger.Info("Using systemd-resolved DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err) @@ -40,7 +35,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName) if err == nil { logger.Info("Using NetworkManager DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err) @@ -48,7 +43,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName) if err == nil { logger.Info("Using resolvconf DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create resolvconf configurator: %v, falling back", err) } @@ -60,11 +55,11 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { } logger.Info("Using file-based DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } // setDNS is a helper function to set DNS and log the results -func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error { +func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error { // Get current DNS servers before changing currentDNS, err := conf.GetCurrentDNS() if err != nil { @@ -75,7 +70,7 @@ func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/dns/override/dns_override_windows.go b/dns/override/dns_override_windows.go index a564079..16bbca1 100644 --- a/dns/override/dns_override_windows.go +++ b/dns/override/dns_override_windows.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on Windows // Uses registry-based configuration (automatically extracts interface GUID) -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error configurator, err = platform.NewWindowsDNSConfigurator(interfaceName) if err != nil { @@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/dns/platform/darwin.go b/dns/platform/darwin.go index 61cc81b..8054c57 100644 --- a/dns/platform/darwin.go +++ b/dns/platform/darwin.go @@ -416,4 +416,4 @@ func (d *DarwinDNSConfigurator) clearState() error { logger.Debug("Cleared DNS state file") return nil -} \ No newline at end of file +} diff --git a/go.mod b/go.mod index 5261037..09a5bc4 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.25 require ( github.com/Microsoft/go-winio v0.6.2 - github.com/fosrl/newt v1.8.1 + github.com/fosrl/newt v1.9.0 github.com/godbus/dbus/v5 v5.2.2 github.com/gorilla/websocket v1.5.3 github.com/miekg/dns v1.1.70 @@ -30,3 +30,6 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect ) + +// To be used ONLY for local development +// replace github.com/fosrl/newt => ../newt diff --git a/go.sum b/go.sum index c0a2bf7..be51e01 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/fosrl/newt v1.8.1 h1:oP3xBEISoO/TENsHccqqs6LXpoOWCt6aiP75CfIWpvk= -github.com/fosrl/newt v1.8.1/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= +github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE= +github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= diff --git a/main.go b/main.go index f6c6973..2bf8dcd 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/updates" - "github.com/fosrl/olm/olm" + olmpkg "github.com/fosrl/olm/olm" ) func main() { @@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt } // Create a new olm.Config struct and copy values from the main config - olmConfig := olm.GlobalConfig{ + olmConfig := olmpkg.OlmConfig{ LogLevel: config.LogLevel, EnableAPI: config.EnableAPI, HTTPAddr: config.HTTPAddr, @@ -219,15 +219,20 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt Agent: "Olm CLI", OnExit: cancel, // Pass cancel function directly to trigger shutdown OnTerminated: cancel, + PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE + } + + olm, err := olmpkg.Init(ctx, olmConfig) + if err != nil { + logger.Fatal("Failed to initialize olm: %v", err) } - olm.Init(ctx, olmConfig) if err := olm.StartApi(); err != nil { logger.Fatal("Failed to start API server: %v", err) } if config.ID != "" && config.Secret != "" && config.Endpoint != "" { - tunnelConfig := olm.TunnelConfig{ + tunnelConfig := olmpkg.TunnelConfig{ Endpoint: config.Endpoint, ID: config.ID, Secret: config.Secret, diff --git a/olm/connect.go b/olm/connect.go new file mode 100644 index 0000000..575a8fd --- /dev/null +++ b/olm/connect.go @@ -0,0 +1,275 @@ +package olm + +import ( + "encoding/json" + "fmt" + "os" + "runtime" + "strconv" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + olmDevice "github.com/fosrl/olm/device" + "github.com/fosrl/olm/dns" + dnsOverride "github.com/fosrl/olm/dns/override" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" +) + +// OlmErrorData represents the error data sent from the server +type OlmErrorData struct { + Code string `json:"code"` + Message string `json:"message"` +} + +func (o *Olm) handleConnect(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + + var wgData WgData + + if o.connected { + logger.Info("Already connected. Ignoring new connection request.") + return + } + + if o.stopRegister != nil { + o.stopRegister() + o.stopRegister = nil + } + + if o.updateRegister != nil { + o.updateRegister = nil + } + + // if there is an existing tunnel then close it + if o.dev != nil { + logger.Info("Got new message. Closing existing tunnel!") + o.dev.Close() + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &wgData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + + o.tdev, err = func() (tun.Device, error) { + if o.tunnelConfig.FileDescriptorTun != 0 { + return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU) + } + ifName := o.tunnelConfig.InterfaceName + if runtime.GOOS == "darwin" { // this is if we dont pass a fd + ifName, err = network.FindUnusedUTUN() + if err != nil { + return nil, err + } + } + return tun.CreateTUN(ifName, o.tunnelConfig.MTU) + }() + if err != nil { + logger.Error("Failed to create TUN device: %v", err) + return + } + + // if config.FileDescriptorTun == 0 { + if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? + o.tunnelConfig.InterfaceName = realInterfaceName + } + // } + + // Wrap TUN device with packet filter for DNS proxy + o.middleDev = olmDevice.NewMiddleDevice(o.tdev) + + wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") + // Use filtered device instead of raw TUN device + o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger)) + + if o.tunnelConfig.EnableUAPI { + fileUAPI, err := func() (*os.File, error) { + if o.tunnelConfig.FileDescriptorUAPI != 0 { + fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) + } + return os.NewFile(uintptr(fd), ""), nil + } + return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName) + }() + if err != nil { + logger.Error("UAPI listen error: %v", err) + os.Exit(1) + return + } + + o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI) + if err != nil { + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) + } + + go func() { + for { + conn, err := o.uapiListener.Accept() + if err != nil { + return + } + go o.dev.IpcHandle(conn) + } + }() + logger.Info("UAPI listener started") + } + + if err = o.dev.Up(); err != nil { + logger.Error("Failed to bring up WireGuard device: %v", err) + } + + // Extract interface IP (strip CIDR notation if present) + interfaceIP := wgData.TunnelIP + if strings.Contains(interfaceIP, "/") { + interfaceIP = strings.Split(interfaceIP, "/")[0] + } + + // Create and start DNS proxy + o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP) + if err != nil { + logger.Error("Failed to create DNS proxy: %v", err) + } + + if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil { + logger.Error("Failed to o.tunnelConfigure interface: %v", err) + } + + if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet + logger.Error("Failed to add route for utility subnet: %v", err) + } + + // Create peer manager with integrated peer monitoring + o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ + Device: o.dev, + DNSProxy: o.dnsProxy, + InterfaceName: o.tunnelConfig.InterfaceName, + PrivateKey: o.privateKey, + MiddleDev: o.middleDev, + LocalIP: interfaceIP, + SharedBind: o.sharedBind, + WSClient: o.websocket, + APIServer: o.apiServer, + }) + + for i := range wgData.Sites { + site := wgData.Sites[i] + var siteEndpoint string + // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer + if site.RelayEndpoint != "" { + siteEndpoint = site.RelayEndpoint + } else { + siteEndpoint = site.Endpoint + } + + o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) + + if err := o.peerManager.AddPeer(site); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + logger.Info("Configured peer %s", site.PublicKey) + } + + o.peerManager.Start() + + if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime + logger.Error("Failed to start DNS proxy: %v", err) + } + + if o.tunnelConfig.OverrideDNS { + // Set up DNS override to use our DNS proxy + if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { + logger.Error("Failed to setup DNS override: %v", err) + return + } + + network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()}) + } + + o.apiServer.SetRegistered(true) + + o.connected = true + + // Start ping monitor now that we are registered and connected + o.websocket.StartPingMonitor() + + // Invoke onConnected callback if configured + if o.olmConfig.OnConnected != nil { + go o.olmConfig.OnConnected() + } + + logger.Info("WireGuard device created.") +} + +func (o *Olm) handleOlmError(msg websocket.WSMessage) { + logger.Debug("Received olm error message: %v", msg.Data) + + var errorData OlmErrorData + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling olm error data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &errorData); err != nil { + logger.Error("Error unmarshaling olm error data: %v", err) + return + } + + logger.Error("Olm error (code: %s): %s", errorData.Code, errorData.Message) + + // Set the olm error in the API server so it can be exposed via status + o.apiServer.SetOlmError(errorData.Code, errorData.Message) + + // Invoke onOlmError callback if configured + if o.olmConfig.OnOlmError != nil { + go o.olmConfig.OnOlmError(errorData.Code, errorData.Message) + } +} + +func (o *Olm) handleTerminate(msg websocket.WSMessage) { + logger.Info("Received terminate message") + + var errorData OlmErrorData + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling terminate error data: %v", err) + } else { + if err := json.Unmarshal(jsonData, &errorData); err != nil { + logger.Error("Error unmarshaling terminate error data: %v", err) + } else { + logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message) + // Set the olm error in the API server so it can be exposed via status + o.apiServer.SetOlmError(errorData.Code, errorData.Message) + } + } + + o.apiServer.SetTerminated(true) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) + o.apiServer.ClearPeerStatuses() + + network.ClearNetworkSettings() + + o.Close() + + if o.olmConfig.OnTerminated != nil { + go o.olmConfig.OnTerminated() + } +} diff --git a/olm/data.go b/olm/data.go new file mode 100644 index 0000000..35798c6 --- /dev/null +++ b/olm/data.go @@ -0,0 +1,347 @@ +package olm + +import ( + "encoding/json" + "time" + + "github.com/fosrl/newt/holepunch" + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" +) + +func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) { + logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var addSubnetsData peers.PeerAdd + if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil { + logger.Error("Error unmarshaling add-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId) + return + } + + // Add new subnets + for _, subnet := range addSubnetsData.RemoteSubnets { + if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases + for _, alias := range addSubnetsData.Aliases { + if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } +} + +func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) { + logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeSubnetsData peers.RemovePeerData + if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil { + logger.Error("Error unmarshaling remove-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId) + return + } + + // Remove subnets + for _, subnet := range removeSubnetsData.RemoteSubnets { + if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Remove aliases + for _, alias := range removeSubnetsData.Aliases { + if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } +} + +func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) { + logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateSubnetsData peers.UpdatePeerData + if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil { + logger.Error("Error unmarshaling update-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId) + return + } + + // Add new subnets BEFORE removing old ones to preserve shared subnets + // This ensures that if an old and new subnet are the same on different peers, + // the route won't be temporarily removed + for _, subnet := range updateSubnetsData.NewRemoteSubnets { + if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Remove old subnets after new ones are added + for _, subnet := range updateSubnetsData.OldRemoteSubnets { + if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases BEFORE removing old ones to preserve shared IP addresses + // This ensures that if an old and new alias share the same IP, the IP won't be + // temporarily removed from the allowed IPs list + for _, alias := range updateSubnetsData.NewAliases { + if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } + + // Remove old aliases after new ones are added + for _, alias := range updateSubnetsData.OldAliases { + if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } + + logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) +} + +// Handler for syncing peer configuration - reconciles expected state with actual state +func (o *Olm) handleSync(msg websocket.WSMessage) { + logger.Debug("Received sync message: %v", msg.Data) + + if !o.connected { + logger.Warn("Not connected, ignoring sync request") + return + } + + if o.peerManager == nil { + logger.Warn("Peer manager not initialized, ignoring sync request") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling sync data: %v", err) + return + } + + var syncData SyncData + if err := json.Unmarshal(jsonData, &syncData); err != nil { + logger.Error("Error unmarshaling sync data: %v", err) + return + } + + // Sync exit nodes for hole punching + o.syncExitNodes(syncData.ExitNodes) + + // Build a map of expected peers from the incoming data + expectedPeers := make(map[int]peers.SiteConfig) + for _, site := range syncData.Sites { + expectedPeers[site.SiteId] = site + } + + // Get all current peers + currentPeers := o.peerManager.GetAllPeers() + currentPeerMap := make(map[int]peers.SiteConfig) + for _, peer := range currentPeers { + currentPeerMap[peer.SiteId] = peer + } + + // Find peers to remove (in current but not in expected) + for siteId := range currentPeerMap { + if _, exists := expectedPeers[siteId]; !exists { + logger.Info("Sync: Removing peer for site %d (no longer in expected config)", siteId) + if err := o.peerManager.RemovePeer(siteId); err != nil { + logger.Error("Sync: Failed to remove peer %d: %v", siteId, err) + } else { + // Remove any exit nodes associated with this peer from hole punching + if o.holePunchManager != nil { + removed := o.holePunchManager.RemoveExitNodesByPeer(siteId) + if removed > 0 { + logger.Info("Sync: Removed %d exit nodes associated with peer %d from hole punch rotation", removed, siteId) + } + } + } + } + } + + // Find peers to add (in expected but not in current) and peers to update + for siteId, expectedSite := range expectedPeers { + if _, exists := currentPeerMap[siteId]; !exists { + // New peer - add it using the add flow (with holepunch) + logger.Info("Sync: Adding new peer for site %d", siteId) + + o.holePunchManager.TriggerHolePunch() + + // // TODO: do we need to send the message to the cloud to add the peer that way? + // if err := o.peerManager.AddPeer(expectedSite); err != nil { + // logger.Error("Sync: Failed to add peer %d: %v", siteId, err) + // } else { + // logger.Info("Sync: Successfully added peer for site %d", siteId) + // } + + // add the peer via the server + // this is important because newt needs to get triggered as well to add the peer once the hp is complete + o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": expectedSite.SiteId, + }, 1*time.Second, 10) + + } else { + // Existing peer - check if update is needed + currentSite := currentPeerMap[siteId] + needsUpdate := false + + // Check if any fields have changed + if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint { + needsUpdate = true + } + if expectedSite.RelayEndpoint != "" && expectedSite.RelayEndpoint != currentSite.RelayEndpoint { + needsUpdate = true + } + if expectedSite.PublicKey != "" && expectedSite.PublicKey != currentSite.PublicKey { + needsUpdate = true + } + if expectedSite.ServerIP != "" && expectedSite.ServerIP != currentSite.ServerIP { + needsUpdate = true + } + if expectedSite.ServerPort != 0 && expectedSite.ServerPort != currentSite.ServerPort { + needsUpdate = true + } + // Check remote subnets + if expectedSite.RemoteSubnets != nil && !slicesEqual(expectedSite.RemoteSubnets, currentSite.RemoteSubnets) { + needsUpdate = true + } + // Check aliases + if expectedSite.Aliases != nil && !aliasesEqual(expectedSite.Aliases, currentSite.Aliases) { + needsUpdate = true + } + + if needsUpdate { + logger.Info("Sync: Updating peer for site %d", siteId) + + // Merge expected data with current data + siteConfig := currentSite + if expectedSite.Endpoint != "" { + siteConfig.Endpoint = expectedSite.Endpoint + } + if expectedSite.RelayEndpoint != "" { + siteConfig.RelayEndpoint = expectedSite.RelayEndpoint + } + if expectedSite.PublicKey != "" { + siteConfig.PublicKey = expectedSite.PublicKey + } + if expectedSite.ServerIP != "" { + siteConfig.ServerIP = expectedSite.ServerIP + } + if expectedSite.ServerPort != 0 { + siteConfig.ServerPort = expectedSite.ServerPort + } + if expectedSite.RemoteSubnets != nil { + siteConfig.RemoteSubnets = expectedSite.RemoteSubnets + } + if expectedSite.Aliases != nil { + siteConfig.Aliases = expectedSite.Aliases + } + + if err := o.peerManager.UpdatePeer(siteConfig); err != nil { + logger.Error("Sync: Failed to update peer %d: %v", siteId, err) + } else { + // If the endpoint changed, trigger holepunch to refresh NAT mappings + if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint { + logger.Info("Sync: Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", siteId) + o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetServerHolepunchInterval() + } + logger.Info("Sync: Successfully updated peer for site %d", siteId) + } + } + } + } + + logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers)) +} + +// syncExitNodes reconciles the expected exit nodes with the current ones in the hole punch manager +func (o *Olm) syncExitNodes(expectedExitNodes []SyncExitNode) { + if o.holePunchManager == nil { + logger.Warn("Hole punch manager not initialized, skipping exit node sync") + return + } + + // Build a map of expected exit nodes by endpoint + expectedExitNodeMap := make(map[string]SyncExitNode) + for _, exitNode := range expectedExitNodes { + expectedExitNodeMap[exitNode.Endpoint] = exitNode + } + + // Get current exit nodes from hole punch manager + currentExitNodes := o.holePunchManager.GetExitNodes() + currentExitNodeMap := make(map[string]holepunch.ExitNode) + for _, exitNode := range currentExitNodes { + currentExitNodeMap[exitNode.Endpoint] = exitNode + } + + // Find exit nodes to remove (in current but not in expected) + for endpoint := range currentExitNodeMap { + if _, exists := expectedExitNodeMap[endpoint]; !exists { + logger.Info("Sync: Removing exit node %s (no longer in expected config)", endpoint) + o.holePunchManager.RemoveExitNode(endpoint) + } + } + + // Find exit nodes to add (in expected but not in current) + for endpoint, expectedExitNode := range expectedExitNodeMap { + if _, exists := currentExitNodeMap[endpoint]; !exists { + logger.Info("Sync: Adding new exit node %s", endpoint) + + relayPort := expectedExitNode.RelayPort + if relayPort == 0 { + relayPort = 21820 // default relay port + } + + hpExitNode := holepunch.ExitNode{ + Endpoint: expectedExitNode.Endpoint, + RelayPort: relayPort, + PublicKey: expectedExitNode.PublicKey, + SiteIds: expectedExitNode.SiteIds, + } + + if o.holePunchManager.AddExitNode(hpExitNode) { + logger.Info("Sync: Successfully added exit node %s", endpoint) + } + o.holePunchManager.TriggerHolePunch() + } + } + + logger.Info("Sync exit nodes completed: processed %d expected exit nodes, had %d current exit nodes", len(expectedExitNodeMap), len(currentExitNodeMap)) +} diff --git a/olm/olm.go b/olm/olm.go index f84ee4f..cd8a844 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -2,13 +2,12 @@ package olm import ( "context" - "encoding/json" "fmt" "net" + "net/http" + _ "net/http/pprof" "os" - "runtime" - "strconv" - "strings" + "sync" "time" "github.com/fosrl/newt/bind" @@ -28,40 +27,58 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -var ( - privateKey wgtypes.Key - connected bool - dev *device.Device - uapiListener net.Listener - tdev tun.Device - middleDev *olmDevice.MiddleDevice +type Olm struct { + privateKey wgtypes.Key + logFile *os.File + + connected bool + tunnelRunning bool + + uapiListener net.Listener + dev *device.Device + tdev tun.Device + middleDev *olmDevice.MiddleDevice + sharedBind *bind.SharedBind + dnsProxy *dns.DNSProxy apiServer *api.API - olmClient *websocket.Client - tunnelCancel context.CancelFunc - tunnelRunning bool - sharedBind *bind.SharedBind + websocket *websocket.Client holePunchManager *holepunch.Manager - globalConfig GlobalConfig - tunnelConfig TunnelConfig - globalCtx context.Context - stopRegister func() - stopPeerSend func() - updateRegister func(newData interface{}) - stopPing chan struct{} peerManager *peers.PeerManager -) + // Power mode management + currentPowerMode string + powerModeMu sync.Mutex + wakeUpTimer *time.Timer + wakeUpDebounce time.Duration + + olmCtx context.Context + tunnelCancel context.CancelFunc + + olmConfig OlmConfig + tunnelConfig TunnelConfig + + // Metadata to send alongside pings + fingerprint map[string]any + postures map[string]any + metaMu sync.Mutex + + stopRegister func() + updateRegister func(newData any) + + stopPeerSend func() +} // initTunnelInfo creates the shared UDP socket and holepunch manager. // This is used during initial tunnel setup and when switching organizations. -func initTunnelInfo(clientID string) error { - var err error - privateKey, err = wgtypes.GeneratePrivateKey() +func (o *Olm) initTunnelInfo(clientID string) error { + privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { logger.Error("Failed to generate private key: %v", err) return err } + o.privateKey = privateKey + sourcePort, err := util.FindAvailableUDPPort(49152, 65535) if err != nil { return fmt.Errorf("failed to find available UDP port: %w", err) @@ -77,57 +94,92 @@ func initTunnelInfo(clientID string) error { return fmt.Errorf("failed to create UDP socket: %w", err) } - sharedBind, err = bind.New(udpConn) + sharedBind, err := bind.New(udpConn) if err != nil { - udpConn.Close() + _ = udpConn.Close() return fmt.Errorf("failed to create shared bind: %w", err) } + o.sharedBind = sharedBind + // Add a reference for the hole punch senders (creator already has one reference for WireGuard) sharedBind.AddRef() logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) // Create the holepunch manager - holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) + o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) return nil } -func Init(ctx context.Context, config GlobalConfig) { - globalConfig = config - globalCtx = ctx - +func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) + // Start pprof server if enabled + if config.PprofAddr != "" { + go func() { + logger.Info("Starting pprof server on %s", config.PprofAddr) + if err := http.ListenAndServe(config.PprofAddr, nil); err != nil { + logger.Error("Failed to start pprof server: %v", err) + } + }() + } + + var logFile *os.File + if config.LogFilePath != "" { + file, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + logger.Fatal("Failed to open log file: %v", err) + return nil, err + } + + logger.SetOutput(file) + logFile = file + } + + if config.WakeUpDebounce == 0 { + config.WakeUpDebounce = 3 * time.Second + } + logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() if err != nil { logger.Fatal("Insufficient permissions to create native TUN interface: %v", err) - return + return nil, err } + var apiServer *api.API if config.HTTPAddr != "" { apiServer = api.NewAPI(config.HTTPAddr) } else if config.SocketPath != "" { apiServer = api.NewAPISocket(config.SocketPath) + } else { + // this is so is not null but it cant be started without either the socket path or http addr + apiServer = api.NewAPIStub() } apiServer.SetVersion(config.Version) apiServer.SetAgent(config.Agent) - // Set up API handlers - apiServer.SetHandlers( + newOlm := &Olm{ + logFile: logFile, + olmCtx: ctx, + apiServer: apiServer, + olmConfig: config, + } + + newOlm.registerAPICallbacks() + + return newOlm, nil +} + +func (o *Olm) registerAPICallbacks() { + o.apiServer.SetHandlers( // onConnect func(req api.ConnectionRequest) error { logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) - // Stop any existing tunnel before starting a new one - if olmClient != nil { - logger.Info("Stopping existing tunnel before starting new connection") - StopTunnel() - } - tunnelConfig := TunnelConfig{ Endpoint: req.Endpoint, ID: req.ID, @@ -181,7 +233,7 @@ func Init(ctx context.Context, config GlobalConfig) { // Start the tunnel process with the new credentials if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" { logger.Info("Starting tunnel with new credentials") - go StartTunnel(tunnelConfig) + go o.StartTunnel(tunnelConfig) } return nil @@ -189,706 +241,183 @@ func Init(ctx context.Context, config GlobalConfig) { // onSwitchOrg func(req api.SwitchOrgRequest) error { logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID) - return SwitchOrg(req.OrgID) + return o.SwitchOrg(req.OrgID) + }, + // onMetadataChange + func(req api.MetadataChangeRequest) error { + logger.Info("Received change metadata request via API") + + if req.Fingerprint != nil { + o.SetFingerprint(req.Fingerprint) + } + + if req.Postures != nil { + o.SetPostures(req.Postures) + } + + return nil }, // onDisconnect func() error { logger.Info("Processing disconnect request via API") - return StopTunnel() + return o.StopTunnel() }, // onExit func() error { logger.Info("Processing shutdown request via API") - Close() - if globalConfig.OnExit != nil { - globalConfig.OnExit() + o.Close() + if o.olmConfig.OnExit != nil { + o.olmConfig.OnExit() } return nil }, + // onRebind + func() error { + logger.Info("Processing rebind request via API") + return o.RebindSocket() + }, + // onPowerMode + func(req api.PowerModeRequest) error { + logger.Info("Processing power mode change request via API: mode=%s", req.Mode) + return o.SetPowerMode(req.Mode) + }, ) } -func StartTunnel(config TunnelConfig) { - if tunnelRunning { +func (o *Olm) StartTunnel(config TunnelConfig) { + if o.tunnelRunning { logger.Info("Tunnel already running") return } - - tunnelRunning = true // Also set it here in case it is called externally - tunnelConfig = config - - // Reset terminated status when tunnel starts - apiServer.SetTerminated(false) - + // debug print out the whole config logger.Debug("Starting tunnel with config: %+v", config) - // Create a cancellable context for this tunnel process - tunnelCtx, cancel := context.WithCancel(globalCtx) - tunnelCancel = cancel - defer func() { - tunnelCancel = nil - }() + o.tunnelRunning = true // Also set it here in case it is called externally + o.tunnelConfig = config - // Recreate channels for this tunnel session - stopPing = make(chan struct{}) + // Reset terminated status when tunnel starts + o.apiServer.SetTerminated(false) + + fingerprint := config.InitialFingerprint + if fingerprint == nil { + fingerprint = make(map[string]any) + } + + postures := config.InitialPostures + if postures == nil { + postures = make(map[string]any) + } + + o.SetFingerprint(fingerprint) + o.SetPostures(postures) + + // Create a cancellable context for this tunnel process + tunnelCtx, cancel := context.WithCancel(o.olmCtx) + o.tunnelCancel = cancel var ( - interfaceName = config.InterfaceName - id = config.ID - secret = config.Secret - userToken = config.UserToken + err error + id = config.ID + secret = config.Secret + userToken = config.UserToken ) - apiServer.SetOrgID(config.OrgID) + o.tunnelConfig.InterfaceName = config.InterfaceName - // Create a new olm client using the provided credentials - olm, err := websocket.NewClient( - id, // Use provided ID - secret, // Use provided secret - userToken, // Use provided user token OPTIONAL + o.apiServer.SetOrgID(config.OrgID) + + // Create a new o.websocket client using the provided credentials + o.websocket, err = websocket.NewClient( + id, + secret, + userToken, config.OrgID, - config.Endpoint, // Use provided endpoint - config.PingIntervalDuration, + config.Endpoint, + 30*time.Second, // 30 seconds config.PingTimeoutDuration, + websocket.WithPingDataProvider(func() map[string]any { + o.metaMu.Lock() + defer o.metaMu.Unlock() + return map[string]any{ + "fingerprint": o.fingerprint, + "postures": o.postures, + } + }), ) if err != nil { logger.Error("Failed to create olm: %v", err) return } - // Store the client reference globally - olmClient = olm - // Create shared UDP socket and holepunch manager - if err := initTunnelInfo(id); err != nil { + if err := o.initTunnelInfo(id); err != nil { logger.Error("%v", err) return } - olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - - var wgData WgData - - if connected { - logger.Info("Already connected. Ignoring new connection request.") - return - } - - if stopRegister != nil { - stopRegister() - stopRegister = nil - } - - if updateRegister != nil { - updateRegister = nil - } - - // if there is an existing tunnel then close it - if dev != nil { - logger.Info("Got new message. Closing existing tunnel!") - dev.Close() - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &wgData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - tdev, err = func() (tun.Device, error) { - if config.FileDescriptorTun != 0 { - return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) - } - var ifName = interfaceName - if runtime.GOOS == "darwin" { // this is if we dont pass a fd - ifName, err = network.FindUnusedUTUN() - if err != nil { - return nil, err - } - } - return tun.CreateTUN(ifName, config.MTU) - }() - - if err != nil { - logger.Error("Failed to create TUN device: %v", err) - return - } - - if config.FileDescriptorTun == 0 { - if realInterfaceName, err2 := tdev.Name(); err2 == nil { - interfaceName = realInterfaceName - } - } - - // Wrap TUN device with packet filter for DNS proxy - middleDev = olmDevice.NewMiddleDevice(tdev) - - wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") - // Use filtered device instead of raw TUN device - dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger)) - - if config.EnableUAPI { - fileUAPI, err := func() (*os.File, error) { - if config.FileDescriptorUAPI != 0 { - fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32) - if err != nil { - return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) - } - return os.NewFile(uintptr(fd), ""), nil - } - return olmDevice.UapiOpen(interfaceName) - }() - if err != nil { - logger.Error("UAPI listen error: %v", err) - os.Exit(1) - return - } - - uapiListener, err = olmDevice.UapiListen(interfaceName, fileUAPI) - if err != nil { - logger.Error("Failed to listen on uapi socket: %v", err) - os.Exit(1) - } - - go func() { - for { - conn, err := uapiListener.Accept() - if err != nil { - - return - } - go dev.IpcHandle(conn) - } - }() - logger.Info("UAPI listener started") - } - - if err = dev.Up(); err != nil { - logger.Error("Failed to bring up WireGuard device: %v", err) - } - - // Extract interface IP (strip CIDR notation if present) - interfaceIP := wgData.TunnelIP - if strings.Contains(interfaceIP, "/") { - interfaceIP = strings.Split(interfaceIP, "/")[0] - } - - // Create and start DNS proxy - dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP) - if err != nil { - logger.Error("Failed to create DNS proxy: %v", err) - } - - if err = network.ConfigureInterface(interfaceName, wgData.TunnelIP, config.MTU); err != nil { - logger.Error("Failed to configure interface: %v", err) - } - - if network.AddRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet - logger.Error("Failed to add route for utility subnet: %v", err) - } - - // Create peer manager with integrated peer monitoring - peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ - Device: dev, - DNSProxy: dnsProxy, - InterfaceName: interfaceName, - PrivateKey: privateKey, - MiddleDev: middleDev, - LocalIP: interfaceIP, - SharedBind: sharedBind, - WSClient: olm, - APIServer: apiServer, - }) - - for i := range wgData.Sites { - site := wgData.Sites[i] - var siteEndpoint string - // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer - if site.RelayEndpoint != "" { - siteEndpoint = site.RelayEndpoint - } else { - siteEndpoint = site.Endpoint - } - - apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) - - if err := peerManager.AddPeer(site); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - - logger.Info("Configured peer %s", site.PublicKey) - } - - peerManager.Start() - - if err := dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime - logger.Error("Failed to start DNS proxy: %v", err) - } - - if config.OverrideDNS { - // Set up DNS override to use our DNS proxy - if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy); err != nil { - logger.Error("Failed to setup DNS override: %v", err) - return - } - } - - apiServer.SetRegistered(true) - - connected = true - - // Invoke onConnected callback if configured - if globalConfig.OnConnected != nil { - go globalConfig.OnConnected() - } - - logger.Info("WireGuard device created.") - }) - - olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { - logger.Debug("Received update-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var updateData peers.SiteConfig - if err := json.Unmarshal(jsonData, &updateData); err != nil { - logger.Error("Error unmarshaling update data: %v", err) - return - } - - // Get existing peer from PeerManager - existingPeer, exists := peerManager.GetPeer(updateData.SiteId) - if !exists { - logger.Warn("Peer with site ID %d not found", updateData.SiteId) - return - } - - // Create updated site config by merging with existing data - siteConfig := existingPeer - - if updateData.Endpoint != "" { - siteConfig.Endpoint = updateData.Endpoint - } - if updateData.RelayEndpoint != "" { - siteConfig.RelayEndpoint = updateData.RelayEndpoint - } - if updateData.PublicKey != "" { - siteConfig.PublicKey = updateData.PublicKey - } - if updateData.ServerIP != "" { - siteConfig.ServerIP = updateData.ServerIP - } - if updateData.ServerPort != 0 { - siteConfig.ServerPort = updateData.ServerPort - } - if updateData.RemoteSubnets != nil { - siteConfig.RemoteSubnets = updateData.RemoteSubnets - } - - if err := peerManager.UpdatePeer(siteConfig); err != nil { - logger.Error("Failed to update peer: %v", err) - return - } - - // If the endpoint changed, trigger holepunch to refresh NAT mappings - if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { - logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) - holePunchManager.TriggerHolePunch() - holePunchManager.ResetInterval() - } - - // Update successful - logger.Info("Successfully updated peer for site %d", updateData.SiteId) - }) - - // Handler for adding a new peer - olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { - logger.Debug("Received add-peer message: %v", msg.Data) - - if stopPeerSend != nil { - stopPeerSend() - stopPeerSend = nil - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var siteConfig peers.SiteConfig - if err := json.Unmarshal(jsonData, &siteConfig); err != nil { - logger.Error("Error unmarshaling add data: %v", err) - return - } - - holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it - - if err := peerManager.AddPeer(siteConfig); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - - // Add successful - logger.Info("Successfully added peer for site %d", siteConfig.SiteId) - }) - - // Handler for removing a peer - olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) { - logger.Debug("Received remove-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var removeData peers.PeerRemove - if err := json.Unmarshal(jsonData, &removeData); err != nil { - logger.Error("Error unmarshaling remove data: %v", err) - return - } - - if err := peerManager.RemovePeer(removeData.SiteId); err != nil { - logger.Error("Failed to remove peer: %v", err) - return - } - - // Remove any exit nodes associated with this peer from hole punching - if holePunchManager != nil { - removed := holePunchManager.RemoveExitNodesByPeer(removeData.SiteId) - if removed > 0 { - logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId) - } - } - - // Remove successful - logger.Info("Successfully removed peer for site %d", removeData.SiteId) - }) - - // Handler for adding remote subnets to a peer - olm.RegisterHandler("olm/wg/peer/data/add", func(msg websocket.WSMessage) { - logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var addSubnetsData peers.PeerAdd - if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil { - logger.Error("Error unmarshaling add-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(addSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId) - return - } - - // Add new subnets - for _, subnet := range addSubnetsData.RemoteSubnets { - if err := peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to add allowed IP %s: %v", subnet, err) - } - } - - // Add new aliases - for _, alias := range addSubnetsData.Aliases { - if err := peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil { - logger.Error("Failed to add alias %s: %v", alias.Alias, err) - } - } - }) - - // Handler for removing remote subnets from a peer - olm.RegisterHandler("olm/wg/peer/data/remove", func(msg websocket.WSMessage) { - logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var removeSubnetsData peers.RemovePeerData - if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil { - logger.Error("Error unmarshaling remove-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(removeSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId) - return - } - - // Remove subnets - for _, subnet := range removeSubnetsData.RemoteSubnets { - if err := peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to remove allowed IP %s: %v", subnet, err) - } - } - - // Remove aliases - for _, alias := range removeSubnetsData.Aliases { - if err := peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil { - logger.Error("Failed to remove alias %s: %v", alias.Alias, err) - } - } - }) - - // Handler for updating remote subnets of a peer (remove old, add new in one operation) - olm.RegisterHandler("olm/wg/peer/data/update", func(msg websocket.WSMessage) { - logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var updateSubnetsData peers.UpdatePeerData - if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil { - logger.Error("Error unmarshaling update-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(updateSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", updateSubnetsData.SiteId) - return - } - - // Add new subnets BEFORE removing old ones to preserve shared subnets - // This ensures that if an old and new subnet are the same on different peers, - // the route won't be temporarily removed - for _, subnet := range updateSubnetsData.NewRemoteSubnets { - if err := peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to add allowed IP %s: %v", subnet, err) - } - } - - // Remove old subnets after new ones are added - for _, subnet := range updateSubnetsData.OldRemoteSubnets { - if err := peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to remove allowed IP %s: %v", subnet, err) - } - } - - // Add new aliases BEFORE removing old ones to preserve shared IP addresses - // This ensures that if an old and new alias share the same IP, the IP won't be - // temporarily removed from the allowed IPs list - for _, alias := range updateSubnetsData.NewAliases { - if err := peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil { - logger.Error("Failed to add alias %s: %v", alias.Alias, err) - } - } - - // Remove old aliases after new ones are added - for _, alias := range updateSubnetsData.OldAliases { - if err := peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil { - logger.Error("Failed to remove alias %s: %v", alias.Alias, err) - } - } - - logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) - }) - - olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) { - logger.Debug("Received relay-peer message: %v", msg.Data) - - // Check if peerManager is still valid (may be nil during shutdown) - if peerManager == nil { - logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") - return - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var relayData peers.RelayPeerData - if err := json.Unmarshal(jsonData, &relayData); err != nil { - logger.Error("Error unmarshaling relay data: %v", err) - return - } - - primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - // Update HTTP server to mark this peer as using relay - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) - - peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort) - }) - - olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) { - logger.Debug("Received unrelay-peer message: %v", msg.Data) - - // Check if peerManager is still valid (may be nil during shutdown) - if peerManager == nil { - logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") - return - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var relayData peers.UnRelayPeerData - if err := json.Unmarshal(jsonData, &relayData); err != nil { - logger.Error("Error unmarshaling relay data: %v", err) - return - } - - primaryRelay, err := util.ResolveDomain(relayData.Endpoint) - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - // Update HTTP server to mark this peer as using relay - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) - - peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) - }) + // Handlers for managing connection status + o.websocket.RegisterHandler("olm/wg/connect", o.handleConnect) + o.websocket.RegisterHandler("olm/error", o.handleOlmError) + o.websocket.RegisterHandler("olm/terminate", o.handleTerminate) + + // Handlers for managing peers + o.websocket.RegisterHandler("olm/wg/peer/add", o.handleWgPeerAdd) + o.websocket.RegisterHandler("olm/wg/peer/remove", o.handleWgPeerRemove) + o.websocket.RegisterHandler("olm/wg/peer/update", o.handleWgPeerUpdate) + o.websocket.RegisterHandler("olm/wg/peer/relay", o.handleWgPeerRelay) + o.websocket.RegisterHandler("olm/wg/peer/unrelay", o.handleWgPeerUnrelay) + + // Handlers for managing remote subnets to a peer + o.websocket.RegisterHandler("olm/wg/peer/data/add", o.handleWgPeerAddData) + o.websocket.RegisterHandler("olm/wg/peer/data/remove", o.handleWgPeerRemoveData) + o.websocket.RegisterHandler("olm/wg/peer/data/update", o.handleWgPeerUpdateData) // Handler for peer handshake - adds exit node to holepunch rotation and notifies server - olm.RegisterHandler("olm/wg/peer/holepunch/site/add", func(msg websocket.WSMessage) { - logger.Debug("Received peer-handshake message: %v", msg.Data) + o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) + o.websocket.RegisterHandler("olm/sync", o.handleSync) - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling handshake data: %v", err) - return - } - - var handshakeData struct { - SiteId int `json:"siteId"` - ExitNode struct { - PublicKey string `json:"publicKey"` - Endpoint string `json:"endpoint"` - RelayPort uint16 `json:"relayPort"` - } `json:"exitNode"` - } - - if err := json.Unmarshal(jsonData, &handshakeData); err != nil { - logger.Error("Error unmarshaling handshake data: %v", err) - return - } - - // Get existing peer from PeerManager - _, exists := peerManager.GetPeer(handshakeData.SiteId) - if exists { - logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) - return - } - - relayPort := handshakeData.ExitNode.RelayPort - if relayPort == 0 { - relayPort = 21820 // default relay port - } - - siteId := handshakeData.SiteId - exitNode := holepunch.ExitNode{ - Endpoint: handshakeData.ExitNode.Endpoint, - RelayPort: relayPort, - PublicKey: handshakeData.ExitNode.PublicKey, - SiteIds: []int{siteId}, - } - - added := holePunchManager.AddExitNode(exitNode) - if added { - logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) - } else { - logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) - } - - holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt - holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud - - // Send handshake acknowledgment back to server with retry - stopPeerSend, _ = olm.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": handshakeData.SiteId, - }, 1*time.Second) - - logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) - }) - - olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { - logger.Info("Received terminate message") - apiServer.SetTerminated(true) - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) - apiServer.ClearPeerStatuses() - network.ClearNetworkSettings() - Close() - - if globalConfig.OnTerminated != nil { - go globalConfig.OnTerminated() - } - }) - - olm.RegisterHandler("pong", func(msg websocket.WSMessage) { - logger.Debug("Received pong message") - }) - - olm.OnConnect(func() error { + o.websocket.OnConnect(func() error { logger.Info("Websocket Connected") - apiServer.SetConnectionStatus(true) + o.apiServer.SetConnectionStatus(true) - if connected { + if o.connected { + o.websocket.StartPingMonitor() + logger.Debug("Already connected, skipping registration") return nil } - publicKey := privateKey.PublicKey() + publicKey := o.privateKey.PublicKey() // delay for 500ms to allow for time for the hp to get processed time.Sleep(500 * time.Millisecond) - if stopRegister == nil { + if o.stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) - stopRegister, updateRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - "relay": !config.Holepunch, - "olmVersion": globalConfig.Version, - "olmAgent": globalConfig.Agent, - "orgId": config.OrgID, - "userToken": userToken, - }, 1*time.Second) + o.stopRegister, o.updateRegister = o.websocket.SendMessageInterval("olm/wg/register", map[string]any{ + "publicKey": publicKey.String(), + "relay": !config.Holepunch, + "olmVersion": o.olmConfig.Version, + "olmAgent": o.olmConfig.Agent, + "orgId": config.OrgID, + "userToken": userToken, + "fingerprint": o.fingerprint, + "postures": o.postures, + }, 1*time.Second, 10) // Invoke onRegistered callback if configured - if globalConfig.OnRegistered != nil { - go globalConfig.OnRegistered() + if o.olmConfig.OnRegistered != nil { + go o.olmConfig.OnRegistered() } } - go keepSendingPing(olm) - return nil }) - olm.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { - holePunchManager.SetToken(token) + o.websocket.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { + o.holePunchManager.SetToken(token) logger.Debug("Got exit nodes for hole punching: %v", exitNodes) @@ -912,114 +441,115 @@ func StartTunnel(config TunnelConfig) { // Start hole punching using the manager logger.Info("Starting hole punch for %d exit nodes", len(exitNodes)) - if err := holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { + if err := o.holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { logger.Warn("Failed to start hole punch: %v", err) } }) - olm.OnAuthError(func(statusCode int, message string) { + o.websocket.OnAuthError(func(statusCode int, message string) { logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message) - apiServer.SetTerminated(true) - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) - apiServer.ClearPeerStatuses() + o.apiServer.SetTerminated(true) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) + o.apiServer.ClearOlmError() + o.apiServer.ClearPeerStatuses() network.ClearNetworkSettings() - Close() + o.Close() - if globalConfig.OnAuthError != nil { - go globalConfig.OnAuthError(statusCode, message) + if o.olmConfig.OnAuthError != nil { + go o.olmConfig.OnAuthError(statusCode, message) } - if globalConfig.OnTerminated != nil { - go globalConfig.OnTerminated() + if o.olmConfig.OnTerminated != nil { + go o.olmConfig.OnTerminated() } }) // Connect to the WebSocket server - if err := olm.Connect(); err != nil { + if err := o.websocket.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) return } - defer olm.Close() + defer func() { _ = o.websocket.Close() }() // Wait for context cancellation <-tunnelCtx.Done() logger.Info("Tunnel process context cancelled, cleaning up") } -func Close() { +func (o *Olm) Close() { + // send a disconnect message to the cloud to show disconnected + if o.websocket != nil { + o.websocket.SendMessage("olm/disconnecting", map[string]any{}) + // Close the websocket connection after sending disconnect + _ = o.websocket.Close() + o.websocket = nil + } + // Restore original DNS configuration // we do this first to avoid any DNS issues if something else gets stuck if err := dnsOverride.RestoreDNSOverride(); err != nil { logger.Error("Failed to restore DNS: %v", err) } - // Stop hole punch manager - if holePunchManager != nil { - holePunchManager.Stop() - holePunchManager = nil + if o.holePunchManager != nil { + o.holePunchManager.Stop() + o.holePunchManager = nil } - if stopPing != nil { - select { - case <-stopPing: - // Channel already closed - default: - close(stopPing) - } + if o.stopRegister != nil { + o.stopRegister() + o.stopRegister = nil } - if stopRegister != nil { - stopRegister() - stopRegister = nil + // Close() also calls Stop() internally + if o.peerManager != nil { + o.peerManager.Close() + o.peerManager = nil } - if updateRegister != nil { - updateRegister = nil + if o.uapiListener != nil { + _ = o.uapiListener.Close() + o.uapiListener = nil } - if peerManager != nil { - peerManager.Close() // Close() also calls Stop() internally - peerManager = nil - } - - if uapiListener != nil { - uapiListener.Close() - uapiListener = nil + if o.logFile != nil { + _ = o.logFile.Close() + o.logFile = nil } // Stop DNS proxy first - it uses the middleDev for packet filtering - logger.Debug("Stopping DNS proxy") - if dnsProxy != nil { - dnsProxy.Stop() - dnsProxy = nil + if o.dnsProxy != nil { + logger.Debug("Stopping DNS proxy") + o.dnsProxy.Stop() + o.dnsProxy = nil } // Close MiddleDevice first - this closes the TUN and signals the closed channel // This unblocks the pump goroutine and allows WireGuard's TUN reader to exit - logger.Debug("Closing MiddleDevice") - if middleDev != nil { - middleDev.Close() - middleDev = nil + // Note: o.tdev is closed by o.middleDev.Close() since middleDev wraps it + if o.middleDev != nil { + logger.Debug("Closing MiddleDevice") + _ = o.middleDev.Close() + o.middleDev = nil } - // Note: tdev is closed by middleDev.Close() since middleDev wraps it - tdev = nil // Now close WireGuard device - its TUN reader should have exited by now - logger.Debug("Closing WireGuard device") - if dev != nil { - dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference - dev = nil + // This will call sharedBind.Close() which releases WireGuard's reference + if o.dev != nil { + logger.Debug("Closing WireGuard device") + o.dev.Close() + o.dev = nil } - // Release the hole punch reference to the shared bind - if sharedBind != nil { - // Release hole punch reference (WireGuard already released its reference via dev.Close()) - logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount()) - sharedBind.Release() - sharedBind = nil + // Release the hole punch reference to the shared bind (WireGuard already + // released its reference via dev.Close()) + if o.sharedBind != nil { + logger.Debug("Releasing shared bind (refcount before release: %d)", o.sharedBind.GetRefCount()) + _ = o.sharedBind.Release() logger.Info("Released shared UDP bind") + o.sharedBind = nil } logger.Info("Olm service stopped") @@ -1027,78 +557,332 @@ func Close() { // StopTunnel stops just the tunnel process and websocket connection // without shutting down the entire application -func StopTunnel() error { +func (o *Olm) StopTunnel() error { logger.Info("Stopping tunnel process") + if !o.tunnelRunning { + logger.Debug("Tunnel not running, nothing to stop") + return nil + } + // Cancel the tunnel context if it exists - if tunnelCancel != nil { - tunnelCancel() + if o.tunnelCancel != nil { + o.tunnelCancel() // Give it a moment to clean up time.Sleep(200 * time.Millisecond) } - // Close the websocket connection - if olmClient != nil { - olmClient.Close() - olmClient = nil - } - - Close() + // Close() will handle sending disconnect message and closing websocket + o.Close() // Reset the connected state - connected = false - tunnelRunning = false + o.connected = false + o.tunnelRunning = false // Update API server status - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) + o.apiServer.ClearOlmError() network.ClearNetworkSettings() - apiServer.ClearPeerStatuses() + o.apiServer.ClearPeerStatuses() logger.Info("Tunnel process stopped") return nil } -func StopApi() error { - if apiServer != nil { - err := apiServer.Stop() +func (o *Olm) StopApi() error { + if o.apiServer != nil { + err := o.apiServer.Stop() if err != nil { return fmt.Errorf("failed to stop API server: %w", err) } } + return nil } -func StartApi() error { - if apiServer != nil { - err := apiServer.Start() +func (o *Olm) StartApi() error { + if o.apiServer != nil { + err := o.apiServer.Start() if err != nil { return fmt.Errorf("failed to start API server: %w", err) } } + return nil } -func GetStatus() api.StatusResponse { - return apiServer.GetStatus() +func (o *Olm) GetStatus() api.StatusResponse { + return o.apiServer.GetStatus() } -func SwitchOrg(orgID string) error { +func (o *Olm) SwitchOrg(orgID string) error { logger.Info("Processing org switch request to orgId: %s", orgID) // stop the tunnel - if err := StopTunnel(); err != nil { + if err := o.StopTunnel(); err != nil { return fmt.Errorf("failed to stop existing tunnel: %w", err) } // Update the org ID in the API server and global config - apiServer.SetOrgID(orgID) + o.apiServer.SetOrgID(orgID) - tunnelConfig.OrgID = orgID + o.tunnelConfig.OrgID = orgID // Restart the tunnel with the same config but new org ID - go StartTunnel(tunnelConfig) + go o.StartTunnel(o.tunnelConfig) return nil } + +func (o *Olm) SetFingerprint(data map[string]any) { + o.metaMu.Lock() + defer o.metaMu.Unlock() + + o.fingerprint = data +} + +func (o *Olm) SetPostures(data map[string]any) { + o.metaMu.Lock() + defer o.metaMu.Unlock() + + o.postures = data +} + +// SetPowerMode switches between normal and low power modes +// In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes +// In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored +// Wake-up has a 3-second debounce to prevent rapid flip-flopping; sleep is immediate +func (o *Olm) SetPowerMode(mode string) error { + // Validate mode + if mode != "normal" && mode != "low" { + return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode) + } + + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + + // If already in the requested mode, return early + if o.currentPowerMode == mode { + // Cancel any pending wake-up timer if we're already in normal mode + if mode == "normal" && o.wakeUpTimer != nil { + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } + logger.Debug("Already in %s power mode", mode) + return nil + } + + if mode == "low" { + // Low Power Mode: Cancel any pending wake-up and immediately go to sleep + + // Cancel pending wake-up timer if any + if o.wakeUpTimer != nil { + logger.Debug("Cancelling pending wake-up timer") + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } + + logger.Info("Switching to low power mode") + + // Mark as disconnected so we re-register on reconnect + o.connected = false + + // Update API server connection status + if o.apiServer != nil { + o.apiServer.SetConnectionStatus(false) + } + + if o.websocket != nil { + logger.Info("Disconnecting websocket for low power mode") + if err := o.websocket.Disconnect(); err != nil { + logger.Error("Error disconnecting websocket: %v", err) + } + } + + lowPowerInterval := 10 * time.Minute + + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval) + peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval) + logger.Info("Set monitoring intervals to 10 minutes for low power mode") + } + o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable + } + + if o.holePunchManager != nil { + o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval) + } + + o.currentPowerMode = "low" + logger.Info("Switched to low power mode") + + } else { + // Normal Power Mode: Start debounce timer before actually waking up + + // If there's already a pending wake-up timer, don't start another + if o.wakeUpTimer != nil { + logger.Debug("Wake-up already pending, ignoring duplicate request") + return nil + } + + logger.Info("Wake-up requested, starting %v debounce timer", o.wakeUpDebounce) + + o.wakeUpTimer = time.AfterFunc(o.wakeUpDebounce, func() { + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + + // Clear the timer reference + o.wakeUpTimer = nil + + // Double-check we're still in low power mode (could have changed) + if o.currentPowerMode == "normal" { + logger.Debug("Already in normal mode after debounce, skipping wake-up") + return + } + + logger.Info("Debounce complete, switching to normal power mode") + + logger.Info("Reconnecting websocket for normal power mode") + if o.websocket != nil { + if err := o.websocket.Connect(); err != nil { + logger.Error("Failed to reconnect websocket: %v", err) + return + } + } + + // Restore intervals and reconnect websocket + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.ResetPeerHolepunchInterval() + peerMonitor.ResetPeerInterval() + } + + o.peerManager.UpdateAllPeersPersistentKeepalive(5) + } + + if o.holePunchManager != nil { + o.holePunchManager.ResetServerHolepunchInterval() + } + + o.currentPowerMode = "normal" + logger.Info("Switched to normal power mode") + }) + } + + return nil +} + +// RebindSocket recreates the UDP socket when network connectivity changes. +// This is necessary on macOS/iOS when transitioning between WiFi and cellular, +// as the old socket becomes stale and can no longer route packets. +// Call this method when detecting a network path change. +func (o *Olm) RebindSocket() error { + if o.sharedBind == nil { + return fmt.Errorf("shared bind is not initialized") + } + + // Close the old socket first to release the port, then try to rebind to the same port + currentPort, err := o.sharedBind.CloseSocket() + if err != nil { + return fmt.Errorf("failed to close old socket: %w", err) + } + + logger.Info("Rebinding UDP socket (released port: %d)", currentPort) + + // Create a new UDP socket + var newConn *net.UDPConn + var newPort uint16 + + // First try to bind to the same port (now available since we closed the old socket) + localAddr := &net.UDPAddr{ + Port: int(currentPort), + IP: net.IPv4zero, + } + + newConn, err = net.ListenUDP("udp4", localAddr) + if err != nil { + // If we can't reuse the port, find a new one + logger.Warn("Could not rebind to port %d, finding new port: %v", currentPort, err) + newPort, err = util.FindAvailableUDPPort(49152, 65535) + if err != nil { + return fmt.Errorf("failed to find available UDP port: %w", err) + } + + localAddr = &net.UDPAddr{ + Port: int(newPort), + IP: net.IPv4zero, + } + + // Use udp4 explicitly to avoid IPv6 dual-stack issues + newConn, err = net.ListenUDP("udp4", localAddr) + if err != nil { + return fmt.Errorf("failed to create new UDP socket: %w", err) + } + } else { + newPort = currentPort + } + + // Rebind the shared bind with the new connection + if err := o.sharedBind.Rebind(newConn); err != nil { + newConn.Close() + return fmt.Errorf("failed to rebind shared bind: %w", err) + } + + logger.Info("Successfully rebound UDP socket on port %d", newPort) + + // Check if we're in low power mode before triggering hole punch + o.powerModeMu.Lock() + isLowPower := o.currentPowerMode == "low" + o.powerModeMu.Unlock() + + // Only trigger hole punch if not in low power mode + if !isLowPower && o.holePunchManager != nil { + o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetServerHolepunchInterval() + logger.Info("Triggered hole punch after socket rebind") + } else if isLowPower { + logger.Info("Skipping hole punch trigger due to low power mode") + } + + return nil +} + +func (o *Olm) AddDevice(fd uint32) error { + if o.middleDev == nil { + return fmt.Errorf("middle device is not initialized") + } + + if o.tunnelConfig.MTU == 0 { + return fmt.Errorf("tunnel MTU is not set") + } + + tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU) + if err != nil { + return fmt.Errorf("failed to create TUN device from fd: %v", err) + } + + // Update interface name if available + if realInterfaceName, err2 := tdev.Name(); err2 == nil { + o.tunnelConfig.InterfaceName = realInterfaceName + } + + // Replace the existing TUN device in the middle device with the new one + o.middleDev.AddDevice(tdev) + + logger.Info("Added device from file descriptor %d", fd) + + return nil +} + +func GetNetworkSettingsJSON() (string, error) { + return network.GetJSON() +} + +func GetNetworkSettingsIncrementor() int { + return network.GetIncrementor() +} diff --git a/olm/peer.go b/olm/peer.go new file mode 100644 index 0000000..56e298d --- /dev/null +++ b/olm/peer.go @@ -0,0 +1,258 @@ +package olm + +import ( + "encoding/json" + "time" + + "github.com/fosrl/newt/holepunch" + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" +) + +func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { + logger.Debug("Received add-peer message: %v", msg.Data) + + if o.stopPeerSend != nil { + o.stopPeerSend() + o.stopPeerSend = nil + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var siteConfig peers.SiteConfig + if err := json.Unmarshal(jsonData, &siteConfig); err != nil { + logger.Error("Error unmarshaling add data: %v", err) + return + } + + _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it + + if err := o.peerManager.AddPeer(siteConfig); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + logger.Info("Successfully added peer for site %d", siteConfig.SiteId) +} + +func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { + logger.Debug("Received remove-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeData peers.PeerRemove + if err := json.Unmarshal(jsonData, &removeData); err != nil { + logger.Error("Error unmarshaling remove data: %v", err) + return + } + + if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil { + logger.Error("Failed to remove peer: %v", err) + return + } + + // Remove any exit nodes associated with this peer from hole punching + if o.holePunchManager != nil { + removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId) + if removed > 0 { + logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId) + } + } + + logger.Info("Successfully removed peer for site %d", removeData.SiteId) +} + +func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) { + logger.Debug("Received update-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateData peers.SiteConfig + if err := json.Unmarshal(jsonData, &updateData); err != nil { + logger.Error("Error unmarshaling update data: %v", err) + return + } + + // Get existing peer from PeerManager + existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId) + if !exists { + logger.Warn("Peer with site ID %d not found", updateData.SiteId) + return + } + + // Create updated site config by merging with existing data + siteConfig := existingPeer + + if updateData.Endpoint != "" { + siteConfig.Endpoint = updateData.Endpoint + } + if updateData.RelayEndpoint != "" { + siteConfig.RelayEndpoint = updateData.RelayEndpoint + } + if updateData.PublicKey != "" { + siteConfig.PublicKey = updateData.PublicKey + } + if updateData.ServerIP != "" { + siteConfig.ServerIP = updateData.ServerIP + } + if updateData.ServerPort != 0 { + siteConfig.ServerPort = updateData.ServerPort + } + if updateData.RemoteSubnets != nil { + siteConfig.RemoteSubnets = updateData.RemoteSubnets + } + + if err := o.peerManager.UpdatePeer(siteConfig); err != nil { + logger.Error("Failed to update peer: %v", err) + return + } + + // If the endpoint changed, trigger holepunch to refresh NAT mappings + if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { + logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) + _ = o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetServerHolepunchInterval() + } + + logger.Info("Successfully updated peer for site %d", updateData.SiteId) +} + +func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) { + logger.Debug("Received relay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if o.peerManager == nil { + logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.RelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) + if err != nil { + logger.Error("Failed to resolve primary relay endpoint: %v", err) + return + } + + // Update HTTP server to mark this peer as using relay + o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) + + o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort) +} + +func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { + logger.Debug("Received unrelay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if o.peerManager == nil { + logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.UnRelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.Endpoint) + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + // Update HTTP server to mark this peer as using relay + o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) + + o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) +} + +func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { + logger.Debug("Received peer-handshake message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling handshake data: %v", err) + return + } + + var handshakeData struct { + SiteId int `json:"siteId"` + ExitNode struct { + PublicKey string `json:"publicKey"` + Endpoint string `json:"endpoint"` + RelayPort uint16 `json:"relayPort"` + } `json:"exitNode"` + } + + if err := json.Unmarshal(jsonData, &handshakeData); err != nil { + logger.Error("Error unmarshaling handshake data: %v", err) + return + } + + // Get existing peer from PeerManager + _, exists := o.peerManager.GetPeer(handshakeData.SiteId) + if exists { + logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) + return + } + + relayPort := handshakeData.ExitNode.RelayPort + if relayPort == 0 { + relayPort = 21820 // default relay port + } + + siteId := handshakeData.SiteId + exitNode := holepunch.ExitNode{ + Endpoint: handshakeData.ExitNode.Endpoint, + RelayPort: relayPort, + PublicKey: handshakeData.ExitNode.PublicKey, + SiteIds: []int{siteId}, + } + + added := o.holePunchManager.AddExitNode(exitNode) + if added { + logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) + } else { + logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) + } + + o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt + o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud + + // Send handshake acknowledgment back to server with retry + o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + }, 1*time.Second, 10) + + logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) +} diff --git a/olm/types.go b/olm/types.go index b7153af..198b222 100644 --- a/olm/types.go +++ b/olm/types.go @@ -12,9 +12,22 @@ type WgData struct { UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses } -type GlobalConfig struct { +type SyncData struct { + Sites []peers.SiteConfig `json:"sites"` + ExitNodes []SyncExitNode `json:"exitNodes"` +} + +type SyncExitNode struct { + Endpoint string `json:"endpoint"` + RelayPort uint16 `json:"relayPort"` + PublicKey string `json:"publicKey"` + SiteIds []int `json:"siteIds"` +} + +type OlmConfig struct { // Logging - LogLevel string + LogLevel string + LogFilePath string // HTTP server EnableAPI bool @@ -23,11 +36,17 @@ type GlobalConfig struct { Version string Agent string + WakeUpDebounce time.Duration + + // Debugging + PprofAddr string // Address to serve pprof on (e.g., "localhost:6060") + // Callbacks OnRegistered func() OnConnected func() OnTerminated func() OnAuthError func(statusCode int, message string) // Called when auth fails (401/403) + OnOlmError func(code string, message string) // Called when registration fails OnExit func() // Called when exit is requested via API } @@ -63,5 +82,8 @@ type TunnelConfig struct { OverrideDNS bool TunnelDNS bool + InitialFingerprint map[string]any + InitialPostures map[string]any + DisableRelay bool } diff --git a/olm/util.go b/olm/util.go index 6bfd171..73572dc 100644 --- a/olm/util.go +++ b/olm/util.go @@ -1,55 +1,47 @@ package olm import ( - "time" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/network" - "github.com/fosrl/olm/websocket" + "github.com/fosrl/olm/peers" ) -func sendPing(olm *websocket.Client) error { - err := olm.SendMessage("olm/ping", map[string]interface{}{ - "timestamp": time.Now().Unix(), - "userToken": olm.GetConfig().UserToken, - }) - if err != nil { - logger.Error("Failed to send ping message: %v", err) - return err +// slicesEqual compares two string slices for equality (order-independent) +func slicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false } - logger.Debug("Sent ping message") - return nil -} - -func keepSendingPing(olm *websocket.Client) { - // Send ping immediately on startup - if err := sendPing(olm); err != nil { - logger.Error("Failed to send initial ping: %v", err) - } else { - logger.Info("Sent initial ping message") + // Create a map to count occurrences in slice a + counts := make(map[string]int) + for _, v := range a { + counts[v]++ } - - // Set up ticker for one minute intervals - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-stopPing: - logger.Info("Stopping ping messages") - return - case <-ticker.C: - if err := sendPing(olm); err != nil { - logger.Error("Failed to send periodic ping: %v", err) - } + // Check if slice b has the same elements + for _, v := range b { + counts[v]-- + if counts[v] < 0 { + return false } } + return true } -func GetNetworkSettingsJSON() (string, error) { - return network.GetJSON() -} - -func GetNetworkSettingsIncrementor() int { - return network.GetIncrementor() +// aliasesEqual compares two Alias slices for equality (order-independent) +func aliasesEqual(a, b []peers.Alias) bool { + if len(a) != len(b) { + return false + } + // Create a map to count occurrences in slice a (using alias+address as key) + counts := make(map[string]int) + for _, v := range a { + key := v.Alias + "|" + v.AliasAddress + counts[key]++ + } + // Check if slice b has the same elements + for _, v := range b { + key := v.Alias + "|" + v.AliasAddress + counts[key]-- + if counts[key] < 0 { + return false + } + } + return true } diff --git a/peers/manager.go b/peers/manager.go index af781e5..0566775 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -50,6 +50,8 @@ type PeerManager struct { // key is the CIDR string, value is a set of siteIds that want this IP allowedIPClaims map[string]map[int]bool APIServer *api.API + + PersistentKeepalive int } // NewPeerManager creates a new PeerManager with an internal PeerMonitor @@ -84,6 +86,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) { return peer, ok } +// GetPeerMonitor returns the internal peer monitor instance +func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor { + pm.mu.RLock() + defer pm.mu.RUnlock() + return pm.peerMonitor +} + func (pm *PeerManager) GetAllPeers() []SiteConfig { pm.mu.RLock() defer pm.mu.RUnlock() @@ -120,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -159,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { return nil } +// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once +// without recreating them. Returns a map of siteId to error for any peers that failed to update. +func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error { + pm.mu.RLock() + defer pm.mu.RUnlock() + + pm.PersistentKeepalive = interval + + errors := make(map[int]error) + + for siteId, peer := range pm.peers { + err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval) + if err != nil { + errors[siteId] = err + } + } + + if len(errors) == 0 { + return nil + } + return errors +} + func (pm *PeerManager) RemovePeer(siteId int) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -238,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error { ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId) wgConfig := promotedPeer wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } @@ -314,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -324,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId) promotedWgConfig := promotedPeer promotedWgConfig.AllowedIps = promotedOwnedIPs - if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 27bc408..28d92ef 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -31,8 +31,7 @@ type PeerMonitor struct { monitors map[int]*Client mutex sync.Mutex running bool - interval time.Duration - timeout time.Duration + timeout time.Duration maxAttempts int wsClient *websocket.Client @@ -42,7 +41,7 @@ type PeerMonitor struct { stack *stack.Stack ep *channel.Endpoint activePorts map[uint16]bool - portsLock sync.Mutex + portsLock sync.RWMutex nsCtx context.Context nsCancel context.CancelFunc nsWg sync.WaitGroup @@ -50,17 +49,26 @@ type PeerMonitor struct { // Holepunch testing fields sharedBind *bind.SharedBind holepunchTester *holepunch.HolepunchTester - holepunchInterval time.Duration holepunchTimeout time.Duration holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchStatus map[int]bool // siteID -> connected status - holepunchStopChan chan struct{} + holepunchStopChan chan struct{} + holepunchUpdateChan chan struct{} // Relay tracking fields relayedPeers map[int]bool // siteID -> whether the peer is currently relayed holepunchMaxAttempts int // max consecutive failures before triggering relay holepunchFailures map[int]int // siteID -> consecutive failure count + // Exponential backoff fields for holepunch monitor + defaultHolepunchMinInterval time.Duration // Minimum interval (initial) + defaultHolepunchMaxInterval time.Duration + holepunchMinInterval time.Duration // Minimum interval (initial) + holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) + holepunchBackoffMultiplier float64 // Multiplier for each stable check + holepunchStableCount map[int]int // siteID -> consecutive stable status count + holepunchCurrentInterval time.Duration // Current interval with backoff applied + // Rapid initial test fields rapidTestInterval time.Duration // interval between rapid test attempts rapidTestTimeout time.Duration // timeout for each rapid test attempt @@ -78,7 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), - interval: 2 * time.Second, // Default check interval (faster) timeout: 3 * time.Second, maxAttempts: 3, wsClient: wsClient, @@ -88,7 +95,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe nsCtx: ctx, nsCancel: cancel, sharedBind: sharedBind, - holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds holepunchTimeout: 2 * time.Second, // Faster timeout holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), @@ -101,6 +107,15 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total apiServer: apiServer, wgConnectionStatus: make(map[int]bool), + // Exponential backoff settings for holepunch monitor + defaultHolepunchMinInterval: 2 * time.Second, + defaultHolepunchMaxInterval: 30 * time.Second, + holepunchMinInterval: 2 * time.Second, + holepunchMaxInterval: 30 * time.Second, + holepunchBackoffMultiplier: 1.5, + holepunchStableCount: make(map[int]int), + holepunchCurrentInterval: 2 * time.Second, + holepunchUpdateChan: make(chan struct{}, 1), } if err := pm.initNetstack(); err != nil { @@ -116,41 +131,75 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe } // SetInterval changes how frequently peers are checked -func (pm *PeerMonitor) SetInterval(interval time.Duration) { +func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.interval = interval - // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetPacketInterval(interval) + client.SetPacketInterval(minInterval, maxInterval) } + + logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval) } -// SetTimeout changes the timeout for waiting for responses -func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { +func (pm *PeerMonitor) ResetPeerInterval() { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.timeout = timeout - - // Update timeout for all existing monitors + // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetTimeout(timeout) + client.ResetPacketInterval() } } -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (pm *PeerMonitor) SetMaxAttempts(attempts int) { +// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) { + pm.mutex.Lock() + pm.holepunchMinInterval = minInterval + pm.holepunchMaxInterval = maxInterval + // Reset current interval to the new minimum + pm.holepunchCurrentInterval = minInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } +} + +// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.maxAttempts = attempts + return pm.holepunchMinInterval, pm.holepunchMaxInterval +} - // Update max attempts for all existing monitors - for _, client := range pm.monitors { - client.SetMaxAttempts(attempts) +func (pm *PeerMonitor) ResetPeerHolepunchInterval() { + pm.mutex.Lock() + pm.holepunchMinInterval = pm.defaultHolepunchMinInterval + pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval + pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } } } @@ -169,10 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st return err } - client.SetPacketInterval(pm.interval) - client.SetTimeout(pm.timeout) - client.SetMaxAttempts(pm.maxAttempts) - pm.monitors[siteID] = client pm.holepunchEndpoints[siteID] = holepunchEndpoint @@ -470,31 +515,59 @@ func (pm *PeerMonitor) stopHolepunchMonitor() { logger.Info("Stopped holepunch connection monitor") } -// runHolepunchMonitor runs the holepunch monitoring loop +// runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff func (pm *PeerMonitor) runHolepunchMonitor() { - ticker := time.NewTicker(pm.holepunchInterval) - defer ticker.Stop() + pm.mutex.Lock() + pm.holepunchCurrentInterval = pm.holepunchMinInterval + pm.mutex.Unlock() - // Do initial check immediately - pm.checkHolepunchEndpoints() + timer := time.NewTimer(0) // Fire immediately for initial check + defer timer.Stop() for { select { case <-pm.holepunchStopChan: return - case <-ticker.C: - pm.checkHolepunchEndpoints() + case <-pm.holepunchUpdateChan: + // Interval settings changed, reset to minimum + pm.mutex.Lock() + pm.holepunchCurrentInterval = pm.holepunchMinInterval + currentInterval := pm.holepunchCurrentInterval + pm.mutex.Unlock() + + timer.Reset(currentInterval) + logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval) + case <-timer.C: + anyStatusChanged := pm.checkHolepunchEndpoints() + + pm.mutex.Lock() + if anyStatusChanged { + // Reset to minimum interval on any status change + pm.holepunchCurrentInterval = pm.holepunchMinInterval + } else { + // Apply exponential backoff when stable + newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier) + if newInterval > pm.holepunchMaxInterval { + newInterval = pm.holepunchMaxInterval + } + pm.holepunchCurrentInterval = newInterval + } + currentInterval := pm.holepunchCurrentInterval + pm.mutex.Unlock() + + timer.Reset(currentInterval) } } } // checkHolepunchEndpoints tests all holepunch endpoints -func (pm *PeerMonitor) checkHolepunchEndpoints() { +// Returns true if any endpoint's status changed +func (pm *PeerMonitor) checkHolepunchEndpoints() bool { pm.mutex.Lock() // Check if we're still running before doing any work if !pm.running { pm.mutex.Unlock() - return + return false } endpoints := make(map[int]string, len(pm.holepunchEndpoints)) for siteID, endpoint := range pm.holepunchEndpoints { @@ -504,8 +577,10 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { maxAttempts := pm.holepunchMaxAttempts pm.mutex.Unlock() + anyStatusChanged := false + for siteID, endpoint := range endpoints { - // logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) + // logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint) result := pm.holepunchTester.TestEndpoint(endpoint, timeout) pm.mutex.Lock() @@ -529,7 +604,9 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { pm.mutex.Unlock() // Log status changes - if !exists || previousStatus != result.Success { + statusChanged := !exists || previousStatus != result.Success + if statusChanged { + anyStatusChanged = true if result.Success { logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT) } else { @@ -562,7 +639,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { pm.mutex.Unlock() if !stillRunning { - return // Stop processing if shutdown is in progress + return anyStatusChanged // Stop processing if shutdown is in progress } if !result.Success && !isRelayed && failureCount >= maxAttempts { @@ -579,6 +656,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { } } } + + return anyStatusChanged } // GetHolepunchStatus returns the current holepunch status for all endpoints @@ -650,55 +729,55 @@ func (pm *PeerMonitor) Close() { logger.Debug("PeerMonitor: Cleanup complete") } -// TestPeer tests connectivity to a specific peer -func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { - pm.mutex.Lock() - client, exists := pm.monitors[siteID] - pm.mutex.Unlock() +// // TestPeer tests connectivity to a specific peer +// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { +// pm.mutex.Lock() +// client, exists := pm.monitors[siteID] +// pm.mutex.Unlock() - if !exists { - return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) - } +// if !exists { +// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) +// } - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - defer cancel() +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// defer cancel() - connected, rtt := client.TestConnection(ctx) - return connected, rtt, nil -} +// connected, rtt := client.TestPeerConnection(ctx) +// return connected, rtt, nil +// } -// TestAllPeers tests connectivity to all peers -func (pm *PeerMonitor) TestAllPeers() map[int]struct { - Connected bool - RTT time.Duration -} { - pm.mutex.Lock() - peers := make(map[int]*Client, len(pm.monitors)) - for siteID, client := range pm.monitors { - peers[siteID] = client - } - pm.mutex.Unlock() +// // TestAllPeers tests connectivity to all peers +// func (pm *PeerMonitor) TestAllPeers() map[int]struct { +// Connected bool +// RTT time.Duration +// } { +// pm.mutex.Lock() +// peers := make(map[int]*Client, len(pm.monitors)) +// for siteID, client := range pm.monitors { +// peers[siteID] = client +// } +// pm.mutex.Unlock() - results := make(map[int]struct { - Connected bool - RTT time.Duration - }) - for siteID, client := range peers { - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - connected, rtt := client.TestConnection(ctx) - cancel() +// results := make(map[int]struct { +// Connected bool +// RTT time.Duration +// }) +// for siteID, client := range peers { +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// connected, rtt := client.TestPeerConnection(ctx) +// cancel() - results[siteID] = struct { - Connected bool - RTT time.Duration - }{ - Connected: connected, - RTT: rtt, - } - } +// results[siteID] = struct { +// Connected bool +// RTT time.Duration +// }{ +// Connected: connected, +// RTT: rtt, +// } +// } - return results -} +// return results +// } // initNetstack initializes the gvisor netstack func (pm *PeerMonitor) initNetstack() error { @@ -770,9 +849,9 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool { } // Check if we are listening on this port - pm.portsLock.Lock() + pm.portsLock.RLock() active := pm.activePorts[uint16(port)] - pm.portsLock.Unlock() + pm.portsLock.RUnlock() if !active { return false @@ -803,13 +882,12 @@ func (pm *PeerMonitor) runPacketSender() { defer pm.nsWg.Done() logger.Debug("PeerMonitor: Packet sender goroutine started") - // Use a ticker to periodically check for packets without blocking indefinitely - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-pm.nsCtx.Done(): + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := pm.ep.ReadContext(pm.nsCtx) + if pkt == nil { + // Context was cancelled or endpoint closed logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets") // Drain any remaining packets before exiting for { @@ -821,36 +899,28 @@ func (pm *PeerMonitor) runPacketSender() { } logger.Debug("PeerMonitor: Packet sender goroutine exiting") return - case <-ticker.C: - // Try to read packets in batches - for i := 0; i < 10; i++ { - pkt := pm.ep.Read() - if pkt == nil { - break - } - - // Extract packet data - slices := pkt.AsSlices() - if len(slices) > 0 { - var totalSize int - for _, slice := range slices { - totalSize += len(slice) - } - - buf := make([]byte, totalSize) - pos := 0 - for _, slice := range slices { - copy(buf[pos:], slice) - pos += len(slice) - } - - // Inject into MiddleDevice (outbound to WG) - pm.middleDev.InjectOutbound(buf) - } - - pkt.DecRef() - } } + + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + pm.middleDev.InjectOutbound(buf) + } + + pkt.DecRef() } } diff --git a/peers/monitor/wgtester.go b/peers/monitor/wgtester.go index dac2008..e9f6f63 100644 --- a/peers/monitor/wgtester.go +++ b/peers/monitor/wgtester.go @@ -32,10 +32,19 @@ type Client struct { monitorLock sync.Mutex connLock sync.Mutex // Protects connection operations shutdownCh chan struct{} + updateCh chan struct{} packetInterval time.Duration timeout time.Duration maxAttempts int dialer Dialer + + // Exponential backoff fields + defaultMinInterval time.Duration // Default minimum interval (initial) + defaultMaxInterval time.Duration // Default maximum interval (cap for backoff) + minInterval time.Duration // Minimum interval (initial) + maxInterval time.Duration // Maximum interval (cap for backoff) + backoffMultiplier float64 // Multiplier for each stable check + stableCountToBackoff int // Number of stable checks before backing off } // Dialer is a function that creates a connection @@ -50,28 +59,59 @@ type ConnectionStatus struct { // NewClient creates a new connection test client func NewClient(serverAddr string, dialer Dialer) (*Client, error) { return &Client{ - serverAddr: serverAddr, - shutdownCh: make(chan struct{}), - packetInterval: 2 * time.Second, - timeout: 500 * time.Millisecond, // Timeout for individual packets - maxAttempts: 3, // Default max attempts - dialer: dialer, + serverAddr: serverAddr, + shutdownCh: make(chan struct{}), + updateCh: make(chan struct{}, 1), + packetInterval: 2 * time.Second, + defaultMinInterval: 2 * time.Second, + defaultMaxInterval: 30 * time.Second, + minInterval: 2 * time.Second, + maxInterval: 30 * time.Second, + backoffMultiplier: 1.5, + stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off + timeout: 500 * time.Millisecond, // Timeout for individual packets + maxAttempts: 3, // Default max attempts + dialer: dialer, }, nil } // SetPacketInterval changes how frequently packets are sent in monitor mode -func (c *Client) SetPacketInterval(interval time.Duration) { - c.packetInterval = interval +func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) { + c.monitorLock.Lock() + c.packetInterval = minInterval + c.minInterval = minInterval + c.maxInterval = maxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() + + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } -// SetTimeout changes the timeout for waiting for responses -func (c *Client) SetTimeout(timeout time.Duration) { - c.timeout = timeout -} +func (c *Client) ResetPacketInterval() { + c.monitorLock.Lock() + c.packetInterval = c.defaultMinInterval + c.minInterval = c.defaultMinInterval + c.maxInterval = c.defaultMaxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (c *Client) SetMaxAttempts(attempts int) { - c.maxAttempts = attempts + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } // UpdateServerAddr updates the server address and resets the connection @@ -125,9 +165,10 @@ func (c *Client) ensureConnection() error { return nil } -// TestConnection checks if the connection to the server is working +// TestPeerConnection checks if the connection to the server is working // Returns true if connected, false otherwise -func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { +func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) { + // logger.Debug("wgtester: testing connection to peer %s", c.serverAddr) if err := c.ensureConnection(); err != nil { logger.Warn("Failed to ensure connection: %v", err) return false, 0 @@ -138,6 +179,9 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { binary.BigEndian.PutUint32(packet[0:4], magicHeader) packet[4] = packetTypeRequest + // Reusable response buffer + responseBuffer := make([]byte, packetSize) + // Send multiple attempts as specified for attempt := 0; attempt < c.maxAttempts; attempt++ { select { @@ -157,20 +201,17 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { return false, 0 } - // logger.Debug("Attempting to send monitor packet to %s", c.serverAddr) _, err := c.conn.Write(packet) if err != nil { c.connLock.Unlock() logger.Info("Error sending packet: %v", err) continue } - // logger.Debug("Successfully sent monitor packet") // Set read deadline c.conn.SetReadDeadline(time.Now().Add(c.timeout)) // Wait for response - responseBuffer := make([]byte, packetSize) n, err := c.conn.Read(responseBuffer) c.connLock.Unlock() @@ -211,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - return c.TestConnection(ctx) + return c.TestPeerConnection(ctx) } // MonitorCallback is the function type for connection status change callbacks @@ -238,28 +279,61 @@ func (c *Client) StartMonitor(callback MonitorCallback) error { go func() { var lastConnected bool firstRun := true + stableCount := 0 + currentInterval := c.minInterval - ticker := time.NewTicker(c.packetInterval) - defer ticker.Stop() + timer := time.NewTimer(currentInterval) + defer timer.Stop() for { select { case <-c.shutdownCh: return - case <-ticker.C: + case <-c.updateCh: + // Interval settings changed, reset to minimum + c.monitorLock.Lock() + currentInterval = c.minInterval + c.monitorLock.Unlock() + + // Reset backoff state + stableCount = 0 + + timer.Reset(currentInterval) + logger.Debug("Packet interval updated, reset to %v", currentInterval) + case <-timer.C: ctx, cancel := context.WithTimeout(context.Background(), c.timeout) - connected, rtt := c.TestConnection(ctx) + connected, rtt := c.TestPeerConnection(ctx) cancel() + statusChanged := connected != lastConnected + // Callback if status changed or it's the first check - if connected != lastConnected || firstRun { + if statusChanged || firstRun { callback(ConnectionStatus{ Connected: connected, RTT: rtt, }) lastConnected = connected firstRun = false + // Reset backoff on status change + stableCount = 0 + currentInterval = c.minInterval + } else { + // Status is stable, increment counter + stableCount++ + + // Apply exponential backoff after stable threshold + if stableCount >= c.stableCountToBackoff { + newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier) + if newInterval > c.maxInterval { + newInterval = c.maxInterval + } + currentInterval = newInterval + } } + + // Reset timer with current interval + timer.Reset(currentInterval) } } }() diff --git a/peers/peer.go b/peers/peer.go index 9370b9d..8211fa4 100644 --- a/peers/peer.go +++ b/peers/peer.go @@ -11,7 +11,7 @@ import ( ) // ConfigurePeer sets up or updates a peer within the WireGuard device -func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error { +func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error { var endpoint string if relay && siteConfig.RelayEndpoint != "" { endpoint = formatEndpoint(siteConfig.RelayEndpoint) @@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes } configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) - configBuilder.WriteString("persistent_keepalive_interval=5\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive)) config := configBuilder.String() logger.Debug("Configuring peer with config: %s", config) @@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [ return nil } +// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it +func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error { + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) + configBuilder.WriteString("update_only=true\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval)) + + config := configBuilder.String() + logger.Debug("Updating persistent keepalive for peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err) + } + + return nil +} + func formatEndpoint(endpoint string) string { if strings.Contains(endpoint, ":") { return endpoint diff --git a/websocket/client.go b/websocket/client.go index 1c5afaf..c4e67b0 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -54,8 +55,9 @@ type ExitNode struct { } type WSMessage struct { - Type string `json:"type"` - Data interface{} `json:"data"` + Type string `json:"type"` + Data interface{} `json:"data"` + ConfigVersion int `json:"configVersion,omitempty"` } // this is not json anymore @@ -77,6 +79,7 @@ type Client struct { handlersMux sync.RWMutex reconnectInterval time.Duration isConnected bool + isDisconnected bool // Flag to track if client is intentionally disconnected reconnectMux sync.RWMutex pingInterval time.Duration pingTimeout time.Duration @@ -87,6 +90,19 @@ type Client struct { clientType string // Type of client (e.g., "newt", "olm") tlsConfig TLSConfig configNeedsSave bool // Flag to track if config needs to be saved + configVersion int // Latest config version received from server + configVersionMux sync.RWMutex + token string // Cached authentication token + exitNodes []ExitNode // Cached exit nodes from token response + tokenMux sync.RWMutex // Protects token and exitNodes + forceNewToken bool // Flag to force fetching a new token on next connection + processingMessage bool // Flag to track if a message is currently being processed + processingMux sync.RWMutex // Protects processingMessage + processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete + getPingData func() map[string]any // Callback to get additional ping data + pingStarted bool // Flag to track if ping monitor has been started + pingStartedMux sync.Mutex // Protects pingStarted + pingDone chan struct{} // Channel to stop the ping monitor independently } type ClientOption func(*Client) @@ -122,6 +138,13 @@ func WithTLSConfig(config TLSConfig) ClientOption { } } +// WithPingDataProvider sets a callback to provide additional data for ping messages +func WithPingDataProvider(fn func() map[string]any) ClientOption { + return func(c *Client) { + c.getPingData = fn + } +} + func (c *Client) OnConnect(callback func() error) { c.onConnect = callback } @@ -154,6 +177,7 @@ func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time. pingInterval: pingInterval, pingTimeout: pingTimeout, clientType: "olm", + pingDone: make(chan struct{}), } // Apply options before loading config @@ -173,6 +197,9 @@ func (c *Client) GetConfig() *Config { // Connect establishes the WebSocket connection func (c *Client) Connect() error { + if c.isDisconnected { + c.isDisconnected = false + } go c.connectWithRetry() return nil } @@ -205,9 +232,31 @@ func (c *Client) Close() error { return nil } +// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later. +func (c *Client) Disconnect() error { + c.isDisconnected = true + c.setConnected(false) + + // Stop the ping monitor + c.stopPingMonitor() + + // Wait for any message currently being processed to complete + c.processingWg.Wait() + + if c.conn != nil { + c.writeMux.Lock() + c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + c.writeMux.Unlock() + err := c.conn.Close() + c.conn = nil + return err + } + return nil +} + // SendMessage sends a message through the WebSocket connection func (c *Client) SendMessage(messageType string, data interface{}) error { - if c.conn == nil { + if c.isDisconnected || c.conn == nil { return fmt.Errorf("not connected") } @@ -216,14 +265,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { Data: data, } - logger.Debug("Sending message: %s, data: %+v", messageType, data) + logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data) c.writeMux.Lock() defer c.writeMux.Unlock() return c.conn.WriteJSON(msg) } -func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) { +func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) { stopChan := make(chan struct{}) updateChan := make(chan interface{}) var dataMux sync.Mutex @@ -231,30 +280,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter go func() { count := 0 - maxAttempts := 10 - err := c.SendMessage(messageType, currentData) // Send immediately - if err != nil { - logger.Error("Failed to send initial message: %v", err) + send := func() { + if c.isDisconnected || c.conn == nil { + return + } + err := c.SendMessage(messageType, currentData) + if err != nil { + logger.Error("websocket: Failed to send message: %v", err) + } + count++ } - count++ + + send() // Send immediately ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: - if count >= maxAttempts { - logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) + if maxAttempts != -1 && count >= maxAttempts { + logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) return } dataMux.Lock() - err = c.SendMessage(messageType, currentData) + send() dataMux.Unlock() - if err != nil { - logger.Error("Failed to send message: %v", err) - } - count++ case newData := <-updateChan: dataMux.Lock() // Merge newData into currentData if both are maps @@ -277,6 +328,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter case <-stopChan: return } + // Suspend sending if disconnected + for c.isDisconnected { + select { + case <-stopChan: + return + case <-time.After(500 * time.Millisecond): + } + } } }() return func() { @@ -323,7 +382,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { tlsConfig = &tls.Config{} } tlsConfig.InsecureSkipVerify = true - logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } tokenData := map[string]interface{}{ @@ -352,7 +411,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { req.Header.Set("X-CSRF-Token", "x-csrf-protection") // print out the request for debugging - logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) + logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) // Make the request client := &http.Client{} @@ -369,7 +428,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) // Return AuthError for 401/403 status codes if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { @@ -385,7 +444,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - logger.Error("Failed to decode token response.") + logger.Error("websocket: Failed to decode token response.") return "", nil, fmt.Errorf("failed to decode token response: %w", err) } @@ -397,7 +456,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { return "", nil, fmt.Errorf("received empty token from server") } - logger.Debug("Received token: %s", tokenResp.Data.Token) + logger.Debug("websocket: Received token: %s", tokenResp.Data.Token) return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil } @@ -411,7 +470,8 @@ func (c *Client) connectWithRetry() { err := c.establishConnection() if err != nil { // Check if this is an auth error (401/403) - if authErr, ok := err.(*AuthError); ok { + var authErr *AuthError + if errors.As(err, &authErr) { logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr) // Trigger auth error callback if set (this should terminate the tunnel) if c.onAuthError != nil { @@ -422,7 +482,7 @@ func (c *Client) connectWithRetry() { continue } // For other errors (5xx, network issues), continue retrying - logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) + logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) time.Sleep(c.reconnectInterval) continue } @@ -432,15 +492,25 @@ func (c *Client) connectWithRetry() { } func (c *Client) establishConnection() error { - // Get token for authentication - token, exitNodes, err := c.getToken() - if err != nil { - return fmt.Errorf("failed to get token: %w", err) - } + // Get token for authentication - reuse cached token unless forced to get new one + c.tokenMux.Lock() + needNewToken := c.token == "" || c.forceNewToken + if needNewToken { + token, exitNodes, err := c.getToken() + if err != nil { + c.tokenMux.Unlock() + return fmt.Errorf("failed to get token: %w", err) + } + c.token = token + c.exitNodes = exitNodes + c.forceNewToken = false - if c.onTokenUpdate != nil { - c.onTokenUpdate(token, exitNodes) + if c.onTokenUpdate != nil { + c.onTokenUpdate(token, exitNodes) + } } + token := c.token + c.tokenMux.Unlock() // Parse the base URL to determine protocol and hostname baseURL, err := url.Parse(c.baseURL) @@ -475,7 +545,7 @@ func (c *Client) establishConnection() error { // Use new TLS configuration method if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { - logger.Info("Setting up TLS configuration for WebSocket connection") + logger.Info("websocket: Setting up TLS configuration for WebSocket connection") tlsConfig, err := c.setupTLS() if err != nil { return fmt.Errorf("failed to setup TLS configuration: %w", err) @@ -489,25 +559,38 @@ func (c *Client) establishConnection() error { dialer.TLSClientConfig = &tls.Config{} } dialer.TLSClientConfig.InsecureSkipVerify = true - logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } - conn, _, err := dialer.Dial(u.String(), nil) + conn, resp, err := dialer.Dial(u.String(), nil) if err != nil { + // Check if this is an unauthorized error (401) + if resp != nil && resp.StatusCode == http.StatusUnauthorized { + logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized") + // Force getting a new token on next reconnect attempt + c.tokenMux.Lock() + c.forceNewToken = true + c.tokenMux.Unlock() + return &AuthError{ + StatusCode: http.StatusUnauthorized, + Message: "WebSocket connection unauthorized", + } + } return fmt.Errorf("failed to connect to WebSocket: %w", err) } c.conn = conn c.setConnected(true) - // Start the ping monitor - go c.pingMonitor() + // Note: ping monitor is NOT started here - it will be started when + // StartPingMonitor() is called after registration completes + // Start the read pump with disconnect detection go c.readPumpWithDisconnectDetection() if c.onConnect != nil { if err := c.onConnect(); err != nil { - logger.Error("OnConnect callback failed: %v", err) + logger.Error("websocket: OnConnect callback failed: %v", err) } } @@ -520,9 +603,9 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Handle new separate certificate configuration if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" { - logger.Info("Loading separate certificate files for mTLS") - logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile) - logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile) + logger.Info("websocket: Loading separate certificate files for mTLS") + logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile) + logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile) // Load client certificate and key cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile) @@ -533,7 +616,7 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Load CA certificates for remote validation if specified if len(c.tlsConfig.CAFiles) > 0 { - logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles) + logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles) caCertPool := x509.NewCertPool() for _, caFile := range c.tlsConfig.CAFiles { caCert, err := os.ReadFile(caFile) @@ -559,13 +642,13 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Fallback to existing PKCS12 implementation for backward compatibility if c.tlsConfig.PKCS12File != "" { - logger.Info("Loading PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)") return c.setupPKCS12TLS() } // Legacy fallback using config.TlsClientCert if c.config.TlsClientCert != "" { - logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)") return loadClientCertificate(c.config.TlsClientCert) } @@ -577,6 +660,59 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) { return loadClientCertificate(c.tlsConfig.PKCS12File) } +// sendPing sends a single ping message +func (c *Client) sendPing() { + if c.isDisconnected || c.conn == nil { + return + } + // Skip ping if a message is currently being processed + c.processingMux.RLock() + isProcessing := c.processingMessage + c.processingMux.RUnlock() + if isProcessing { + logger.Debug("websocket: Skipping ping, message is being processed") + return + } + // Send application-level ping with config version + c.configVersionMux.RLock() + configVersion := c.configVersion + c.configVersionMux.RUnlock() + + pingData := map[string]any{ + "timestamp": time.Now().Unix(), + "userToken": c.config.UserToken, + } + if c.getPingData != nil { + for k, v := range c.getPingData() { + pingData[k] = v + } + } + + pingMsg := WSMessage{ + Type: "olm/ping", + Data: pingData, + ConfigVersion: configVersion, + } + + logger.Debug("websocket: Sending ping: %+v", pingMsg) + + c.writeMux.Lock() + err := c.conn.WriteJSON(pingMsg) + c.writeMux.Unlock() + if err != nil { + // Check if we're shutting down before logging error and reconnecting + select { + case <-c.done: + // Expected during shutdown + return + default: + logger.Error("websocket: Ping failed: %v", err) + c.reconnect() + return + } + } +} + // pingMonitor sends pings at a short interval and triggers reconnect on failure func (c *Client) pingMonitor() { ticker := time.NewTicker(c.pingInterval) @@ -586,29 +722,65 @@ func (c *Client) pingMonitor() { select { case <-c.done: return + case <-c.pingDone: + return case <-ticker.C: - if c.conn == nil { - return - } - c.writeMux.Lock() - err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout)) - c.writeMux.Unlock() - if err != nil { - // Check if we're shutting down before logging error and reconnecting - select { - case <-c.done: - // Expected during shutdown - return - default: - logger.Error("Ping failed: %v", err) - c.reconnect() - return - } - } + c.sendPing() } } } +// StartPingMonitor starts the ping monitor goroutine. +// This should be called after the client is registered and connected. +// It is safe to call multiple times - only the first call will start the monitor. +func (c *Client) StartPingMonitor() { + c.pingStartedMux.Lock() + defer c.pingStartedMux.Unlock() + + if c.pingStarted { + return + } + c.pingStarted = true + + // Create a new pingDone channel for this ping monitor instance + c.pingDone = make(chan struct{}) + + // Send an initial ping immediately + go func() { + c.sendPing() + c.pingMonitor() + }() +} + +// stopPingMonitor stops the ping monitor goroutine if it's running. +func (c *Client) stopPingMonitor() { + c.pingStartedMux.Lock() + defer c.pingStartedMux.Unlock() + + if !c.pingStarted { + return + } + + // Close the pingDone channel to stop the monitor + close(c.pingDone) + c.pingStarted = false +} + +// GetConfigVersion returns the current config version +func (c *Client) GetConfigVersion() int { + c.configVersionMux.RLock() + defer c.configVersionMux.RUnlock() + return c.configVersion +} + +// setConfigVersion updates the config version if the new version is higher +func (c *Client) setConfigVersion(version int) { + c.configVersionMux.Lock() + defer c.configVersionMux.Unlock() + logger.Debug("websocket: setting config version to %d", version) + c.configVersion = version +} + // readPumpWithDisconnectDetection reads messages and triggers reconnect on error func (c *Client) readPumpWithDisconnectDetection() { defer func() { @@ -633,26 +805,47 @@ func (c *Client) readPumpWithDisconnectDetection() { var msg WSMessage err := c.conn.ReadJSON(&msg) if err != nil { - // Check if we're shutting down before logging error + // Check if we're shutting down or explicitly disconnected before logging error select { case <-c.done: // Expected during shutdown, don't log as error - logger.Debug("WebSocket connection closed during shutdown") + logger.Debug("websocket: connection closed during shutdown") return default: + // Check if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: connection closed: client was explicitly disconnected") + return + } + // Unexpected error during normal operation if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { - logger.Error("WebSocket read error: %v", err) + logger.Error("websocket: read error: %v", err) } else { - logger.Debug("WebSocket connection closed: %v", err) + logger.Debug("websocket: connection closed: %v", err) } return // triggers reconnect via defer } } + // Update config version from incoming message + c.setConfigVersion(msg.ConfigVersion) + c.handlersMux.RLock() if handler, ok := c.handlers[msg.Type]; ok { + // Mark that we're processing a message + c.processingMux.Lock() + c.processingMessage = true + c.processingMux.Unlock() + c.processingWg.Add(1) + handler(msg) + + // Mark that we're done processing + c.processingWg.Done() + c.processingMux.Lock() + c.processingMessage = false + c.processingMux.Unlock() } c.handlersMux.RUnlock() } @@ -666,6 +859,12 @@ func (c *Client) reconnect() { c.conn = nil } + // Don't reconnect if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected") + return + } + // Only reconnect if we're not shutting down select { case <-c.done: @@ -683,7 +882,7 @@ func (c *Client) setConnected(status bool) { // LoadClientCertificate Helper method to load client certificates (PKCS12 format) func loadClientCertificate(p12Path string) (*tls.Config, error) { - logger.Info("Loading tls-client-cert %s", p12Path) + logger.Info("websocket: Loading tls-client-cert %s", p12Path) // Read the PKCS12 file p12Data, err := os.ReadFile(p12Path) if err != nil {