diff --git a/linux.go b/linux.go index 99f5e2e..ed7846c 100644 --- a/linux.go +++ b/linux.go @@ -56,6 +56,12 @@ func setupClients(client *websocket.Client) { }) } +func setDownstreamTNetstack(tnet *netstack.Net) { + if wgService != nil { + wgService.SetOthertnet(tnet) + } +} + func closeClients() { if wgService != nil { wgService.Close(!keepInterface) diff --git a/main.go b/main.go index 7557d6e..8df5402 100644 --- a/main.go +++ b/main.go @@ -343,6 +343,8 @@ func main() { logger.Error("Failed to create TUN device: %v", err) } + setDownstreamTNetstack(tnet) + // Create WireGuard device dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger( mapToWireGuardLogLevel(loggerLevel), diff --git a/wgnetstack/wgnetstack.go b/wgnetstack/wgnetstack.go index 333ebec..5c3410a 100644 --- a/wgnetstack/wgnetstack.go +++ b/wgnetstack/wgnetstack.go @@ -74,6 +74,7 @@ type WireGuardService struct { dns []netip.Addr // Callback for when netstack is ready onNetstackReady func(*netstack.Net) + othertnet *netstack.Net } // Add this type definition @@ -209,12 +210,19 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str return service, nil } +func (s *WireGuardService) SetOthertnet(tnet *netstack.Net) { + s.othertnet = tnet +} + func (s *WireGuardService) Close(rm bool) { if s.stopGetConfig != nil { s.stopGetConfig() s.stopGetConfig = nil } + s.mu.Lock() + defer s.mu.Unlock() + // Close WireGuard device first - this will automatically close the TUN device if s.device != nil { s.device.Close() @@ -236,6 +244,9 @@ func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint) + // Create a new stop channel for this holepunch session + s.stopHolepunch = make(chan struct{}) + // start the UDP holepunch go s.keepSendingUDPHolePunch(s.holePunchEndpoint) } @@ -246,11 +257,15 @@ func (s *WireGuardService) SetToken(token string) { // GetNetstackNet returns the netstack network interface for use by other components func (s *WireGuardService) GetNetstackNet() *netstack.Net { + s.mu.Lock() + defer s.mu.Unlock() return s.tnet } // IsReady returns true if the WireGuard service is ready to use func (s *WireGuardService) IsReady() bool { + s.mu.Lock() + defer s.mu.Unlock() return s.device != nil && s.tnet != nil } @@ -310,15 +325,23 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { } func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { + s.mu.Lock() // split off the cidr from the IP address parts := strings.Split(wgconfig.IpAddress, "/") if len(parts) != 2 { + s.mu.Unlock() return fmt.Errorf("invalid IP address format: %s", wgconfig.IpAddress) } // Parse the IP address and CIDR mask tunnelIP := netip.MustParseAddr(parts[0]) + // stop the holepunch its a channel + if s.stopHolepunch != nil { + close(s.stopHolepunch) + s.stopHolepunch = nil + } + // Parse the IP address from the config // tunnelIP := netip.MustParseAddr(wgconfig.IpAddress) @@ -329,6 +352,7 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { s.dns, s.mtu) if err != nil { + s.mu.Unlock() return fmt.Errorf("failed to create TUN device: %v", err) } @@ -345,22 +369,32 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { err = s.device.IpcSet(config) if err != nil { + s.mu.Unlock() return fmt.Errorf("failed to configure WireGuard device: %v", err) } // Bring up the device err = s.device.Up() if err != nil { + s.mu.Unlock() return fmt.Errorf("failed to bring up WireGuard device: %v", err) } logger.Info("WireGuard netstack device created and configured") + // Store callback and tnet reference before releasing mutex + callback := s.onNetstackReady + tnet := s.tnet + + // Release the mutex before calling the callback + s.mu.Unlock() + // Call the callback if it's set to notify that netstack is ready - if s.onNetstackReady != nil { - s.onNetstackReady(s.tnet) + if callback != nil { + callback(tnet) } + // Note: we already unlocked above, so don't use defer unlock return nil } @@ -784,7 +818,7 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { // Create UDP connection bound to the same port as WireGuard conn, err := net.DialUDP("udp", localAddr, remoteAddr) if err != nil { - return fmt.Errorf("failed to create UDP connection: %v", err) + return fmt.Errorf("failed to create netstack UDP connection: %v", err) } defer conn.Close() @@ -815,13 +849,13 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { return fmt.Errorf("failed to marshal encrypted payload: %v", err) } - // Send the encrypted packet using the UDP connection + // Send the encrypted packet using the netstack UDP connection _, err = conn.Write(jsonData) if err != nil { return fmt.Errorf("failed to send UDP packet: %v", err) } - logger.Debug("Sent UDP hole punch to %s from port %d", remoteAddr.String(), s.Port) + logger.Debug("Sent UDP hole punch to %s via netstack", remoteAddr.String()) return nil } @@ -880,9 +914,11 @@ func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) { } func (s *WireGuardService) keepSendingUDPHolePunch(host string) { + logger.Info("Starting UDP hole punch routine to %s:21820", host) + // send initial hole punch if err := s.sendUDPHolePunch(host + ":21820"); err != nil { - logger.Error("Failed to send initial UDP hole punch: %v", err) + logger.Debug("Failed to send initial UDP hole punch: %v", err) } ticker := time.NewTicker(3 * time.Second) @@ -895,7 +931,7 @@ func (s *WireGuardService) keepSendingUDPHolePunch(host string) { return case <-ticker.C: if err := s.sendUDPHolePunch(host + ":21820"); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) + logger.Debug("Failed to send UDP hole punch: %v", err) } } }