diff --git a/main.go b/main.go index 0ab4a33..a3d8df1 100644 --- a/main.go +++ b/main.go @@ -59,39 +59,86 @@ func fixKey(key string) string { return hex.EncodeToString(decoded) } -func ping(tnet *netstack.Net, dst string) { +func ping(tnet *netstack.Net, dst string) error { logger.Info("Pinging %s", dst) socket, err := tnet.Dial("ping4", dst) if err != nil { - logger.Error("Failed to create ICMP socket: %v", err) + return fmt.Errorf("failed to create ICMP socket: %w", err) } + defer socket.Close() + requestPing := icmp.Echo{ Seq: rand.Intn(1 << 16), Data: []byte("gopher burrow"), } - icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) - socket.SetReadDeadline(time.Now().Add(time.Second * 10)) + + icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) + if err != nil { + return fmt.Errorf("failed to marshal ICMP message: %w", err) + } + + if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil { + return fmt.Errorf("failed to set read deadline: %w", err) + } + start := time.Now() _, err = socket.Write(icmpBytes) if err != nil { - logger.Error("Failed to write ICMP packet: %v", err) + return fmt.Errorf("failed to write ICMP packet: %w", err) } + n, err := socket.Read(icmpBytes[:]) if err != nil { - logger.Error("Failed to read ICMP packet: %v", err) + return fmt.Errorf("failed to read ICMP packet: %w", err) } + replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) if err != nil { - logger.Error("Failed to parse ICMP packet: %v", err) + return fmt.Errorf("failed to parse ICMP packet: %w", err) } + replyPing, ok := replyPacket.Body.(*icmp.Echo) if !ok { - logger.Error("invalid reply type: %v", replyPacket) + return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body) } + if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { - logger.Error("invalid ping reply: %v", replyPing) + return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q", + replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data) } + logger.Info("Ping latency: %v", time.Since(start)) + return nil +} + +func pingWithRetry(tnet *netstack.Net, dst string) error { + const ( + maxAttempts = 5 + retryDelay = 2 * time.Second + ) + + var lastErr error + for attempt := 1; attempt <= maxAttempts; attempt++ { + logger.Info("Ping attempt %d of %d", attempt, maxAttempts) + + if err := ping(tnet, dst); err != nil { + lastErr = err + logger.Warn("Ping attempt %d failed: %v", attempt, err) + + if attempt < maxAttempts { + time.Sleep(retryDelay) + continue + } + return fmt.Errorf("all ping attempts failed after %d tries, last error: %w", + maxAttempts, lastErr) + } + + // Successful ping + return nil + } + + // This shouldn't be reached due to the return in the loop, but added for completeness + return fmt.Errorf("unexpected error: all ping attempts failed") } func parseLogLevel(level string) logger.LogLevel { @@ -237,7 +284,12 @@ func main() { if connected { logger.Info("Already connected! Put I will send a ping anyway...") - ping(tnet, wgData.ServerIP) + // ping(tnet, wgData.ServerIP) + err = pingWithRetry(tnet, wgData.ServerIP) + if err != nil { + // Handle complete failure after all retries + logger.Error("Failed to ping %s: %v", wgData.ServerIP, err) + } return } @@ -293,7 +345,12 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( logger.Info("WireGuard device created. Lets ping the server now...") // Ping to bring the tunnel up on the server side quickly - ping(tnet, wgData.ServerIP) + // ping(tnet, wgData.ServerIP) + err = pingWithRetry(tnet, wgData.ServerIP) + if err != nil { + // Handle complete failure after all retries + logger.Error("Failed to ping %s: %v", wgData.ServerIP, err) + } // Create proxy manager pm = proxy.NewProxyManager(tnet)