This commit is contained in:
Owen
2025-12-02 21:23:28 -05:00
parent 0b87070e31
commit a2334cc5af
6 changed files with 146 additions and 24 deletions

View File

@@ -37,13 +37,14 @@ type SwitchOrgRequest struct {
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
Connected bool `json:"connected"`
RTT time.Duration `json:"rtt"`
LastSeen time.Time `json:"lastSeen"`
Endpoint string `json:"endpoint,omitempty"`
IsRelay bool `json:"isRelay"`
PeerIP string `json:"peerAddress,omitempty"`
SiteID int `json:"siteId"`
Connected bool `json:"connected"`
RTT time.Duration `json:"rtt"`
LastSeen time.Time `json:"lastSeen"`
Endpoint string `json:"endpoint,omitempty"`
IsRelay bool `json:"isRelay"`
PeerIP string `json:"peerAddress,omitempty"`
HolepunchConnected bool `json:"holepunchConnected"`
}
// StatusResponse is returned by the status endpoint
@@ -252,6 +253,22 @@ func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
status.IsRelay = isRelay
}
// UpdatePeerHolepunchStatus updates the holepunch connection status of a peer
func (s *API) UpdatePeerHolepunchStatus(siteID int, holepunchConnected bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.HolepunchConnected = holepunchConnected
}
// handleConnect handles the /connect endpoint
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {

21
main.go
View File

@@ -155,14 +155,18 @@ func main() {
}
// Create a context that will be cancelled on interrupt signals
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
signalCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
// Create a separate context for programmatic shutdown (e.g., via API exit)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Run in console mode
runOlmMainWithArgs(ctx, os.Args[1:])
runOlmMainWithArgs(ctx, cancel, signalCtx, os.Args[1:])
}
func runOlmMainWithArgs(ctx context.Context, args []string) {
func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCtx context.Context, args []string) {
// Setup Windows event logging if on Windows
if runtime.GOOS == "windows" {
setupWindowsEventLog()
@@ -211,6 +215,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
HTTPAddr: config.HTTPAddr,
SocketPath: config.SocketPath,
Version: config.Version,
OnExit: cancel, // Pass cancel function directly to trigger shutdown
}
olm.Init(ctx, olmConfig)
@@ -242,9 +247,13 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
logger.Info("Incomplete tunnel configuration, not starting tunnel")
}
// Wait for context cancellation (from signals or API shutdown)
<-ctx.Done()
logger.Info("Shutdown signal received, cleaning up...")
// Wait for either signal or programmatic shutdown
select {
case <-signalCtx.Done():
logger.Info("Shutdown signal received, cleaning up...")
case <-ctx.Done():
logger.Info("Shutdown requested via API, cleaning up...")
}
// Clean up resources
olm.Close()

View File

