package main import ( "context" "encoding/json" "flag" "fmt" "net" "os" "os/signal" "runtime" "strconv" "strings" "syscall" "time" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/updates" "github.com/fosrl/newt/websocket" "github.com/fosrl/olm/httpserver" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/wgtester" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) // Helper function to format endpoints correctly func formatEndpoint(endpoint string) string { if endpoint == "" { return "" } // Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080) _, _, err := net.SplitHostPort(endpoint) if err == nil { return endpoint // Already valid, no change needed } // If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it. lastColon := strings.LastIndex(endpoint, ":") if lastColon > 0 { // Ensure there is a colon and it's not the first character hostPart := endpoint[:lastColon] // Check if the host part is a literal IPv6 address if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil { // It is! Reformat it with brackets. portPart := endpoint[lastColon+1:] return fmt.Sprintf("[%s]:%s", hostPart, portPart) } } // If it's not the specific malformed case, return it as is. return endpoint } func main() { // Check if we're running as a Windows service if isWindowsService() { runService("OlmWireguardService", false, os.Args[1:]) fmt.Println("Running as Windows service") return } // Handle service management commands on Windows if runtime.GOOS == "windows" { var command string if len(os.Args) > 1 { command = os.Args[1] } else { command = "default" } switch command { case "install": err := installService() if err != nil { fmt.Printf("Failed to install service: %v\n", err) os.Exit(1) } fmt.Println("Service installed successfully") return case "remove", "uninstall": err := removeService() if err != nil { fmt.Printf("Failed to remove service: %v\n", err) os.Exit(1) } fmt.Println("Service removed successfully") return case "start": // Pass the remaining arguments (after "start") to the service serviceArgs := os.Args[2:] err := startService(serviceArgs) if err != nil { fmt.Printf("Failed to start service: %v\n", err) os.Exit(1) } fmt.Println("Service started successfully") return case "stop": err := stopService() if err != nil { fmt.Printf("Failed to stop service: %v\n", err) os.Exit(1) } fmt.Println("Service stopped successfully") return case "status": status, err := getServiceStatus() if err != nil { fmt.Printf("Failed to get service status: %v\n", err) os.Exit(1) } fmt.Printf("Service status: %s\n", status) return case "debug": // get the status and if it is Not Installed then install it first status, err := getServiceStatus() if err != nil { fmt.Printf("Failed to get service status: %v\n", err) os.Exit(1) } if status == "Not Installed" { err := installService() if err != nil { fmt.Printf("Failed to install service: %v\n", err) os.Exit(1) } fmt.Println("Service installed successfully, now running in debug mode") } // Pass the remaining arguments (after "debug") to the service serviceArgs := os.Args[2:] err = debugService(serviceArgs) if err != nil { fmt.Printf("Failed to debug service: %v\n", err) os.Exit(1) } return case "logs": err := watchLogFile(false) if err != nil { fmt.Printf("Failed to watch log file: %v\n", err) os.Exit(1) } return case "config": if runtime.GOOS == "windows" { showServiceConfig() } else { fmt.Println("Service configuration is only available on Windows") } return case "help", "--help", "-h": fmt.Println("Olm WireGuard VPN Client") fmt.Println("\nWindows Service Management:") fmt.Println(" install Install the service") fmt.Println(" remove Remove the service") fmt.Println(" start [args] Start the service with optional arguments") fmt.Println(" stop Stop the service") fmt.Println(" status Show service status") fmt.Println(" debug [args] Run service in debug mode with optional arguments") fmt.Println(" logs Tail the service log file") fmt.Println(" config Show current service configuration") fmt.Println("\nExamples:") fmt.Println(" olm start --enable-http --http-addr :9452") fmt.Println(" olm debug --endpoint https://example.com --id myid --secret mysecret") fmt.Println("\nFor console mode, run without arguments or with standard flags.") return default: // get the status and if it is Not Installed then install it first status, err := getServiceStatus() if err != nil { fmt.Printf("Failed to get service status: %v\n", err) os.Exit(1) } if status == "Not Installed" { err := installService() if err != nil { fmt.Printf("Failed to install service: %v\n", err) os.Exit(1) } fmt.Println("Service installed successfully, now running") } // Pass the remaining arguments (after "debug") to the service serviceArgs := os.Args[1:] err = debugService(serviceArgs) if err != nil { fmt.Printf("Failed to debug service: %v\n", err) os.Exit(1) } return } } // Run in console mode runOlmMain(context.Background()) } func runOlmMain(ctx context.Context) { runOlmMainWithArgs(ctx, os.Args[1:]) } func runOlmMainWithArgs(ctx context.Context, args []string) { // Log that we've entered the main function // fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args) // Create a new FlagSet for parsing service arguments serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError) var ( endpoint string id string secret string mtu string mtuInt int dns string privateKey wgtypes.Key err error logLevel string interfaceName string enableHTTP bool httpAddr string testMode bool // Add this var for the test flag testTarget string // Add this var for test target pingInterval time.Duration pingTimeout time.Duration doHolepunch bool connected bool ) stopHolepunch = make(chan struct{}) stopPing = make(chan struct{}) // if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values endpoint = os.Getenv("PANGOLIN_ENDPOINT") id = os.Getenv("OLM_ID") secret = os.Getenv("OLM_SECRET") mtu = os.Getenv("MTU") dns = os.Getenv("DNS") logLevel = os.Getenv("LOG_LEVEL") interfaceName = os.Getenv("INTERFACE") httpAddr = os.Getenv("HTTP_ADDR") pingIntervalStr := os.Getenv("PING_INTERVAL") pingTimeoutStr := os.Getenv("PING_TIMEOUT") enableHTTPEnv := os.Getenv("ENABLE_HTTP") holepunchEnv := os.Getenv("HOLEPUNCH") enableHTTP = enableHTTPEnv == "true" doHolepunch = holepunchEnv == "true" if endpoint == "" { serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server") } if id == "" { serviceFlags.StringVar(&id, "id", "", "Olm ID") } if secret == "" { serviceFlags.StringVar(&secret, "secret", "", "Olm secret") } if mtu == "" { serviceFlags.StringVar(&mtu, "mtu", "1280", "MTU to use") } if dns == "" { serviceFlags.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use") } if logLevel == "" { serviceFlags.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") } if interfaceName == "" { serviceFlags.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface") } if httpAddr == "" { serviceFlags.StringVar(&httpAddr, "http-addr", ":9452", "HTTP server address (e.g., ':9452')") } if pingIntervalStr == "" { serviceFlags.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)") } if pingTimeoutStr == "" { serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)") } if enableHTTPEnv == "" { serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests") } if holepunchEnv == "" { serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)") } version := serviceFlags.Bool("version", false, "Print the version") // Parse the service arguments if err := serviceFlags.Parse(args); err != nil { fmt.Printf("Error parsing service arguments: %v\n", err) return } // Debug: Print final values after flag parsing // fmt.Printf("After flag parsing: endpoint='%s', id='%s', secret='%s'\n", endpoint, id, secret) // Parse ping intervals if pingIntervalStr != "" { pingInterval, err = time.ParseDuration(pingIntervalStr) if err != nil { fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr) pingInterval = 3 * time.Second } } else { pingInterval = 3 * time.Second } if pingTimeoutStr != "" { pingTimeout, err = time.ParseDuration(pingTimeoutStr) if err != nil { fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr) pingTimeout = 5 * time.Second } } else { pingTimeout = 5 * time.Second } // Setup Windows event logging if on Windows if runtime.GOOS == "windows" { setupWindowsEventLog() } else { // Initialize logger for non-Windows platforms logger.Init() } loggerLevel := parseLogLevel(logLevel) logger.GetLogger().SetLevel(parseLogLevel(logLevel)) olmVersion := "version_replaceme" if *version { fmt.Println("Olm version " + olmVersion) os.Exit(0) } else { logger.Info("Olm version " + olmVersion) } if err := updates.CheckForUpdate("fosrl", "olm", olmVersion); err != nil { logger.Debug("Failed to check for updates: %v", err) } // Log startup information logger.Debug("Olm service starting...") logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret) logger.Debug("HTTP enabled: %v, HTTP addr: %s", enableHTTP, httpAddr) if doHolepunch { logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.") } // Handle test mode if testMode { if testTarget == "" { logger.Fatal("Test mode requires -test-target to be set to a server:port") } logger.Info("Running in test mode, connecting to %s", testTarget) // Create a new tester client tester, err := wgtester.NewClient(testTarget) if err != nil { logger.Fatal("Failed to create tester client: %v", err) } defer tester.Close() // Test connection with a 2-second timeout connected, rtt := tester.TestConnectionWithTimeout(2 * time.Second) if connected { logger.Info("Connection test successful! RTT: %v", rtt) fmt.Printf("Connection test successful! RTT: %v\n", rtt) os.Exit(0) } else { logger.Error("Connection test failed - no response received") fmt.Println("Connection test failed - no response received") os.Exit(1) } } var httpServer *httpserver.HTTPServer if enableHTTP { httpServer = httpserver.NewHTTPServer(httpAddr) httpServer.SetVersion(olmVersion) if err := httpServer.Start(); err != nil { logger.Fatal("Failed to start HTTP server: %v", err) } // Use a goroutine to handle connection requests go func() { for req := range httpServer.GetConnectionChannel() { logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) // Set the connection parameters id = req.ID secret = req.Secret endpoint = req.Endpoint } }() } // // Check if required parameters are missing and provide helpful guidance // missingParams := []string{} // if id == "" { // missingParams = append(missingParams, "id (use -id flag or OLM_ID env var)") // } // if secret == "" { // missingParams = append(missingParams, "secret (use -secret flag or OLM_SECRET env var)") // } // if endpoint == "" { // missingParams = append(missingParams, "endpoint (use -endpoint flag or PANGOLIN_ENDPOINT env var)") // } // if len(missingParams) > 0 { // logger.Error("Missing required parameters: %v", missingParams) // logger.Error("Either provide them as command line flags or set as environment variables") // fmt.Printf("ERROR: Missing required parameters: %v\n", missingParams) // fmt.Printf("Please provide them as command line flags or set as environment variables\n") // if !enableHTTP { // logger.Error("HTTP server is disabled, cannot receive parameters via API") // fmt.Printf("HTTP server is disabled, cannot receive parameters via API\n") // return // } // } // Create a new olm olm, err := websocket.NewClient( "olm", id, // CLI arg takes precedence secret, // CLI arg takes precedence endpoint, pingInterval, pingTimeout, ) if err != nil { logger.Fatal("Failed to create olm: %v", err) } endpoint = olm.GetConfig().Endpoint // Update endpoint from config id = olm.GetConfig().ID // Update ID from config secret = olm.GetConfig().Secret // Update secret from config // wait until we have a client id and secret and endpoint waitCount := 0 for id == "" || secret == "" || endpoint == "" { select { case <-ctx.Done(): logger.Info("Context cancelled while waiting for credentials") return default: missing := []string{} if id == "" { missing = append(missing, "id") } if secret == "" { missing = append(missing, "secret") } if endpoint == "" { missing = append(missing, "endpoint") } waitCount++ if waitCount%10 == 1 { // Log every 10 seconds instead of every second logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount) } time.Sleep(1 * time.Second) } } // parse the mtu string into an int mtuInt, err = strconv.Atoi(mtu) if err != nil { logger.Fatal("Failed to parse MTU: %v", err) } privateKey, err = wgtypes.GeneratePrivateKey() if err != nil { logger.Fatal("Failed to generate private key: %v", err) } // Create TUN device and network stack var dev *device.Device var wgData WgData var holePunchData HolePunchData var uapiListener net.Listener var tdev tun.Device sourcePort, err := FindAvailableUDPPort(49152, 65535) if err != nil { fmt.Printf("Error finding available port: %v\n", err) os.Exit(1) } olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Info("Error marshaling data: %v", err) return } if err := json.Unmarshal(jsonData, &holePunchData); err != nil { logger.Info("Error unmarshaling target data: %v", err) return } // Create a new stopHolepunch channel for the new set of goroutines stopHolepunch = make(chan struct{}) // Start a single hole punch goroutine for all exit nodes logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes)) go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort) }) olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { // THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY logger.Debug("Received message: %v", msg.Data) type LegacyHolePunchData struct { ServerPubKey string `json:"serverPubKey"` Endpoint string `json:"endpoint"` } var legacyHolePunchData LegacyHolePunchData jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Info("Error marshaling data: %v", err) return } if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil { logger.Info("Error unmarshaling target data: %v", err) return } // Stop any existing hole punch goroutines by closing the current channel select { case <-stopHolepunch: // Channel already closed default: close(stopHolepunch) } // Create a new stopHolepunch channel for the new set of goroutines stopHolepunch = make(chan struct{}) // Start hole punching for each exit node logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey) go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey) }) olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) if connected { logger.Info("Already connected. Ignoring new connection request.") return } if stopRegister != nil { stopRegister() stopRegister = nil } close(stopHolepunch) // wait 10 milliseconds to ensure the previous connection is closed logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed") time.Sleep(500 * time.Millisecond) // if there is an existing tunnel then close it if dev != nil { logger.Info("Got new message. Closing existing tunnel!") dev.Close() } jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Info("Error marshaling data: %v", err) return } if err := json.Unmarshal(jsonData, &wgData); err != nil { logger.Info("Error unmarshaling target data: %v", err) return } tdev, err = func() (tun.Device, error) { if runtime.GOOS == "darwin" { interfaceName, err := findUnusedUTUN() if err != nil { return nil, err } return tun.CreateTUN(interfaceName, mtuInt) } if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" { return createTUNFromFD(tunFdStr, mtuInt) } return tun.CreateTUN(interfaceName, mtuInt) }() if err != nil { logger.Error("Failed to create TUN device: %v", err) return } if realInterfaceName, err2 := tdev.Name(); err2 == nil { interfaceName = realInterfaceName } fileUAPI, err := func() (*os.File, error) { if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" { fd, err := strconv.ParseUint(uapiFdStr, 10, 32) if err != nil { return nil, err } return os.NewFile(uintptr(fd), ""), nil } return uapiOpen(interfaceName) }() if err != nil { logger.Error("UAPI listen error: %v", err) os.Exit(1) return } dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) uapiListener, err = uapiListen(interfaceName, fileUAPI) if err != nil { logger.Error("Failed to listen on uapi socket: %v", err) os.Exit(1) } go func() { for { conn, err := uapiListener.Accept() if err != nil { return } go dev.IpcHandle(conn) } }() logger.Info("UAPI listener started") if err = dev.Up(); err != nil { logger.Error("Failed to bring up WireGuard device: %v", err) } if err = ConfigureInterface(interfaceName, wgData); err != nil { logger.Error("Failed to configure interface: %v", err) } if httpServer != nil { httpServer.SetTunnelIP(wgData.TunnelIP) } peerMonitor = peermonitor.NewPeerMonitor( func(siteID int, connected bool, rtt time.Duration) { if httpServer != nil { // Find the site config to get endpoint information var endpoint string var isRelay bool for _, site := range wgData.Sites { if site.SiteId == siteID { endpoint = site.Endpoint // TODO: We'll need to track relay status separately // For now, assume not using relay unless we get relay data isRelay = !doHolepunch break } } httpServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay) } if connected { logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) } else { logger.Warn("Peer %d is disconnected", siteID) } }, fixKey(privateKey.String()), olm, dev, doHolepunch, ) for i := range wgData.Sites { site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice if httpServer != nil { httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false) } // Format the endpoint before configuring the peer. site.Endpoint = formatEndpoint(site.Endpoint) if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { logger.Error("Failed to configure peer: %v", err) return } if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for peer: %v", err) return } if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err) return } logger.Info("Configured peer %s", site.PublicKey) } peerMonitor.Start() connected = true logger.Info("WireGuard device created.") }) olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { logger.Debug("Received update-peer message: %v", msg.Data) jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) return } var updateData UpdatePeerData if err := json.Unmarshal(jsonData, &updateData); err != nil { logger.Error("Error unmarshaling update data: %v", err) return } // Convert to SiteConfig siteConfig := SiteConfig{ SiteId: updateData.SiteId, Endpoint: updateData.Endpoint, PublicKey: updateData.PublicKey, ServerIP: updateData.ServerIP, ServerPort: updateData.ServerPort, RemoteSubnets: updateData.RemoteSubnets, } // Update the peer in WireGuard if dev != nil { // Find the existing peer to get old data var oldRemoteSubnets string var oldPublicKey string for _, site := range wgData.Sites { if site.SiteId == updateData.SiteId { oldRemoteSubnets = site.RemoteSubnets oldPublicKey = site.PublicKey break } } // If the public key has changed, remove the old peer first if oldPublicKey != "" && oldPublicKey != updateData.PublicKey { logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey) if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil { logger.Error("Failed to remove old peer: %v", err) return } } // Format the endpoint before updating the peer. siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to update peer: %v", err) return } // Remove old remote subnet routes if they changed if oldRemoteSubnets != siteConfig.RemoteSubnets { if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil { logger.Error("Failed to remove old remote subnet routes: %v", err) // Continue anyway to add new routes } // Add new remote subnet routes if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add new remote subnet routes: %v", err) return } } // Update successful logger.Info("Successfully updated peer for site %d", updateData.SiteId) for i := range wgData.Sites { if wgData.Sites[i].SiteId == updateData.SiteId { wgData.Sites[i] = siteConfig break } } } else { logger.Error("WireGuard device not initialized") } }) // Handler for adding a new peer olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { logger.Debug("Received add-peer message: %v", msg.Data) jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) return } var addData AddPeerData if err := json.Unmarshal(jsonData, &addData); err != nil { logger.Error("Error unmarshaling add data: %v", err) return } // Convert to SiteConfig siteConfig := SiteConfig{ SiteId: addData.SiteId, Endpoint: addData.Endpoint, PublicKey: addData.PublicKey, ServerIP: addData.ServerIP, ServerPort: addData.ServerPort, RemoteSubnets: addData.RemoteSubnets, } // Add the peer to WireGuard if dev != nil { // Format the endpoint before adding the new peer. siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to add peer: %v", err) return } if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for new peer: %v", err) return } if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err) return } // Add successful logger.Info("Successfully added peer for site %d", addData.SiteId) // Update WgData with the new peer wgData.Sites = append(wgData.Sites, siteConfig) } else { logger.Error("WireGuard device not initialized") } }) // Handler for removing a peer olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) { logger.Debug("Received remove-peer message: %v", msg.Data) jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) return } var removeData RemovePeerData if err := json.Unmarshal(jsonData, &removeData); err != nil { logger.Error("Error unmarshaling remove data: %v", err) return } // Find the peer to remove var peerToRemove *SiteConfig var newSites []SiteConfig for _, site := range wgData.Sites { if site.SiteId == removeData.SiteId { peerToRemove = &site } else { newSites = append(newSites, site) } } if peerToRemove == nil { logger.Error("Peer with site ID %d not found", removeData.SiteId) return } // Remove the peer from WireGuard if dev != nil { if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil { logger.Error("Failed to remove peer: %v", err) // Send error response if needed return } // Remove route for the peer err = removeRouteForServerIP(peerToRemove.ServerIP) if err != nil { logger.Error("Failed to remove route for peer: %v", err) return } // Remove routes for remote subnets if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil { logger.Error("Failed to remove routes for remote subnets: %v", err) return } // Remove successful logger.Info("Successfully removed peer for site %d", removeData.SiteId) // Update WgData to remove the peer wgData.Sites = newSites } else { logger.Error("WireGuard device not initialized") } }) olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) { logger.Debug("Received relay-peer message: %v", msg.Data) jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) return } var relayData RelayPeerData if err := json.Unmarshal(jsonData, &relayData); err != nil { logger.Error("Error unmarshaling relay data: %v", err) return } primaryRelay, err := resolveDomain(relayData.Endpoint) if err != nil { logger.Warn("Failed to resolve primary relay endpoint: %v", err) } // Update HTTP server to mark this peer as using relay if httpServer != nil { httpServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) } peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) }) olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) { logger.Info("Received no-sites message - no sites available for connection") // if stopRegister != nil { // stopRegister() // stopRegister = nil // } // select { // case <-stopHolepunch: // // Channel already closed, do nothing // default: // close(stopHolepunch) // } logger.Info("No sites available - stopped registration and holepunch processes") }) olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { logger.Info("Received terminate message") olm.Close() }) olm.OnConnect(func() error { logger.Info("Websocket Connected") if httpServer != nil { httpServer.SetConnectionStatus(true) } if connected { logger.Debug("Already connected, skipping registration") return nil } publicKey := privateKey.PublicKey() logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch) stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ "publicKey": publicKey.String(), "relay": !doHolepunch, "olmVersion": olmVersion, }, 1*time.Second) go keepSendingPing(olm) logger.Info("Sent registration message") return nil }) olm.OnTokenUpdate(func(token string) { olmToken = token }) // Connect to the WebSocket server if err := olm.Connect(); err != nil { logger.Fatal("Failed to connect to server: %v", err) } defer olm.Close() // Wait for interrupt signal or context cancellation sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) select { case <-sigCh: logger.Info("Received interrupt signal") case <-ctx.Done(): logger.Info("Context cancelled") } select { case <-stopHolepunch: // Channel already closed, do nothing default: close(stopHolepunch) } if stopRegister != nil { stopRegister() stopRegister = nil } select { case <-stopPing: // Channel already closed default: close(stopPing) } if uapiListener != nil { uapiListener.Close() } if dev != nil { dev.Close() } logger.Info("runOlmMain() exiting") fmt.Printf("runOlmMain() exiting\n") }