Just hp like olm

This commit is contained in:
Owen
2025-07-25 11:42:36 -07:00
parent f17dbe1fef
commit ccb7008579
3 changed files with 51 additions and 7 deletions

View File

@@ -56,6 +56,12 @@ func setupClients(client *websocket.Client) {
}) })
} }
func setDownstreamTNetstack(tnet *netstack.Net) {
if wgService != nil {
wgService.SetOthertnet(tnet)
}
}
func closeClients() { func closeClients() {
if wgService != nil { if wgService != nil {
wgService.Close(!keepInterface) wgService.Close(!keepInterface)

View File

@@ -343,6 +343,8 @@ func main() {
logger.Error("Failed to create TUN device: %v", err) logger.Error("Failed to create TUN device: %v", err)
} }
setDownstreamTNetstack(tnet)
// Create WireGuard device // Create WireGuard device
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger( dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
mapToWireGuardLogLevel(loggerLevel), mapToWireGuardLogLevel(loggerLevel),

View File

@@ -74,6 +74,7 @@ type WireGuardService struct {
dns []netip.Addr dns []netip.Addr
// Callback for when netstack is ready // Callback for when netstack is ready
onNetstackReady func(*netstack.Net) onNetstackReady func(*netstack.Net)
othertnet *netstack.Net
} }
// Add this type definition // Add this type definition
@@ -209,12 +210,19 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
return service, nil return service, nil
} }
func (s *WireGuardService) SetOthertnet(tnet *netstack.Net) {
s.othertnet = tnet
}
func (s *WireGuardService) Close(rm bool) { func (s *WireGuardService) Close(rm bool) {
if s.stopGetConfig != nil { if s.stopGetConfig != nil {
s.stopGetConfig() s.stopGetConfig()
s.stopGetConfig = nil s.stopGetConfig = nil
} }
s.mu.Lock()
defer s.mu.Unlock()
// Close WireGuard device first - this will automatically close the TUN device // Close WireGuard device first - this will automatically close the TUN device
if s.device != nil { if s.device != nil {
s.device.Close() s.device.Close()
@@ -236,6 +244,9 @@ func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string)
logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint) logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint)
// Create a new stop channel for this holepunch session
s.stopHolepunch = make(chan struct{})
// start the UDP holepunch // start the UDP holepunch
go s.keepSendingUDPHolePunch(s.holePunchEndpoint) go s.keepSendingUDPHolePunch(s.holePunchEndpoint)
} }
@@ -246,11 +257,15 @@ func (s *WireGuardService) SetToken(token string) {
// GetNetstackNet returns the netstack network interface for use by other components // GetNetstackNet returns the netstack network interface for use by other components
func (s *WireGuardService) GetNetstackNet() *netstack.Net { func (s *WireGuardService) GetNetstackNet() *netstack.Net {
s.mu.Lock()
defer s.mu.Unlock()
return s.tnet return s.tnet
} }
// IsReady returns true if the WireGuard service is ready to use // IsReady returns true if the WireGuard service is ready to use
func (s *WireGuardService) IsReady() bool { func (s *WireGuardService) IsReady() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.device != nil && s.tnet != nil return s.device != nil && s.tnet != nil
} }
@@ -310,15 +325,23 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
} }
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.mu.Lock()
// split off the cidr from the IP address // split off the cidr from the IP address
parts := strings.Split(wgconfig.IpAddress, "/") parts := strings.Split(wgconfig.IpAddress, "/")
if len(parts) != 2 { if len(parts) != 2 {
s.mu.Unlock()
return fmt.Errorf("invalid IP address format: %s", wgconfig.IpAddress) return fmt.Errorf("invalid IP address format: %s", wgconfig.IpAddress)
} }
// Parse the IP address and CIDR mask // Parse the IP address and CIDR mask
tunnelIP := netip.MustParseAddr(parts[0]) tunnelIP := netip.MustParseAddr(parts[0])
// stop the holepunch its a channel
if s.stopHolepunch != nil {
close(s.stopHolepunch)
s.stopHolepunch = nil
}
// Parse the IP address from the config // Parse the IP address from the config
// tunnelIP := netip.MustParseAddr(wgconfig.IpAddress) // tunnelIP := netip.MustParseAddr(wgconfig.IpAddress)
@@ -329,6 +352,7 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.dns, s.dns,
s.mtu) s.mtu)
if err != nil { if err != nil {
s.mu.Unlock()
return fmt.Errorf("failed to create TUN device: %v", err) return fmt.Errorf("failed to create TUN device: %v", err)
} }
@@ -345,22 +369,32 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
err = s.device.IpcSet(config) err = s.device.IpcSet(config)
if err != nil { if err != nil {
s.mu.Unlock()
return fmt.Errorf("failed to configure WireGuard device: %v", err) return fmt.Errorf("failed to configure WireGuard device: %v", err)
} }
// Bring up the device // Bring up the device
err = s.device.Up() err = s.device.Up()
if err != nil { if err != nil {
s.mu.Unlock()
return fmt.Errorf("failed to bring up WireGuard device: %v", err) return fmt.Errorf("failed to bring up WireGuard device: %v", err)
} }
logger.Info("WireGuard netstack device created and configured") logger.Info("WireGuard netstack device created and configured")
// Store callback and tnet reference before releasing mutex
callback := s.onNetstackReady
tnet := s.tnet
// Release the mutex before calling the callback
s.mu.Unlock()
// Call the callback if it's set to notify that netstack is ready // Call the callback if it's set to notify that netstack is ready
if s.onNetstackReady != nil { if callback != nil {
s.onNetstackReady(s.tnet) callback(tnet)
} }
// Note: we already unlocked above, so don't use defer unlock
return nil return nil
} }
@@ -784,7 +818,7 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
// Create UDP connection bound to the same port as WireGuard // Create UDP connection bound to the same port as WireGuard
conn, err := net.DialUDP("udp", localAddr, remoteAddr) conn, err := net.DialUDP("udp", localAddr, remoteAddr)
if err != nil { if err != nil {
return fmt.Errorf("failed to create UDP connection: %v", err) return fmt.Errorf("failed to create netstack UDP connection: %v", err)
} }
defer conn.Close() defer conn.Close()
@@ -815,13 +849,13 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
return fmt.Errorf("failed to marshal encrypted payload: %v", err) return fmt.Errorf("failed to marshal encrypted payload: %v", err)
} }
// Send the encrypted packet using the UDP connection // Send the encrypted packet using the netstack UDP connection
_, err = conn.Write(jsonData) _, err = conn.Write(jsonData)
if err != nil { if err != nil {
return fmt.Errorf("failed to send UDP packet: %v", err) return fmt.Errorf("failed to send UDP packet: %v", err)
} }
logger.Debug("Sent UDP hole punch to %s from port %d", remoteAddr.String(), s.Port) logger.Debug("Sent UDP hole punch to %s via netstack", remoteAddr.String())
return nil return nil
} }
@@ -880,9 +914,11 @@ func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) {
} }
func (s *WireGuardService) keepSendingUDPHolePunch(host string) { func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
logger.Info("Starting UDP hole punch routine to %s:21820", host)
// send initial hole punch // send initial hole punch
if err := s.sendUDPHolePunch(host + ":21820"); err != nil { if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
logger.Error("Failed to send initial UDP hole punch: %v", err) logger.Debug("Failed to send initial UDP hole punch: %v", err)
} }
ticker := time.NewTicker(3 * time.Second) ticker := time.NewTicker(3 * time.Second)
@@ -895,7 +931,7 @@ func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
return return
case <-ticker.C: case <-ticker.C:
if err := s.sendUDPHolePunch(host + ":21820"); err != nil { if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err) logger.Debug("Failed to send UDP hole punch: %v", err)
} }
} }
} }