mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Merge pull request #28 from kevin-gillet/126-stop-litteral-ipv6-from-being-resolved
fix: holepunch to only active peers and stop litteral ipv6 from being name resolved
Former-commit-id: 5d42fac1d1
This commit is contained in:
193
main.go
193
main.go
@@ -10,6 +10,7 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -25,6 +26,34 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Helper function to format endpoints correctly
|
||||||
|
func formatEndpoint(endpoint string) string {
|
||||||
|
if endpoint == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
|
||||||
|
_, _, err := net.SplitHostPort(endpoint)
|
||||||
|
if err == nil {
|
||||||
|
return endpoint // Already valid, no change needed
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
|
||||||
|
lastColon := strings.LastIndex(endpoint, ":")
|
||||||
|
if lastColon > 0 { // Ensure there is a colon and it's not the first character
|
||||||
|
hostPart := endpoint[:lastColon]
|
||||||
|
// Check if the host part is a literal IPv6 address
|
||||||
|
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
|
||||||
|
// It is! Reformat it with brackets.
|
||||||
|
portPart := endpoint[lastColon+1:]
|
||||||
|
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's not the specific malformed case, return it as is.
|
||||||
|
return endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Check if we're running as a Windows service
|
// Check if we're running as a Windows service
|
||||||
if isWindowsService() {
|
if isWindowsService() {
|
||||||
@@ -498,29 +527,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey)
|
go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey)
|
||||||
})
|
})
|
||||||
|
|
||||||
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
|
||||||
logger.Debug("Received message: %v", msg.Data)
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new stopHolepunch channel for the new set of goroutines
|
|
||||||
stopHolepunch = make(chan struct{})
|
|
||||||
|
|
||||||
// Start a single hole punch goroutine for all exit nodes
|
|
||||||
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
|
|
||||||
go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Register handlers for different message types
|
|
||||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received message: %v", msg.Data)
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
@@ -558,9 +564,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tdev, err = func() (tun.Device, error) {
|
tdev, err = func() (tun.Device, error) {
|
||||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
|
||||||
|
|
||||||
// if on macOS, call findUnusedUTUN to get a new utun device
|
|
||||||
if runtime.GOOS == "darwin" {
|
if runtime.GOOS == "darwin" {
|
||||||
interfaceName, err := findUnusedUTUN()
|
interfaceName, err := findUnusedUTUN()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -568,12 +571,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
}
|
}
|
||||||
return tun.CreateTUN(interfaceName, mtuInt)
|
return tun.CreateTUN(interfaceName, mtuInt)
|
||||||
}
|
}
|
||||||
|
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
|
||||||
if tunFdStr == "" {
|
return createTUNFromFD(tunFdStr, mtuInt)
|
||||||
return tun.CreateTUN(interfaceName, mtuInt)
|
|
||||||
}
|
}
|
||||||
|
return tun.CreateTUN(interfaceName, mtuInt)
|
||||||
return createTUNFromFD(tunFdStr, mtuInt)
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -581,75 +582,37 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
realInterfaceName, err2 := tdev.Name()
|
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
||||||
if err2 == nil {
|
|
||||||
interfaceName = realInterfaceName
|
interfaceName = realInterfaceName
|
||||||
}
|
}
|
||||||
|
|
||||||
// open UAPI file (or use supplied fd)
|
|
||||||
fileUAPI, err := func() (*os.File, error) {
|
fileUAPI, err := func() (*os.File, error) {
|
||||||
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
|
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
|
||||||
if uapiFdStr == "" {
|
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
|
||||||
return uapiOpen(interfaceName)
|
if err != nil { return nil, err }
|
||||||
|
return os.NewFile(uintptr(fd), ""), nil
|
||||||
}
|
}
|
||||||
|
return uapiOpen(interfaceName)
|
||||||
// use supplied fd
|
|
||||||
|
|
||||||
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return os.NewFile(uintptr(fd), ""), nil
|
|
||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil { logger.Error("UAPI listen error: %v", err); os.Exit(1); return }
|
||||||
logger.Error("UAPI listen error: %v", err)
|
|
||||||
os.Exit(1)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(
|
|
||||||
mapToWireGuardLogLevel(loggerLevel),
|
|
||||||
"wireguard: ",
|
|
||||||
))
|
|
||||||
|
|
||||||
errs := make(chan error)
|
|
||||||
|
|
||||||
|
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||||
|
|
||||||
uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||||
if err != nil {
|
if err != nil { logger.Error("Failed to listen on uapi socket: %v", err); os.Exit(1) }
|
||||||
logger.Error("Failed to listen on uapi socket: %v", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
conn, err := uapiListener.Accept()
|
conn, err := uapiListener.Accept()
|
||||||
if err != nil {
|
if err != nil { return }
|
||||||
errs <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go dev.IpcHandle(conn)
|
go dev.IpcHandle(conn)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
logger.Info("UAPI listener started")
|
logger.Info("UAPI listener started")
|
||||||
|
|
||||||
// Bring up the device
|
if err = dev.Up(); err != nil { logger.Error("Failed to bring up WireGuard device: %v", err) }
|
||||||
err = dev.Up()
|
if err = ConfigureInterface(interfaceName, wgData); err != nil { logger.Error("Failed to configure interface: %v", err) }
|
||||||
if err != nil {
|
if httpServer != nil { httpServer.SetTunnelIP(wgData.TunnelIP) }
|
||||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// configure the interface
|
|
||||||
err = ConfigureInterface(realInterfaceName, wgData)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to configure interface: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set tunnel IP in HTTP server
|
|
||||||
if httpServer != nil {
|
|
||||||
httpServer.SetTunnelIP(wgData.TunnelIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerMonitor = peermonitor.NewPeerMonitor(
|
peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
func(siteID int, connected bool, rtt time.Duration) {
|
func(siteID int, connected bool, rtt time.Duration) {
|
||||||
@@ -680,28 +643,18 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
doHolepunch,
|
doHolepunch,
|
||||||
)
|
)
|
||||||
|
|
||||||
// loop over the sites and call ConfigurePeer for each one
|
for i := range wgData.Sites {
|
||||||
for _, site := range wgData.Sites {
|
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
|
||||||
if httpServer != nil {
|
if httpServer != nil {
|
||||||
httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
|
httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
|
||||||
}
|
}
|
||||||
err = ConfigurePeer(dev, site, privateKey, endpoint)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to configure peer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = addRouteForServerIP(site.ServerIP, interfaceName)
|
// Format the endpoint before configuring the peer.
|
||||||
if err != nil {
|
site.Endpoint = formatEndpoint(site.Endpoint)
|
||||||
logger.Error("Failed to add route for peer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add routes for remote subnets
|
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { logger.Error("Failed to configure peer: %v", err); return }
|
||||||
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
|
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for peer: %v", err); return }
|
||||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return }
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Configured peer %s", site.PublicKey)
|
logger.Info("Configured peer %s", site.PublicKey)
|
||||||
}
|
}
|
||||||
@@ -748,12 +701,11 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
// Format the endpoint before updating the peer.
|
||||||
logger.Error("Failed to update peer: %v", err)
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
// Send error response if needed
|
|
||||||
return
|
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to update peer: %v", err); return }
|
||||||
}
|
|
||||||
|
|
||||||
// Remove old remote subnet routes if they changed
|
// Remove old remote subnet routes if they changed
|
||||||
if oldRemoteSubnets != siteConfig.RemoteSubnets {
|
if oldRemoteSubnets != siteConfig.RemoteSubnets {
|
||||||
@@ -771,12 +723,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
|
|
||||||
// Update successful
|
// Update successful
|
||||||
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||||
// If this is part of a WgData structure, update it
|
for i := range wgData.Sites {
|
||||||
for i, site := range wgData.Sites {
|
if wgData.Sites[i].SiteId == updateData.SiteId { wgData.Sites[i] = siteConfig; break }
|
||||||
if site.SiteId == updateData.SiteId {
|
|
||||||
wgData.Sites[i] = siteConfig
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logger.Error("WireGuard device not initialized")
|
logger.Error("WireGuard device not initialized")
|
||||||
@@ -811,23 +759,12 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
|
|
||||||
// Add the peer to WireGuard
|
// Add the peer to WireGuard
|
||||||
if dev != nil {
|
if dev != nil {
|
||||||
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
// Format the endpoint before adding the new peer.
|
||||||
logger.Error("Failed to add peer: %v", err)
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add route for the new peer
|
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to add peer: %v", err); return }
|
||||||
err = addRouteForServerIP(siteConfig.ServerIP, interfaceName)
|
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for new peer: %v", err); return }
|
||||||
if err != nil {
|
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return }
|
||||||
logger.Error("Failed to add route for new peer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add routes for remote subnets
|
|
||||||
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
|
||||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add successful
|
// Add successful
|
||||||
logger.Info("Successfully added peer for site %d", addData.SiteId)
|
logger.Info("Successfully added peer for site %d", addData.SiteId)
|
||||||
|
|||||||
Reference in New Issue
Block a user