diff --git a/linux.go b/linux.go index 6ee3e72..bba88a7 100644 --- a/linux.go +++ b/linux.go @@ -55,11 +55,20 @@ func closeClients() { } } -func clientsHandleNewtConnection(publicKey string) { +func clientsHandleNewtConnection(publicKey string, endpoint string) { if wgService == nil { return } - wgService.SetServerPubKey(publicKey) + + // split off the port from the endpoint + parts := strings.Split(endpoint, ":") + if len(parts) < 2 { + logger.Error("Invalid endpoint format: %s", endpoint) + return + } + endpoint = strings.Join(parts[:len(parts)-1], ":") + + wgService.StartHolepunch(publicKey, endpoint) } func clientsOnConnect() { diff --git a/main.go b/main.go index 95d3122..b7e2555 100644 --- a/main.go +++ b/main.go @@ -334,8 +334,6 @@ func main() { return } - clientsHandleNewtConnection(wgData.PublicKey) - logger.Debug("Received: %+v", msg) tun, tnet, err = netstack.CreateNetTUN( []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)}, @@ -365,6 +363,8 @@ func main() { return } + clientsHandleNewtConnection(wgData.PublicKey, endpoint) + // Configure WireGuard config := fmt.Sprintf(`private_key=%s public_key=%s @@ -578,30 +578,34 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub WasPreviouslyConnected: res.Node.WasPreviouslyConnected, }) } + // If we were previously connected and there is at least one other good node, - // exclude the previously connected node from pingResults sent to the cloud. - var filteredPingResults []ExitNodePingResult - previouslyConnectedNodeIdx := -1 - for i, res := range pingResults { - if res.WasPreviouslyConnected { - previouslyConnectedNodeIdx = i - } - } - // Count good nodes (latency > 0, no error, not previously connected) - goodNodeCount := 0 - for i, res := range pingResults { - if i != previouslyConnectedNodeIdx && res.LatencyMs > 0 && res.Error == "" { - goodNodeCount++ - } - } - if previouslyConnectedNodeIdx != -1 && goodNodeCount > 0 { + // exclude the previously connected node from pingResults sent to the cloud so we don't try to reconnect to it + // This is to avoid issues where the previously connected node might be down or unreachable + if connected { + var filteredPingResults []ExitNodePingResult + previouslyConnectedNodeIdx := -1 for i, res := range pingResults { - if i != previouslyConnectedNodeIdx { - filteredPingResults = append(filteredPingResults, res) + if res.WasPreviouslyConnected { + previouslyConnectedNodeIdx = i } } - pingResults = filteredPingResults - logger.Info("Excluding previously connected exit node from ping results due to other available nodes") + // Count good nodes (latency > 0, no error, not previously connected) + goodNodeCount := 0 + for i, res := range pingResults { + if i != previouslyConnectedNodeIdx && res.LatencyMs > 0 && res.Error == "" { + goodNodeCount++ + } + } + if previouslyConnectedNodeIdx != -1 && goodNodeCount > 0 { + for i, res := range pingResults { + if i != previouslyConnectedNodeIdx { + filteredPingResults = append(filteredPingResults, res) + } + } + pingResults = filteredPingResults + logger.Info("Excluding previously connected exit node from ping results due to other available nodes") + } } // Send the ping results to the cloud for selection diff --git a/wg/wg.go b/wg/wg.go index 5bf68e7..a09cab8 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -49,21 +49,22 @@ type PeerReading struct { } type WireGuardService struct { - interfaceName string - mtu int - client *websocket.Client - wgClient *wgctrl.Client - config WgConfig - key wgtypes.Key - newtId string - lastReadings map[string]PeerReading - mu sync.Mutex - Port uint16 - stopHolepunch chan struct{} - host string - serverPubKey string - token string - stopGetConfig func() + interfaceName string + mtu int + client *websocket.Client + wgClient *wgctrl.Client + config WgConfig + key wgtypes.Key + newtId string + lastReadings map[string]PeerReading + mu sync.Mutex + Port uint16 + stopHolepunch chan struct{} + host string + serverPubKey string + holePunchEndpoint string + token string + stopGetConfig func() } // Add this type definition @@ -211,13 +212,6 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer) - if err := service.sendUDPHolePunch(service.host + ":21820"); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) - } - - // start the UDP holepunch - go service.keepSendingUDPHolePunch(service.host) - return service, nil } @@ -241,8 +235,14 @@ func (s *WireGuardService) Close(rm bool) { } } -func (s *WireGuardService) SetServerPubKey(serverPubKey string) { +func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) { s.serverPubKey = serverPubKey + s.holePunchEndpoint = endpoint + + logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint) + + // start the UDP holepunch + go s.keepSendingUDPHolePunch(s.holePunchEndpoint) } func (s *WireGuardService) SetToken(token string) { @@ -926,6 +926,11 @@ func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) { } func (s *WireGuardService) keepSendingUDPHolePunch(host string) { + // send initial hole punch + if err := s.sendUDPHolePunch(host + ":21820"); err != nil { + logger.Error("Failed to send initial UDP hole punch: %v", err) + } + ticker := time.NewTicker(3 * time.Second) defer ticker.Stop()