diff --git a/main.go b/main.go index 17fa1b6..c3891ca 100644 --- a/main.go +++ b/main.go @@ -392,20 +392,99 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort) } +func monitorConnection(dev *device.Device, onTimeout func()) { + const ( + checkInterval = 100 * time.Millisecond // Check every 0.1 seconds + timeout = 500 * time.Millisecond // Total timeout of 1.5 seconds + ) + + go func() { + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + timeoutTimer := time.NewTimer(timeout) + defer timeoutTimer.Stop() + + // var lastSent uint64 + + for { + select { + case <-ticker.C: + // Get the current device statistics + deviceInfo, err := dev.IpcGet() + if err != nil { + logger.Error("Failed to get device statistics: %v", err) + continue + } + + // Parse the statistics from the IPC output + stats := parseStatistics(deviceInfo) + + logger.Info("Received: %d, Sent: %d", stats.received, stats.sent) + + // Check if we've received any new bytes + if stats.received > 0 { + // Connection is successful, we received data + logger.Info("Connection established - received bytes detected") + return + } + + // Update the last known values + // lastSent = stats.sent + + case <-timeoutTimer.C: + // We've hit our timeout without seeing any received bytes + logger.Warn("Connection timeout - no data received within %v", timeout) + onTimeout() + return + } + } + }() +} + +// statistics holds the parsed byte counts from the device +type statistics struct { + received uint64 + sent uint64 +} + +// parseStatistics extracts the received and sent byte counts from the device info string +func parseStatistics(info string) statistics { + var stats statistics + + // Split the device info into lines + lines := strings.Split(info, "\n") + + // Look for the transfer_receive and transfer_send lines + for _, line := range lines { + if strings.HasPrefix(line, "rx_bytes=") { + valueStr := strings.TrimPrefix(line, "rx_bytes=") + if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil { + stats.received = value + } + } else if strings.HasPrefix(line, "tx_bytes=") { + valueStr := strings.TrimPrefix(line, "tx_bytes=") + if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil { + stats.sent = value + } + } + } + + return stats +} + func main() { var ( - endpoint string - id string - secret string - mtu string - mtuInt int - dns string - privateKey wgtypes.Key - err error - logLevel string - interfaceName string - generateAndSaveKeyTo string - reachableAt string + endpoint string + id string + secret string + mtu string + mtuInt int + dns string + privateKey wgtypes.Key + err error + logLevel string + interfaceName string ) stopHolepunch = make(chan struct{}) @@ -419,8 +498,6 @@ func main() { dns = os.Getenv("DNS") logLevel = os.Getenv("LOG_LEVEL") interfaceName = os.Getenv("INTERFACE") - generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") - reachableAt = os.Getenv("REACHABLE_AT") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server") @@ -441,13 +518,7 @@ func main() { flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") } if interfaceName == "" { - flag.StringVar(&interfaceName, "interface", "wg2", "Name of the WireGuard interface") - } - if generateAndSaveKeyTo == "" { - flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") - } - if reachableAt == "" { - flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about") + flag.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface") } // do a --version check @@ -495,18 +566,56 @@ func main() { var dev *device.Device var wgData WgData var uapi *os.File + var tdev tun.Device olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { logger.Info("Received terminate message") olm.Close() }) + olm.RegisterHandler("olm/wg/update", func(msg websocket.WSMessage) { + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &wgData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + + endpoint, err := resolveDomain(wgData.Endpoint) + if err != nil { + logger.Error("Failed to resolve endpoint: %v", err) + return + } + + // Configure WireGuard + config := fmt.Sprintf(`private_key=%s + public_key=%s + allowed_ip=%s/32 + endpoint=%s + persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) + + err = dev.IpcSet(config) + if err != nil { + logger.Error("Failed to configure WireGuard device: %v", err) + } + }) + // Register handlers for different message types olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { logger.Info("Received message: %v", msg.Data) close(stopRegister) + // if there is an existing tunnel then close it + if dev != nil { + logger.Info("Got new message. Closing existing tunnel!") + dev.Close() + } + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Info("Error marshaling data: %v", err) @@ -519,7 +628,7 @@ func main() { } // NEED TO DETERMINE AVAILABLE TUN DEVICE HERE - tdev, err := func() (tun.Device, error) { + tdev, err = func() (tun.Device, error) { tunFdStr := os.Getenv(ENV_WG_TUN_FD) // if on macOS, call findUnusedUTUN to get a new utun device @@ -610,24 +719,18 @@ func main() { logger.Info("UAPI listener started") - // endpoint, err := resolveDomain(wgData.Endpoint) - // if err != nil { - // logger.Error("Failed to resolve endpoint: %v", err) - // return - // } + host, err := resolveDomain(wgData.Endpoint) + if err != nil { + logger.Error("Failed to resolve endpoint: %v", err) + return + } // Configure WireGuard - // config := fmt.Sprintf(`private_key=%s - // public_key=%s - // allowed_ip=%s/32 - // endpoint=%s - // persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) - config := fmt.Sprintf(`private_key=%s public_key=%s allowed_ip=%s/32 -endpoint=18.212.58.121:21820 -persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP) +endpoint=%s +persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host) err = dev.IpcSet(config) if err != nil { @@ -647,6 +750,30 @@ persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.Pub } close(stopHolepunch) + + // Monitor the connection for activity + monitorConnection(dev, func() { + host, err := resolveDomain(endpoint) + if err != nil { + logger.Error("Failed to resolve endpoint: %v", err) + return + } + + // Configure WireGuard + config := fmt.Sprintf(`private_key=%s +public_key=%s +allowed_ip=%s/32 +endpoint=%s:21820 +persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host) + + err = dev.IpcSet(config) + if err != nil { + logger.Error("Failed to configure WireGuard device: %v", err) + } + + logger.Info("Adjusted to point to relay!") + }) + logger.Info("WireGuard device created.") })