mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Fix exit
This commit is contained in:
31
api/api.go
31
api/api.go
@@ -37,13 +37,14 @@ type SwitchOrgRequest struct {
|
||||
|
||||
// PeerStatus represents the status of a peer connection
|
||||
type PeerStatus struct {
|
||||
SiteID int `json:"siteId"`
|
||||
Connected bool `json:"connected"`
|
||||
RTT time.Duration `json:"rtt"`
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
IsRelay bool `json:"isRelay"`
|
||||
PeerIP string `json:"peerAddress,omitempty"`
|
||||
SiteID int `json:"siteId"`
|
||||
Connected bool `json:"connected"`
|
||||
RTT time.Duration `json:"rtt"`
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
IsRelay bool `json:"isRelay"`
|
||||
PeerIP string `json:"peerAddress,omitempty"`
|
||||
HolepunchConnected bool `json:"holepunchConnected"`
|
||||
}
|
||||
|
||||
// StatusResponse is returned by the status endpoint
|
||||
@@ -252,6 +253,22 @@ func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||
status.IsRelay = isRelay
|
||||
}
|
||||
|
||||
// UpdatePeerHolepunchStatus updates the holepunch connection status of a peer
|
||||
func (s *API) UpdatePeerHolepunchStatus(siteID int, holepunchConnected bool) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
|
||||
status, exists := s.peerStatuses[siteID]
|
||||
if !exists {
|
||||
status = &PeerStatus{
|
||||
SiteID: siteID,
|
||||
}
|
||||
s.peerStatuses[siteID] = status
|
||||
}
|
||||
|
||||
status.HolepunchConnected = holepunchConnected
|
||||
}
|
||||
|
||||
// handleConnect handles the /connect endpoint
|
||||
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
|
||||
21
main.go
21
main.go
@@ -155,14 +155,18 @@ func main() {
|
||||
}
|
||||
|
||||
// Create a context that will be cancelled on interrupt signals
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
signalCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
// Create a separate context for programmatic shutdown (e.g., via API exit)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Run in console mode
|
||||
runOlmMainWithArgs(ctx, os.Args[1:])
|
||||
runOlmMainWithArgs(ctx, cancel, signalCtx, os.Args[1:])
|
||||
}
|
||||
|
||||
func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCtx context.Context, args []string) {
|
||||
// Setup Windows event logging if on Windows
|
||||
if runtime.GOOS == "windows" {
|
||||
setupWindowsEventLog()
|
||||
@@ -211,6 +215,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
HTTPAddr: config.HTTPAddr,
|
||||
SocketPath: config.SocketPath,
|
||||
Version: config.Version,
|
||||
OnExit: cancel, // Pass cancel function directly to trigger shutdown
|
||||
}
|
||||
|
||||
olm.Init(ctx, olmConfig)
|
||||
@@ -242,9 +247,13 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
logger.Info("Incomplete tunnel configuration, not starting tunnel")
|
||||
}
|
||||
|
||||
// Wait for context cancellation (from signals or API shutdown)
|
||||
<-ctx.Done()
|
||||
logger.Info("Shutdown signal received, cleaning up...")
|
||||
// Wait for either signal or programmatic shutdown
|
||||
select {
|
||||
case <-signalCtx.Done():
|
||||
logger.Info("Shutdown signal received, cleaning up...")
|
||||
case <-ctx.Done():
|
||||
logger.Info("Shutdown requested via API, cleaning up...")
|
||||
}
|
||||
|
||||
// Clean up resources
|
||||
olm.Close()
|
||||
|
||||
25
olm/olm.go
25
olm/olm.go
@@ -97,10 +97,6 @@ func Init(ctx context.Context, config GlobalConfig) {
|
||||
globalConfig = config
|
||||
globalCtx = ctx
|
||||
|
||||
// Create a cancellable context for internal shutdown controconfiguration GlobalConfigl
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
||||
|
||||
if config.HTTPAddr != "" {
|
||||
@@ -194,7 +190,10 @@ func Init(ctx context.Context, config GlobalConfig) {
|
||||
// onExit
|
||||
func() error {
|
||||
logger.Info("Processing shutdown request via API")
|
||||
cancel()
|
||||
Close()
|
||||
if globalConfig.OnExit != nil {
|
||||
globalConfig.OnExit()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
@@ -419,6 +418,7 @@ func StartTunnel(config TunnelConfig) {
|
||||
} else {
|
||||
siteEndpoint = site.Endpoint
|
||||
}
|
||||
|
||||
apiServer.UpdatePeerStatus(site.SiteId, false, 0, siteEndpoint, false)
|
||||
|
||||
if err := peerManager.AddPeer(site, siteEndpoint); err != nil {
|
||||
@@ -483,6 +483,9 @@ func StartTunnel(config TunnelConfig) {
|
||||
if updateData.Endpoint != "" {
|
||||
siteConfig.Endpoint = updateData.Endpoint
|
||||
}
|
||||
if updateData.RelayEndpoint != "" {
|
||||
siteConfig.RelayEndpoint = updateData.RelayEndpoint
|
||||
}
|
||||
if updateData.PublicKey != "" {
|
||||
siteConfig.PublicKey = updateData.PublicKey
|
||||
}
|
||||
@@ -674,6 +677,12 @@ func StartTunnel(config TunnelConfig) {
|
||||
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received relay-peer message: %v", msg.Data)
|
||||
|
||||
// Check if peerManager is still valid (may be nil during shutdown)
|
||||
if peerManager == nil {
|
||||
logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
@@ -700,6 +709,12 @@ func StartTunnel(config TunnelConfig) {
|
||||
olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received unrelay-peer message: %v", msg.Data)
|
||||
|
||||
// Check if peerManager is still valid (may be nil during shutdown)
|
||||
if peerManager == nil {
|
||||
logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
|
||||
@@ -27,6 +27,7 @@ type GlobalConfig struct {
|
||||
OnConnected func()
|
||||
OnTerminated func()
|
||||
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
|
||||
OnExit func() // Called when exit is requested via API
|
||||
}
|
||||
|
||||
type TunnelConfig struct {
|
||||
|
||||
@@ -71,6 +71,7 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
|
||||
config.MiddleDev,
|
||||
config.LocalIP,
|
||||
config.SharedBind,
|
||||
config.APIServer,
|
||||
)
|
||||
|
||||
return pm
|
||||
@@ -233,6 +234,16 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
|
||||
return fmt.Errorf("peer with site ID %d not found", siteConfig.SiteId)
|
||||
}
|
||||
|
||||
// Determine which endpoint to use based on relay state
|
||||
// If the peer is currently relayed, use the relay endpoint; otherwise use the direct endpoint
|
||||
actualEndpoint := endpoint
|
||||
if pm.peerMonitor != nil && pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId) {
|
||||
if oldPeer.RelayEndpoint != "" {
|
||||
actualEndpoint = oldPeer.RelayEndpoint
|
||||
logger.Info("Peer %d is relayed, using relay endpoint: %s", siteConfig.SiteId, actualEndpoint)
|
||||
}
|
||||
}
|
||||
|
||||
// If public key changed, remove old peer first
|
||||
if siteConfig.PublicKey != oldPeer.PublicKey {
|
||||
if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil {
|
||||
@@ -284,7 +295,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
|
||||
wgConfig := siteConfig
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint); err != nil {
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, actualEndpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -359,6 +370,11 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
|
||||
|
||||
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
|
||||
|
||||
// Preserve the relay endpoint if the peer is relayed
|
||||
if pm.peerMonitor != nil && pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId) && oldPeer.RelayEndpoint != "" {
|
||||
siteConfig.RelayEndpoint = oldPeer.RelayEndpoint
|
||||
}
|
||||
|
||||
pm.peers[siteConfig.SiteId] = siteConfig
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/api"
|
||||
middleDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
@@ -59,16 +60,22 @@ type PeerMonitor struct {
|
||||
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
|
||||
holepunchMaxAttempts int // max consecutive failures before triggering relay
|
||||
holepunchFailures map[int]int // siteID -> consecutive failure count
|
||||
|
||||
// API server for status updates
|
||||
apiServer *api.API
|
||||
|
||||
// WG connection status tracking
|
||||
wgConnectionStatus map[int]bool // siteID -> WG connected status
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind) *PeerMonitor {
|
||||
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pm := &PeerMonitor{
|
||||
monitors: make(map[int]*Client),
|
||||
interval: 1 * time.Second, // Default check interval
|
||||
interval: 3 * time.Second, // Default check interval
|
||||
timeout: 5 * time.Second,
|
||||
maxAttempts: 5,
|
||||
maxAttempts: 3,
|
||||
wsClient: wsClient,
|
||||
middleDev: middleDev,
|
||||
localIP: localIP,
|
||||
@@ -76,13 +83,15 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
nsCtx: ctx,
|
||||
nsCancel: cancel,
|
||||
sharedBind: sharedBind,
|
||||
holepunchInterval: 5 * time.Second, // Check holepunch every 5 seconds
|
||||
holepunchInterval: 3 * time.Second, // Check holepunch every 5 seconds
|
||||
holepunchTimeout: 5 * time.Second,
|
||||
holepunchEndpoints: make(map[int]string),
|
||||
holepunchStatus: make(map[int]bool),
|
||||
relayedPeers: make(map[int]bool),
|
||||
holepunchMaxAttempts: 5, // Trigger relay after 5 consecutive failures
|
||||
holepunchMaxAttempts: 3, // Trigger relay after 5 consecutive failures
|
||||
holepunchFailures: make(map[int]int),
|
||||
apiServer: apiServer,
|
||||
wgConnectionStatus: make(map[int]bool),
|
||||
}
|
||||
|
||||
if err := pm.initNetstack(); err != nil {
|
||||
@@ -235,6 +244,26 @@ func (pm *PeerMonitor) Start() {
|
||||
|
||||
// handleConnectionStatusChange is called when a peer's connection status changes
|
||||
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) {
|
||||
pm.mutex.Lock()
|
||||
previousStatus, exists := pm.wgConnectionStatus[siteID]
|
||||
pm.wgConnectionStatus[siteID] = status.Connected
|
||||
isRelayed := pm.relayedPeers[siteID]
|
||||
endpoint := pm.holepunchEndpoints[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
// Log status changes
|
||||
if !exists || previousStatus != status.Connected {
|
||||
if status.Connected {
|
||||
logger.Info("WireGuard connection to site %d is CONNECTED (RTT: %v)", siteID, status.RTT)
|
||||
} else {
|
||||
logger.Warn("WireGuard connection to site %d is DISCONNECTED", siteID)
|
||||
}
|
||||
}
|
||||
|
||||
// Update API with connection status
|
||||
if pm.apiServer != nil {
|
||||
pm.apiServer.UpdatePeerStatus(siteID, status.Connected, status.RTT, endpoint, isRelayed)
|
||||
}
|
||||
}
|
||||
|
||||
// sendRelay sends a relay message to the server
|
||||
@@ -302,6 +331,13 @@ func (pm *PeerMonitor) MarkPeerRelayed(siteID int, relayed bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// IsPeerRelayed returns whether a peer is currently using relay
|
||||
func (pm *PeerMonitor) IsPeerRelayed(siteID int) bool {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
return pm.relayedPeers[siteID]
|
||||
}
|
||||
|
||||
// startHolepunchMonitor starts the holepunch connection monitoring
|
||||
// Note: This function assumes the mutex is already held by the caller (called from Start())
|
||||
func (pm *PeerMonitor) startHolepunchMonitor() error {
|
||||
@@ -364,6 +400,11 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
|
||||
// checkHolepunchEndpoints tests all holepunch endpoints
|
||||
func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
pm.mutex.Lock()
|
||||
// Check if we're still running before doing any work
|
||||
if !pm.running {
|
||||
pm.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
|
||||
for siteID, endpoint := range pm.holepunchEndpoints {
|
||||
endpoints[siteID] = endpoint
|
||||
@@ -402,7 +443,30 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
}
|
||||
}
|
||||
|
||||
// Update API with holepunch status
|
||||
if pm.apiServer != nil {
|
||||
// Update holepunch connection status
|
||||
pm.apiServer.UpdatePeerHolepunchStatus(siteID, result.Success)
|
||||
|
||||
// Get the current WG connection status for this peer
|
||||
pm.mutex.Lock()
|
||||
wgConnected := pm.wgConnectionStatus[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
// Update API - use holepunch endpoint and relay status
|
||||
pm.apiServer.UpdatePeerStatus(siteID, wgConnected, result.RTT, endpoint, isRelayed)
|
||||
}
|
||||
|
||||
// Handle relay logic based on holepunch status
|
||||
// Check if we're still running before sending relay messages
|
||||
pm.mutex.Lock()
|
||||
stillRunning := pm.running
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !stillRunning {
|
||||
return // Stop processing if shutdown is in progress
|
||||
}
|
||||
|
||||
if !result.Success && !isRelayed && failureCount >= maxAttempts {
|
||||
// Holepunch failed and we're not relayed - trigger relay
|
||||
logger.Info("Holepunch to site %d failed %d times, triggering relay", siteID, failureCount)
|
||||
|
||||
Reference in New Issue
Block a user