Add ping retry and handle ping exception

This commit is contained in:
Owen Schwartz
2024-12-22 21:00:40 -05:00
parent b9a3632a1e
commit 934a235c1e

79
main.go
View File

@@ -59,39 +59,86 @@ func fixKey(key string) string {
return hex.EncodeToString(decoded) return hex.EncodeToString(decoded)
} }
func ping(tnet *netstack.Net, dst string) { func ping(tnet *netstack.Net, dst string) error {
logger.Info("Pinging %s", dst) logger.Info("Pinging %s", dst)
socket, err := tnet.Dial("ping4", dst) socket, err := tnet.Dial("ping4", dst)
if err != nil { if err != nil {
logger.Error("Failed to create ICMP socket: %v", err) return fmt.Errorf("failed to create ICMP socket: %w", err)
} }
defer socket.Close()
requestPing := icmp.Echo{ requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16), Seq: rand.Intn(1 << 16),
Data: []byte("gopher burrow"), Data: []byte("gopher burrow"),
} }
icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
socket.SetReadDeadline(time.Now().Add(time.Second * 10)) icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
if err != nil {
return fmt.Errorf("failed to marshal ICMP message: %w", err)
}
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
return fmt.Errorf("failed to set read deadline: %w", err)
}
start := time.Now() start := time.Now()
_, err = socket.Write(icmpBytes) _, err = socket.Write(icmpBytes)
if err != nil { if err != nil {
logger.Error("Failed to write ICMP packet: %v", err) return fmt.Errorf("failed to write ICMP packet: %w", err)
} }
n, err := socket.Read(icmpBytes[:]) n, err := socket.Read(icmpBytes[:])
if err != nil { if err != nil {
logger.Error("Failed to read ICMP packet: %v", err) return fmt.Errorf("failed to read ICMP packet: %w", err)
} }
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
if err != nil { if err != nil {
logger.Error("Failed to parse ICMP packet: %v", err) return fmt.Errorf("failed to parse ICMP packet: %w", err)
} }
replyPing, ok := replyPacket.Body.(*icmp.Echo) replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok { if !ok {
logger.Error("invalid reply type: %v", replyPacket) return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body)
} }
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
logger.Error("invalid ping reply: %v", replyPing) return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q",
replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data)
} }
logger.Info("Ping latency: %v", time.Since(start)) logger.Info("Ping latency: %v", time.Since(start))
return nil
}
func pingWithRetry(tnet *netstack.Net, dst string) error {
const (
maxAttempts = 5
retryDelay = 2 * time.Second
)
var lastErr error
for attempt := 1; attempt <= maxAttempts; attempt++ {
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
if err := ping(tnet, dst); err != nil {
lastErr = err
logger.Warn("Ping attempt %d failed: %v", attempt, err)
if attempt < maxAttempts {
time.Sleep(retryDelay)
continue
}
return fmt.Errorf("all ping attempts failed after %d tries, last error: %w",
maxAttempts, lastErr)
}
// Successful ping
return nil
}
// This shouldn't be reached due to the return in the loop, but added for completeness
return fmt.Errorf("unexpected error: all ping attempts failed")
} }
func parseLogLevel(level string) logger.LogLevel { func parseLogLevel(level string) logger.LogLevel {
@@ -237,7 +284,12 @@ func main() {
if connected { if connected {
logger.Info("Already connected! Put I will send a ping anyway...") logger.Info("Already connected! Put I will send a ping anyway...")
ping(tnet, wgData.ServerIP) // ping(tnet, wgData.ServerIP)
err = pingWithRetry(tnet, wgData.ServerIP)
if err != nil {
// Handle complete failure after all retries
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
}
return return
} }
@@ -293,7 +345,12 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
logger.Info("WireGuard device created. Lets ping the server now...") logger.Info("WireGuard device created. Lets ping the server now...")
// Ping to bring the tunnel up on the server side quickly // Ping to bring the tunnel up on the server side quickly
ping(tnet, wgData.ServerIP) // ping(tnet, wgData.ServerIP)
err = pingWithRetry(tnet, wgData.ServerIP)
if err != nil {
// Handle complete failure after all retries
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
}
// Create proxy manager // Create proxy manager
pm = proxy.NewProxyManager(tnet) pm = proxy.NewProxyManager(tnet)