Merge branch 'main' into dev

Former-commit-id: 42bd8d5d4c
This commit is contained in:
Owen
2025-09-26 09:38:40 -07:00
6 changed files with 84 additions and 140 deletions

View File

@@ -31,7 +31,7 @@ jobs:
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v6
with: with:
go-version: 1.25 go-version: 1.25

View File

@@ -14,7 +14,7 @@ jobs:
- uses: actions/checkout@v5 - uses: actions/checkout@v5
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v6
with: with:
go-version: 1.25 go-version: 1.25

6
go.mod
View File

@@ -5,9 +5,9 @@ go 1.25
require ( require (
github.com/fosrl/newt v0.0.0-20250730062419-3ccd755d557a github.com/fosrl/newt v0.0.0-20250730062419-3ccd755d557a
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.41.0 golang.org/x/crypto v0.42.0
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
golang.org/x/sys v0.35.0 golang.org/x/sys v0.36.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
) )
@@ -15,7 +15,7 @@ require (
require ( require (
github.com/gorilla/websocket v1.5.3 // indirect github.com/gorilla/websocket v1.5.3 // indirect
github.com/vishvananda/netns v0.0.5 // indirect github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/net v0.42.0 // indirect golang.org/x/net v0.43.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
software.sslmate.com/src/go-pkcs12 v0.6.0 // indirect software.sslmate.com/src/go-pkcs12 v0.6.0 // indirect

12
go.sum
View File

@@ -10,16 +10,16 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4= golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=

193
main.go
View File

@@ -10,6 +10,7 @@ import (
"os/signal" "os/signal"
"runtime" "runtime"
"strconv" "strconv"
"strings"
"syscall" "syscall"
"time" "time"
@@ -25,6 +26,34 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// Helper function to format endpoints correctly
func formatEndpoint(endpoint string) string {
if endpoint == "" {
return ""
}
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint // Already valid, no change needed
}
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
lastColon := strings.LastIndex(endpoint, ":")
if lastColon > 0 { // Ensure there is a colon and it's not the first character
hostPart := endpoint[:lastColon]
// Check if the host part is a literal IPv6 address
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
// It is! Reformat it with brackets.
portPart := endpoint[lastColon+1:]
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
}
}
// If it's not the specific malformed case, return it as is.
return endpoint
}
func main() { func main() {
// Check if we're running as a Windows service // Check if we're running as a Windows service
if isWindowsService() { if isWindowsService() {
@@ -507,29 +536,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey) go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey)
}) })
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Create a new stopHolepunch channel for the new set of goroutines
stopHolepunch = make(chan struct{})
// Start a single hole punch goroutine for all exit nodes
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort)
})
// Register handlers for different message types
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data) logger.Debug("Received message: %v", msg.Data)
@@ -567,9 +573,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
} }
tdev, err = func() (tun.Device, error) { tdev, err = func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
// if on macOS, call findUnusedUTUN to get a new utun device
if runtime.GOOS == "darwin" { if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN() interfaceName, err := findUnusedUTUN()
if err != nil { if err != nil {
@@ -577,12 +580,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
} }
return tun.CreateTUN(interfaceName, mtuInt) return tun.CreateTUN(interfaceName, mtuInt)
} }
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
if tunFdStr == "" { return createTUNFromFD(tunFdStr, mtuInt)
return tun.CreateTUN(interfaceName, mtuInt)
} }
return tun.CreateTUN(interfaceName, mtuInt)
return createTUNFromFD(tunFdStr, mtuInt)
}() }()
if err != nil { if err != nil {
@@ -590,75 +591,37 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
return return
} }
realInterfaceName, err2 := tdev.Name() if realInterfaceName, err2 := tdev.Name(); err2 == nil {
if err2 == nil {
interfaceName = realInterfaceName interfaceName = realInterfaceName
} }
// open UAPI file (or use supplied fd)
fileUAPI, err := func() (*os.File, error) { fileUAPI, err := func() (*os.File, error) {
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
if uapiFdStr == "" { fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
return uapiOpen(interfaceName) if err != nil { return nil, err }
return os.NewFile(uintptr(fd), ""), nil
} }
return uapiOpen(interfaceName)
// use supplied fd
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil
}() }()
if err != nil { if err != nil { logger.Error("UAPI listen error: %v", err); os.Exit(1); return }
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(
mapToWireGuardLogLevel(loggerLevel),
"wireguard: ",
))
errs := make(chan error)
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
uapiListener, err = uapiListen(interfaceName, fileUAPI) uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil { if err != nil { logger.Error("Failed to listen on uapi socket: %v", err); os.Exit(1) }
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() { go func() {
for { for {
conn, err := uapiListener.Accept() conn, err := uapiListener.Accept()
if err != nil { if err != nil { return }
errs <- err
return
}
go dev.IpcHandle(conn) go dev.IpcHandle(conn)
} }
}() }()
logger.Info("UAPI listener started") logger.Info("UAPI listener started")
// Bring up the device if err = dev.Up(); err != nil { logger.Error("Failed to bring up WireGuard device: %v", err) }
err = dev.Up() if err = ConfigureInterface(interfaceName, wgData); err != nil { logger.Error("Failed to configure interface: %v", err) }
if err != nil { if httpServer != nil { httpServer.SetTunnelIP(wgData.TunnelIP) }
logger.Error("Failed to bring up WireGuard device: %v", err)
}
// configure the interface
err = ConfigureInterface(realInterfaceName, wgData)
if err != nil {
logger.Error("Failed to configure interface: %v", err)
}
// Set tunnel IP in HTTP server
if httpServer != nil {
httpServer.SetTunnelIP(wgData.TunnelIP)
}
peerMonitor = peermonitor.NewPeerMonitor( peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) { func(siteID int, connected bool, rtt time.Duration) {
@@ -689,28 +652,18 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
doHolepunch, doHolepunch,
) )
// loop over the sites and call ConfigurePeer for each one for i := range wgData.Sites {
for _, site := range wgData.Sites { site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
if httpServer != nil { if httpServer != nil {
httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false) httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
} }
err = ConfigurePeer(dev, site, privateKey, endpoint)
if err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
err = addRouteForServerIP(site.ServerIP, interfaceName) // Format the endpoint before configuring the peer.
if err != nil { site.Endpoint = formatEndpoint(site.Endpoint)
logger.Error("Failed to add route for peer: %v", err)
return
}
// Add routes for remote subnets if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { logger.Error("Failed to configure peer: %v", err); return }
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for peer: %v", err); return }
logger.Error("Failed to add routes for remote subnets: %v", err) if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return }
return
}
logger.Info("Configured peer %s", site.PublicKey) logger.Info("Configured peer %s", site.PublicKey)
} }
@@ -757,12 +710,11 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
break break
} }
} }
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { // Format the endpoint before updating the peer.
logger.Error("Failed to update peer: %v", err) siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
// Send error response if needed
return if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to update peer: %v", err); return }
}
// Remove old remote subnet routes if they changed // Remove old remote subnet routes if they changed
if oldRemoteSubnets != siteConfig.RemoteSubnets { if oldRemoteSubnets != siteConfig.RemoteSubnets {
@@ -780,12 +732,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
// Update successful // Update successful
logger.Info("Successfully updated peer for site %d", updateData.SiteId) logger.Info("Successfully updated peer for site %d", updateData.SiteId)
// If this is part of a WgData structure, update it for i := range wgData.Sites {
for i, site := range wgData.Sites { if wgData.Sites[i].SiteId == updateData.SiteId { wgData.Sites[i] = siteConfig; break }
if site.SiteId == updateData.SiteId {
wgData.Sites[i] = siteConfig
break
}
} }
} else { } else {
logger.Error("WireGuard device not initialized") logger.Error("WireGuard device not initialized")
@@ -820,23 +768,12 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
// Add the peer to WireGuard // Add the peer to WireGuard
if dev != nil { if dev != nil {
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { // Format the endpoint before adding the new peer.
logger.Error("Failed to add peer: %v", err) siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
return
}
// Add route for the new peer if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to add peer: %v", err); return }
err = addRouteForServerIP(siteConfig.ServerIP, interfaceName) if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for new peer: %v", err); return }
if err != nil { if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return }
logger.Error("Failed to add route for new peer: %v", err)
return
}
// Add routes for remote subnets
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
// Add successful // Add successful
logger.Info("Successfully added peer for site %d", addData.SiteId) logger.Info("Successfully added peer for site %d", addData.SiteId)

View File

@@ -3,6 +3,7 @@ package peermonitor
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"sync" "sync"
"time" "time"
@@ -204,12 +205,18 @@ func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) {
return return
} }
// Check for IPv6 and format the endpoint correctly
formattedEndpoint := relayEndpoint
if strings.Contains(relayEndpoint, ":") {
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
}
// Configure WireGuard to use the relay // Configure WireGuard to use the relay
wgConfig := fmt.Sprintf(`private_key=%s wgConfig := fmt.Sprintf(`private_key=%s
public_key=%s public_key=%s
allowed_ip=%s/32 allowed_ip=%s/32
endpoint=%s:21820 endpoint=%s:21820
persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, relayEndpoint) persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, formattedEndpoint)
err := pm.device.IpcSet(wgConfig) err := pm.device.IpcSet(wgConfig)
if err != nil { if err != nil {