@@ -97,10 +97,6 @@ func Init(ctx context.Context, config GlobalConfig) {
globalConfig = config
globalCtx = ctx
// Create a cancellable context for internal shutdown controconfiguration GlobalConfigl
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
if config.HTTPAddr != "" {
@@ -194,7 +190,10 @@ func Init(ctx context.Context, config GlobalConfig) {
// onExit
func() error {
logger.Info("Processing shutdown request via API")
cancel()
Close()
if globalConfig.OnExit != nil {
globalConfig.OnExit()
}
return nil
},
)
@@ -419,6 +418,7 @@ func StartTunnel(config TunnelConfig) {
} else {
siteEndpoint = site.Endpoint
}
apiServer.UpdatePeerStatus(site.SiteId, false, 0, siteEndpoint, false)
if err := peerManager.AddPeer(site, siteEndpoint); err != nil {
@@ -483,6 +483,9 @@ func StartTunnel(config TunnelConfig) {
if updateData.Endpoint != "" {
siteConfig.Endpoint = updateData.Endpoint
}
if updateData.RelayEndpoint != "" {
siteConfig.RelayEndpoint = updateData.RelayEndpoint
}
if updateData.PublicKey != "" {
siteConfig.PublicKey = updateData.PublicKey
}
@@ -674,6 +677,12 @@ func StartTunnel(config TunnelConfig) {
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
logger.Debug("Received relay-peer message: %v", msg.Data)
// Check if peerManager is still valid (may be nil during shutdown)
if peerManager == nil {
logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
@@ -700,6 +709,12 @@ func StartTunnel(config TunnelConfig) {
olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) {
logger.Debug("Received unrelay-peer message: %v", msg.Data)
// Check if peerManager is still valid (may be nil during shutdown)
if peerManager == nil {
logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)

View File

@@ -27,6 +27,7 @@ type GlobalConfig struct {
OnConnected func()
OnTerminated func()
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
OnExit func() // Called when exit is requested via API
}
type TunnelConfig struct {

View File

@@ -71,6 +71,7 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
config.MiddleDev,
config.LocalIP,
config.SharedBind,
config.APIServer,
)
return pm
@@ -233,6 +234,16 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
return fmt.Errorf("peer with site ID %d not found", siteConfig.SiteId)
}
// Determine which endpoint to use based on relay state
// If the peer is currently relayed, use the relay endpoint; otherwise use the direct endpoint
actualEndpoint := endpoint
if pm.peerMonitor != nil && pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId) {
if oldPeer.RelayEndpoint != "" {
actualEndpoint = oldPeer.RelayEndpoint
logger.Info("Peer %d is relayed, using relay endpoint: %s", siteConfig.SiteId, actualEndpoint)
}
}
// If public key changed, remove old peer first
if siteConfig.PublicKey != oldPeer.PublicKey {
if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil {
@@ -284,7 +295,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, actualEndpoint); err != nil {
return err
}
@@ -359,6 +370,11 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
// Preserve the relay endpoint if the peer is relayed
if pm.peerMonitor != nil && pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId) && oldPeer.RelayEndpoint != "" {
siteConfig.RelayEndpoint = oldPeer.RelayEndpoint
}
pm.peers[siteConfig.SiteId] = siteConfig
return nil
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/api"
middleDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/websocket"
"gvisor.dev/gvisor/pkg/buffer"
@@ -59,16 +60,22 @@ type PeerMonitor struct {
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
holepunchMaxAttempts int // max consecutive failures before triggering relay
holepunchFailures map[int]int // siteID -> consecutive failure count
// API server for status updates
apiServer *api.API
// WG connection status tracking
wgConnectionStatus map[int]bool // siteID -> WG connected status
}
// NewPeerMonitor creates a new peer monitor with the given callback
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind) *PeerMonitor {
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
monitors: make(map[int]*Client),
interval: 1 * time.Second, // Default check interval
interval: 3 * time.Second, // Default check interval
timeout: 5 * time.Second,
maxAttempts: 5,
maxAttempts: 3,
wsClient: wsClient,
middleDev: middleDev,
localIP: localIP,
@@ -76,13 +83,15 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
nsCtx: ctx,
nsCancel: cancel,
sharedBind: sharedBind,
holepunchInterval: 5 * time.Second, // Check holepunch every 5 seconds
holepunchInterval: 3 * time.Second, // Check holepunch every 5 seconds
holepunchTimeout: 5 * time.Second,
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool),
holepunchMaxAttempts: 5, // Trigger relay after 5 consecutive failures
holepunchMaxAttempts: 3, // Trigger relay after 5 consecutive failures
holepunchFailures: make(map[int]int),
apiServer: apiServer,
wgConnectionStatus: make(map[int]bool),
}
if err := pm.initNetstack(); err != nil {
@@ -235,6 +244,26 @@ func (pm *PeerMonitor) Start() {
// handleConnectionStatusChange is called when a peer's connection status changes
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) {
pm.mutex.Lock()
previousStatus, exists := pm.wgConnectionStatus[siteID]
pm.wgConnectionStatus[siteID] = status.Connected
isRelayed := pm.relayedPeers[siteID]
endpoint := pm.holepunchEndpoints[siteID]
pm.mutex.Unlock()
// Log status changes
if !exists || previousStatus != status.Connected {
if status.Connected {
logger.Info("WireGuard connection to site %d is CONNECTED (RTT: %v)", siteID, status.RTT)
} else {
logger.Warn("WireGuard connection to site %d is DISCONNECTED", siteID)
}
}
// Update API with connection status
if pm.apiServer != nil {
pm.apiServer.UpdatePeerStatus(siteID, status.Connected, status.RTT, endpoint, isRelayed)
}
}
// sendRelay sends a relay message to the server
@@ -302,6 +331,13 @@ func (pm *PeerMonitor) MarkPeerRelayed(siteID int, relayed bool) {
}
}
// IsPeerRelayed returns whether a peer is currently using relay
func (pm *PeerMonitor) IsPeerRelayed(siteID int) bool {
pm.mutex.Lock()
defer pm.mutex.Unlock()
return pm.relayedPeers[siteID]
}
// startHolepunchMonitor starts the holepunch connection monitoring
// Note: This function assumes the mutex is already held by the caller (called from Start())
func (pm *PeerMonitor) startHolepunchMonitor() error {
@@ -364,6 +400,11 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
// checkHolepunchEndpoints tests all holepunch endpoints
func (pm *PeerMonitor) checkHolepunchEndpoints() {
pm.mutex.Lock()
// Check if we're still running before doing any work
if !pm.running {
pm.mutex.Unlock()
return
}
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
for siteID, endpoint := range pm.holepunchEndpoints {
endpoints[siteID] = endpoint
@@ -402,7 +443,30 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
}
}
// Update API with holepunch status
if pm.apiServer != nil {
// Update holepunch connection status
pm.apiServer.UpdatePeerHolepunchStatus(siteID, result.Success)
// Get the current WG connection status for this peer
pm.mutex.Lock()
wgConnected := pm.wgConnectionStatus[siteID]
pm.mutex.Unlock()
// Update API - use holepunch endpoint and relay status
pm.apiServer.UpdatePeerStatus(siteID, wgConnected, result.RTT, endpoint, isRelayed)
}
// Handle relay logic based on holepunch status
// Check if we're still running before sending relay messages
pm.mutex.Lock()
stillRunning := pm.running
pm.mutex.Unlock()
if !stillRunning {
return // Stop processing if shutdown is in progress
}
if !result.Success && !isRelayed && failureCount >= maxAttempts {
// Holepunch failed and we're not relayed - trigger relay
logger.Info("Holepunch to site %d failed %d times, triggering relay", siteID, failureCount)