diff --git a/api/api.go b/api/api.go index d74e9c9..ffe9594 100644 --- a/api/api.go +++ b/api/api.go @@ -37,13 +37,14 @@ type SwitchOrgRequest struct { // PeerStatus represents the status of a peer connection type PeerStatus struct { - SiteID int `json:"siteId"` - Connected bool `json:"connected"` - RTT time.Duration `json:"rtt"` - LastSeen time.Time `json:"lastSeen"` - Endpoint string `json:"endpoint,omitempty"` - IsRelay bool `json:"isRelay"` - PeerIP string `json:"peerAddress,omitempty"` + SiteID int `json:"siteId"` + Connected bool `json:"connected"` + RTT time.Duration `json:"rtt"` + LastSeen time.Time `json:"lastSeen"` + Endpoint string `json:"endpoint,omitempty"` + IsRelay bool `json:"isRelay"` + PeerIP string `json:"peerAddress,omitempty"` + HolepunchConnected bool `json:"holepunchConnected"` } // StatusResponse is returned by the status endpoint @@ -252,6 +253,22 @@ func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) { status.IsRelay = isRelay } +// UpdatePeerHolepunchStatus updates the holepunch connection status of a peer +func (s *API) UpdatePeerHolepunchStatus(siteID int, holepunchConnected bool) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + + status, exists := s.peerStatuses[siteID] + if !exists { + status = &PeerStatus{ + SiteID: siteID, + } + s.peerStatuses[siteID] = status + } + + status.HolepunchConnected = holepunchConnected +} + // handleConnect handles the /connect endpoint func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { diff --git a/main.go b/main.go index 572886f..a652749 100644 --- a/main.go +++ b/main.go @@ -155,14 +155,18 @@ func main() { } // Create a context that will be cancelled on interrupt signals - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + signalCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() + // Create a separate context for programmatic shutdown (e.g., via API exit) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Run in console mode - runOlmMainWithArgs(ctx, os.Args[1:]) + runOlmMainWithArgs(ctx, cancel, signalCtx, os.Args[1:]) } -func runOlmMainWithArgs(ctx context.Context, args []string) { +func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCtx context.Context, args []string) { // Setup Windows event logging if on Windows if runtime.GOOS == "windows" { setupWindowsEventLog() @@ -211,6 +215,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { HTTPAddr: config.HTTPAddr, SocketPath: config.SocketPath, Version: config.Version, + OnExit: cancel, // Pass cancel function directly to trigger shutdown } olm.Init(ctx, olmConfig) @@ -242,9 +247,13 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { logger.Info("Incomplete tunnel configuration, not starting tunnel") } - // Wait for context cancellation (from signals or API shutdown) - <-ctx.Done() - logger.Info("Shutdown signal received, cleaning up...") + // Wait for either signal or programmatic shutdown + select { + case <-signalCtx.Done(): + logger.Info("Shutdown signal received, cleaning up...") + case <-ctx.Done(): + logger.Info("Shutdown requested via API, cleaning up...") + } // Clean up resources olm.Close() diff --git a/olm/olm.go b/olm/olm.go index 3035cbd..6c06032 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -97,10 +97,6 @@ func Init(ctx context.Context, config GlobalConfig) { globalConfig = config globalCtx = ctx - // Create a cancellable context for internal shutdown controconfiguration GlobalConfigl - ctx, cancel := context.WithCancel(ctx) - defer cancel() - logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) if config.HTTPAddr != "" { @@ -194,7 +190,10 @@ func Init(ctx context.Context, config GlobalConfig) { // onExit func() error { logger.Info("Processing shutdown request via API") - cancel() + Close() + if globalConfig.OnExit != nil { + globalConfig.OnExit() + } return nil }, ) @@ -419,6 +418,7 @@ func StartTunnel(config TunnelConfig) { } else { siteEndpoint = site.Endpoint } + apiServer.UpdatePeerStatus(site.SiteId, false, 0, siteEndpoint, false) if err := peerManager.AddPeer(site, siteEndpoint); err != nil { @@ -483,6 +483,9 @@ func StartTunnel(config TunnelConfig) { if updateData.Endpoint != "" { siteConfig.Endpoint = updateData.Endpoint } + if updateData.RelayEndpoint != "" { + siteConfig.RelayEndpoint = updateData.RelayEndpoint + } if updateData.PublicKey != "" { siteConfig.PublicKey = updateData.PublicKey } @@ -674,6 +677,12 @@ func StartTunnel(config TunnelConfig) { olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) { logger.Debug("Received relay-peer message: %v", msg.Data) + // Check if peerManager is still valid (may be nil during shutdown) + if peerManager == nil { + logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") + return + } + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) @@ -700,6 +709,12 @@ func StartTunnel(config TunnelConfig) { olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) { logger.Debug("Received unrelay-peer message: %v", msg.Data) + // Check if peerManager is still valid (may be nil during shutdown) + if peerManager == nil { + logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") + return + } + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) diff --git a/olm/types.go b/olm/types.go index 8504b77..8330f8d 100644 --- a/olm/types.go +++ b/olm/types.go @@ -27,6 +27,7 @@ type GlobalConfig struct { OnConnected func() OnTerminated func() OnAuthError func(statusCode int, message string) // Called when auth fails (401/403) + OnExit func() // Called when exit is requested via API } type TunnelConfig struct { diff --git a/peers/manager.go b/peers/manager.go index fe71a19..3c4a3a5 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -71,6 +71,7 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager { config.MiddleDev, config.LocalIP, config.SharedBind, + config.APIServer, ) return pm @@ -233,6 +234,16 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error return fmt.Errorf("peer with site ID %d not found", siteConfig.SiteId) } + // Determine which endpoint to use based on relay state + // If the peer is currently relayed, use the relay endpoint; otherwise use the direct endpoint + actualEndpoint := endpoint + if pm.peerMonitor != nil && pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId) { + if oldPeer.RelayEndpoint != "" { + actualEndpoint = oldPeer.RelayEndpoint + logger.Info("Peer %d is relayed, using relay endpoint: %s", siteConfig.SiteId, actualEndpoint) + } + } + // If public key changed, remove old peer first if siteConfig.PublicKey != oldPeer.PublicKey { if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil { @@ -284,7 +295,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, actualEndpoint); err != nil { return err } @@ -359,6 +370,11 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint) + // Preserve the relay endpoint if the peer is relayed + if pm.peerMonitor != nil && pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId) && oldPeer.RelayEndpoint != "" { + siteConfig.RelayEndpoint = oldPeer.RelayEndpoint + } + pm.peers[siteConfig.SiteId] = siteConfig return nil } diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 95a34ac..d2e1094 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -12,6 +12,7 @@ import ( "github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/util" + "github.com/fosrl/olm/api" middleDevice "github.com/fosrl/olm/device" "github.com/fosrl/olm/websocket" "gvisor.dev/gvisor/pkg/buffer" @@ -59,16 +60,22 @@ type PeerMonitor struct { relayedPeers map[int]bool // siteID -> whether the peer is currently relayed holepunchMaxAttempts int // max consecutive failures before triggering relay holepunchFailures map[int]int // siteID -> consecutive failure count + + // API server for status updates + apiServer *api.API + + // WG connection status tracking + wgConnectionStatus map[int]bool // siteID -> WG connected status } // NewPeerMonitor creates a new peer monitor with the given callback -func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind) *PeerMonitor { +func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor { ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), - interval: 1 * time.Second, // Default check interval + interval: 3 * time.Second, // Default check interval timeout: 5 * time.Second, - maxAttempts: 5, + maxAttempts: 3, wsClient: wsClient, middleDev: middleDev, localIP: localIP, @@ -76,13 +83,15 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe nsCtx: ctx, nsCancel: cancel, sharedBind: sharedBind, - holepunchInterval: 5 * time.Second, // Check holepunch every 5 seconds + holepunchInterval: 3 * time.Second, // Check holepunch every 5 seconds holepunchTimeout: 5 * time.Second, holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), relayedPeers: make(map[int]bool), - holepunchMaxAttempts: 5, // Trigger relay after 5 consecutive failures + holepunchMaxAttempts: 3, // Trigger relay after 5 consecutive failures holepunchFailures: make(map[int]int), + apiServer: apiServer, + wgConnectionStatus: make(map[int]bool), } if err := pm.initNetstack(); err != nil { @@ -235,6 +244,26 @@ func (pm *PeerMonitor) Start() { // handleConnectionStatusChange is called when a peer's connection status changes func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) { + pm.mutex.Lock() + previousStatus, exists := pm.wgConnectionStatus[siteID] + pm.wgConnectionStatus[siteID] = status.Connected + isRelayed := pm.relayedPeers[siteID] + endpoint := pm.holepunchEndpoints[siteID] + pm.mutex.Unlock() + + // Log status changes + if !exists || previousStatus != status.Connected { + if status.Connected { + logger.Info("WireGuard connection to site %d is CONNECTED (RTT: %v)", siteID, status.RTT) + } else { + logger.Warn("WireGuard connection to site %d is DISCONNECTED", siteID) + } + } + + // Update API with connection status + if pm.apiServer != nil { + pm.apiServer.UpdatePeerStatus(siteID, status.Connected, status.RTT, endpoint, isRelayed) + } } // sendRelay sends a relay message to the server @@ -302,6 +331,13 @@ func (pm *PeerMonitor) MarkPeerRelayed(siteID int, relayed bool) { } } +// IsPeerRelayed returns whether a peer is currently using relay +func (pm *PeerMonitor) IsPeerRelayed(siteID int) bool { + pm.mutex.Lock() + defer pm.mutex.Unlock() + return pm.relayedPeers[siteID] +} + // startHolepunchMonitor starts the holepunch connection monitoring // Note: This function assumes the mutex is already held by the caller (called from Start()) func (pm *PeerMonitor) startHolepunchMonitor() error { @@ -364,6 +400,11 @@ func (pm *PeerMonitor) runHolepunchMonitor() { // checkHolepunchEndpoints tests all holepunch endpoints func (pm *PeerMonitor) checkHolepunchEndpoints() { pm.mutex.Lock() + // Check if we're still running before doing any work + if !pm.running { + pm.mutex.Unlock() + return + } endpoints := make(map[int]string, len(pm.holepunchEndpoints)) for siteID, endpoint := range pm.holepunchEndpoints { endpoints[siteID] = endpoint @@ -402,7 +443,30 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { } } + // Update API with holepunch status + if pm.apiServer != nil { + // Update holepunch connection status + pm.apiServer.UpdatePeerHolepunchStatus(siteID, result.Success) + + // Get the current WG connection status for this peer + pm.mutex.Lock() + wgConnected := pm.wgConnectionStatus[siteID] + pm.mutex.Unlock() + + // Update API - use holepunch endpoint and relay status + pm.apiServer.UpdatePeerStatus(siteID, wgConnected, result.RTT, endpoint, isRelayed) + } + // Handle relay logic based on holepunch status + // Check if we're still running before sending relay messages + pm.mutex.Lock() + stillRunning := pm.running + pm.mutex.Unlock() + + if !stillRunning { + return // Stop processing if shutdown is in progress + } + if !result.Success && !isRelayed && failureCount >= maxAttempts { // Holepunch failed and we're not relayed - trigger relay logger.Info("Holepunch to site %d failed %d times, triggering relay", siteID, failureCount)