diff --git a/common.go b/common.go index 6e777d3..a87cb8e 100644 --- a/common.go +++ b/common.go @@ -26,10 +26,11 @@ type WgData struct { } type SiteConfig struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` } type TargetsByType struct { @@ -61,7 +62,6 @@ var ( stopRegister chan struct{} olmToken string gerbilServerPubKey string - peerStatusMap map[int]bool ) const ( @@ -319,18 +319,6 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) { } } -func sendRelay(olm *websocket.Client) error { - err := olm.SendMessage("olm/wg/relay", map[string]interface{}{ - "doIt": "now", - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } - logger.Info("Sent relay message") - return nil -} - func sendRegistration(olm *websocket.Client, publicKey string) error { err := olm.SendMessage("olm/wg/register", map[string]interface{}{ "publicKey": publicKey, @@ -395,34 +383,3 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort) } - -func handlePeerStatusChange(siteID int, connected bool, rtt time.Duration) { - // Check if status has changed - prevStatus, exists := peerStatusMap[siteID] - if !exists || prevStatus != connected { - if connected { - logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) - // Add any actions you want to take when a peer connects - - // Example: try to send a relay message if this is the first peer to connect - if !prevStatus && !exists { - // This is a new connection, not just a status update - go func() { - // Give wireguard a moment to establish properly - // time.Sleep(500 * time.Millisecond) - // if olm != nil { - // if err := sendRelay(olm); err != nil { - // logger.Error("Failed to send relay message: %v", err) - // } - // } - }() - } - } else { - logger.Warn("Peer %d is disconnected", siteID) - // Add any actions you want to take when a peer disconnects - } - - // Update status map - peerStatusMap[siteID] = connected - } -} diff --git a/main.go b/main.go index 489b831..89fa99c 100644 --- a/main.go +++ b/main.go @@ -135,16 +135,6 @@ func main() { stopHolepunch = make(chan struct{}) stopRegister = make(chan struct{}) - peerStatusMap = make(map[int]bool) - - // Initialize the peer monitor - peerMonitor = peermonitor.NewPeerMonitor(handlePeerStatusChange) - defer peerMonitor.Close() - - // Set custom monitoring parameters if needed - peerMonitor.SetInterval(5 * time.Second) - peerMonitor.SetTimeout(500 * time.Millisecond) - peerMonitor.SetMaxAttempts(3) // if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values endpoint = os.Getenv("PANGOLIN_ENDPOINT") @@ -212,12 +202,6 @@ func main() { logger.Fatal("Failed to create olm: %v", err) } - sourcePort, err := FindAvailableUDPPort(49152, 65535) - if err != nil { - fmt.Printf("Error finding available port: %v\n", err) - os.Exit(1) - } - // Create TUN device and network stack var dev *device.Device var wgData WgData @@ -225,6 +209,12 @@ func main() { var uapi *os.File var tdev tun.Device + sourcePort, err := FindAvailableUDPPort(49152, 65535) + if err != nil { + fmt.Printf("Error finding available port: %v\n", err) + os.Exit(1) + } + olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { logger.Info("Received terminate message") olm.Close() @@ -366,6 +356,24 @@ func main() { logger.Info("UAPI listener started") + primaryRelay, err := resolveDomain(endpoint) + if err != nil { + logger.Warn("Failed to resolve endpoint: %v", err) + } + + peerMonitor = peermonitor.NewPeerMonitor( + func(siteID int, connected bool, rtt time.Duration) { + if connected { + logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) + } else { + logger.Warn("Peer %d is disconnected", siteID) + } + }, + fixKey(privateKey.String()), + olm, + dev, + ) + // Configure WireGuard with all sites as peers var configBuilder strings.Builder @@ -395,11 +403,24 @@ func main() { configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) configBuilder.WriteString("persistent_keepalive_interval=1\n") - err = peerMonitor.AddPeer(site.SiteId, siteHost) + // take the first part of the allowedIp and the port from the endpoint and put them together + monitorAddress := strings.Split(site.ServerIP, "/")[0] + + monitorPeer := fmt.Sprintf("%s:%d", monitorAddress, site.ServerPort) + + wgConfig := &peermonitor.WireGuardConfig{ + SiteID: site.SiteId, + PublicKey: fixKey(site.PublicKey), + ServerIP: strings.Split(site.ServerIP, "/")[0], + Endpoint: site.Endpoint, + PrimaryRelay: primaryRelay, // Use the main endpoint as relay + } + + err = peerMonitor.AddPeer(site.SiteId, monitorPeer, wgConfig) if err != nil { logger.Warn("Failed to setup monitoring for site %d: %v", site.SiteId, err) } else { - logger.Info("Started monitoring for site %d at %s", site.SiteId, siteHost) + logger.Info("Started monitoring for site %d at %s", site.SiteId, monitorPeer) } } diff --git a/peermonitor/peermonitor.go b/peermonitor/peermonitor.go index 665b303..be17717 100644 --- a/peermonitor/peermonitor.go +++ b/peermonitor/peermonitor.go @@ -3,35 +3,54 @@ package peermonitor import ( "context" "fmt" - "net" "sync" "time" + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/websocket" "github.com/fosrl/olm/wgtester" + "golang.zx2c4.com/wireguard/device" ) // PeerMonitorCallback is the function type for connection status change callbacks type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration) +// WireGuardConfig holds the WireGuard configuration for a peer +type WireGuardConfig struct { + SiteID int + PublicKey string + ServerIP string + Endpoint string + PrimaryRelay string // The primary relay endpoint +} + // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { monitors map[int]*wgtester.Client + configs map[int]*WireGuardConfig callback PeerMonitorCallback mutex sync.Mutex running bool interval time.Duration timeout time.Duration maxAttempts int + privateKey string + wsClient *websocket.Client + device *device.Device } // NewPeerMonitor creates a new peer monitor with the given callback -func NewPeerMonitor(callback PeerMonitorCallback) *PeerMonitor { +func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device) *PeerMonitor { return &PeerMonitor{ monitors: make(map[int]*wgtester.Client), + configs: make(map[int]*WireGuardConfig), callback: callback, - interval: 5 * time.Second, // Default check interval + interval: 1 * time.Second, // Default check interval timeout: 500 * time.Millisecond, maxAttempts: 3, + privateKey: privateKey, + wsClient: wsClient, + device: device, } } @@ -75,7 +94,7 @@ func (pm *PeerMonitor) SetMaxAttempts(attempts int) { } // AddPeer adds a new peer to monitor -func (pm *PeerMonitor) AddPeer(siteID int, endpoint string) error { +func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error { pm.mutex.Lock() defer pm.mutex.Unlock() @@ -85,11 +104,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string) error { pm.RemovePeer(siteID) } - // Add UDP port if not present, assuming default WireGuard port - if _, _, err := net.SplitHostPort(endpoint); err != nil { - endpoint = endpoint + ":51820" // Default WireGuard port - } - client, err := wgtester.NewClient(endpoint) if err != nil { return err @@ -100,14 +114,15 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string) error { client.SetTimeout(pm.timeout) client.SetMaxAttempts(pm.maxAttempts) - // Store the client + // Store the client and config pm.monitors[siteID] = client + pm.configs[siteID] = wgConfig // If monitor is already running, start monitoring this peer if pm.running { siteIDCopy := siteID // Create a copy for the closure err = client.StartMonitor(func(status wgtester.ConnectionStatus) { - pm.callback(siteIDCopy, status.Connected, status.RTT) + pm.handleConnectionStatusChange(siteIDCopy, status) }) } @@ -127,6 +142,7 @@ func (pm *PeerMonitor) RemovePeer(siteID int) { client.StopMonitor() client.Close() delete(pm.monitors, siteID) + delete(pm.configs, siteID) } // Start begins monitoring all peers @@ -144,11 +160,72 @@ func (pm *PeerMonitor) Start() { for siteID, client := range pm.monitors { siteIDCopy := siteID // Create a copy for the closure client.StartMonitor(func(status wgtester.ConnectionStatus) { - pm.callback(siteIDCopy, status.Connected, status.RTT) + pm.handleConnectionStatusChange(siteIDCopy, status) }) } } +// handleConnectionStatusChange is called when a peer's connection status changes +func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status wgtester.ConnectionStatus) { + // Call the user-provided callback first + if pm.callback != nil { + pm.callback(siteID, status.Connected, status.RTT) + } + + // If disconnected, handle failover + if !status.Connected { + pm.handleFailover(siteID) + } +} + +// handleFailover handles failover to the relay server when a peer is disconnected +func (pm *PeerMonitor) handleFailover(siteID int) { + pm.mutex.Lock() + config, exists := pm.configs[siteID] + pm.mutex.Unlock() + + if !exists { + return + } + + // Configure WireGuard to use the relay + wgConfig := fmt.Sprintf(`private_key=%s +public_key=%s +allowed_ip=%s/32 +endpoint=%s:21820 +persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, config.PrimaryRelay) + + err := pm.device.IpcSet(wgConfig) + if err != nil { + fmt.Printf("Failed to configure WireGuard device: %v\n", err) + return + } + + fmt.Printf("Adjusted peer %d to point to relay!\n", siteID) + + // Send relay message to the server + if pm.wsClient != nil { + pm.sendRelay(siteID) + } +} + +// sendRelay sends a relay message to the server +func (pm *PeerMonitor) sendRelay(siteID int) error { + if pm.wsClient == nil { + return fmt.Errorf("websocket client is nil") + } + + err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ + "siteId": siteID, + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + logger.Info("Sent relay message") + return nil +} + // Stop stops monitoring all peers func (pm *PeerMonitor) Stop() { pm.mutex.Lock()