diff --git a/main.go b/main.go index 684754d..299ef05 100644 --- a/main.go +++ b/main.go @@ -6,20 +6,16 @@ import ( "encoding/json" "flag" "fmt" - "math/rand" "net" "os" "os/signal" "strconv" "strings" "syscall" - "time" "github.com/fosrl/newt/logger" "github.com/fosrl/olm/websocket" - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" @@ -65,135 +61,55 @@ const ( ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" ) -func ping(dst string) error { - logger.Info("Pinging %s over WireGuard tunnel", dst) +// func startPingCheck(serverIP string, stopChan chan struct{}) { +// ticker := time.NewTicker(10 * time.Second) +// defer ticker.Stop() - // Create a raw socket for ICMP - conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0") - if err != nil { - return fmt.Errorf("failed to create ICMP socket: %w", err) - } - defer conn.Close() +// go func() { +// for { +// select { +// case <-ticker.C: +// err := ping(serverIP) +// if err != nil { +// logger.Warn("Periodic ping failed: %v", err) +// logger.Warn("HINT: Check if the WireGuard tunnel is up and the server is reachable") +// } +// case <-stopChan: +// logger.Info("Stopping ping check") +// return +// } +// } +// }() +// } - // Parse destination IP - dstIP := net.ParseIP(dst) - if dstIP == nil { - return fmt.Errorf("invalid destination IP: %s", dst) - } +// func pingWithRetry(dst string) error { +// const ( +// maxAttempts = 5 +// retryDelay = 2 * time.Second +// ) - // Create ICMP message - requestPing := icmp.Echo{ - ID: os.Getpid() & 0xffff, - Seq: rand.Intn(1 << 16), - Data: []byte("wireguard ping"), - } +// var lastErr error +// for attempt := 1; attempt <= maxAttempts; attempt++ { +// logger.Info("Ping attempt %d of %d", attempt, maxAttempts) - msg := icmp.Message{ - Type: ipv4.ICMPTypeEcho, - Code: 0, - Body: &requestPing, - } +// if err := ping(dst); err != nil { +// lastErr = err +// logger.Warn("Ping attempt %d failed: %v", attempt, err) - // Marshal the message - icmpBytes, err := msg.Marshal(nil) - if err != nil { - return fmt.Errorf("failed to marshal ICMP message: %w", err) - } +// if attempt < maxAttempts { +// time.Sleep(retryDelay) +// continue +// } +// return fmt.Errorf("all ping attempts failed after %d tries, last error: %w", +// maxAttempts, lastErr) +// } - // Set read deadline - if err := conn.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil { - return fmt.Errorf("failed to set read deadline: %w", err) - } +// // Successful ping +// return nil +// } - // Send the ping - start := time.Now() - _, err = conn.WriteTo(icmpBytes, &net.IPAddr{IP: dstIP}) - if err != nil { - return fmt.Errorf("failed to write ICMP packet: %w", err) - } - - // Wait for reply - reply := make([]byte, 1500) - n, peer, err := conn.ReadFrom(reply) - if err != nil { - return fmt.Errorf("failed to read ICMP packet: %w", err) - } - - // Parse reply - replyMsg, err := icmp.ParseMessage(1, reply[:n]) - if err != nil { - return fmt.Errorf("failed to parse ICMP reply: %w", err) - } - - // Verify reply - switch replyMsg.Type { - case ipv4.ICMPTypeEchoReply: - replyEcho, ok := replyMsg.Body.(*icmp.Echo) - if !ok { - return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyMsg.Body) - } - if replyEcho.ID != requestPing.ID || replyEcho.Seq != requestPing.Seq { - return fmt.Errorf("invalid echo reply: got id=%d seq=%d, want id=%d seq=%d", - replyEcho.ID, replyEcho.Seq, requestPing.ID, requestPing.Seq) - } - default: - return fmt.Errorf("unexpected ICMP message type: %+v", replyMsg) - } - - duration := time.Since(start) - logger.Info("Ping reply from %v: time=%v", peer, duration) - return nil -} - -func startPingCheck(serverIP string, stopChan chan struct{}) { - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - go func() { - for { - select { - case <-ticker.C: - err := ping(serverIP) - if err != nil { - logger.Warn("Periodic ping failed: %v", err) - logger.Warn("HINT: Check if the WireGuard tunnel is up and the server is reachable") - } - case <-stopChan: - logger.Info("Stopping ping check") - return - } - } - }() -} - -func pingWithRetry(dst string) error { - const ( - maxAttempts = 5 - retryDelay = 2 * time.Second - ) - - var lastErr error - for attempt := 1; attempt <= maxAttempts; attempt++ { - logger.Info("Ping attempt %d of %d", attempt, maxAttempts) - - if err := ping(dst); err != nil { - lastErr = err - logger.Warn("Ping attempt %d failed: %v", attempt, err) - - if attempt < maxAttempts { - time.Sleep(retryDelay) - continue - } - return fmt.Errorf("all ping attempts failed after %d tries, last error: %w", - maxAttempts, lastErr) - } - - // Successful ping - return nil - } - - return fmt.Errorf("unexpected error: all ping attempts failed") -} +// return fmt.Errorf("unexpected error: all ping attempts failed") +// } func parseLogLevel(level string) logger.LogLevel { switch strings.ToUpper(level) { @@ -321,7 +237,7 @@ func main() { flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") } if interfaceName == "" { - flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface") + flag.StringVar(&interfaceName, "interface", "wg2", "Name of the WireGuard interface") } if generateAndSaveKeyTo == "" { flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") @@ -367,7 +283,7 @@ func main() { // Create TUN device and network stack var dev *device.Device - var connected bool + // var connected bool var wgData WgData olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { @@ -382,16 +298,16 @@ func main() { olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { logger.Info("Received registration message") - if connected { - logger.Info("Already connected! But I will send a ping anyway...") - err := pingWithRetry(wgData.ServerIP) - if err != nil { - // Handle complete failure after all retries - logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err) - logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?") - } - return - } + // if connected { + // logger.Info("Already connected! But I will send a ping anyway...") + // err := pingWithRetry(wgData.ServerIP) + // if err != nil { + // // Handle complete failure after all retries + // logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err) + // logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?") + // } + // return + // } jsonData, err := json.Marshal(msg.Data) if err != nil { @@ -461,20 +377,20 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Error("Failed to bring up WireGuard device: %v", err) } - logger.Info("WireGuard device created. Lets ping the server now...") + logger.Info("WireGuard device created.") // Ping to bring the tunnel up on the server side quickly // ping(tnet, wgData.ServerIP) - err = pingWithRetry(wgData.ServerIP) - if err != nil { - // Handle complete failure after all retries - logger.Error("Failed to ping %s: %v", wgData.ServerIP, err) - } + // err = pingWithRetry(wgData.ServerIP) + // if err != nil { + // // Handle complete failure after all retries + // logger.Error("Failed to ping %s: %v", wgData.ServerIP, err) + // } - if !connected { - logger.Info("Starting ping check") - startPingCheck(wgData.ServerIP, pingStopChan) - } - connected = true + // if !connected { + // logger.Info("Starting ping check") + // startPingCheck(wgData.ServerIP, pingStopChan) + // } + // connected = true }) olm.OnConnect(func() error {