mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
911 lines
24 KiB
Go
911 lines
24 KiB
Go
package olm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"runtime"
|
|
"time"
|
|
|
|
"github.com/fosrl/newt/bind"
|
|
"github.com/fosrl/newt/holepunch"
|
|
"github.com/fosrl/newt/logger"
|
|
"github.com/fosrl/newt/util"
|
|
"github.com/fosrl/olm/api"
|
|
"github.com/fosrl/olm/network"
|
|
"github.com/fosrl/olm/peermonitor"
|
|
"github.com/fosrl/olm/websocket"
|
|
"golang.zx2c4.com/wireguard/device"
|
|
"golang.zx2c4.com/wireguard/tun"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
type GlobalConfig struct {
|
|
// Logging
|
|
LogLevel string
|
|
|
|
// HTTP server
|
|
EnableAPI bool
|
|
HTTPAddr string
|
|
SocketPath string
|
|
Version string
|
|
|
|
// Source tracking (not in JSON)
|
|
sources map[string]string
|
|
}
|
|
|
|
type TunnelConfig struct {
|
|
// Connection settings
|
|
Endpoint string
|
|
ID string
|
|
Secret string
|
|
UserToken string
|
|
|
|
// Network settings
|
|
MTU int
|
|
DNS string
|
|
InterfaceName string
|
|
|
|
// Advanced
|
|
Holepunch bool
|
|
TlsClientCert string
|
|
|
|
// Parsed values (not in JSON)
|
|
PingIntervalDuration time.Duration
|
|
PingTimeoutDuration time.Duration
|
|
|
|
OrgID string
|
|
// DoNotCreateNewClient bool
|
|
|
|
FileDescriptorTun uint32
|
|
FileDescriptorUAPI uint32
|
|
}
|
|
|
|
var (
|
|
privateKey wgtypes.Key
|
|
connected bool
|
|
dev *device.Device
|
|
wgData WgData
|
|
holePunchData HolePunchData
|
|
uapiListener net.Listener
|
|
tdev tun.Device
|
|
apiServer *api.API
|
|
olmClient *websocket.Client
|
|
tunnelCancel context.CancelFunc
|
|
tunnelRunning bool
|
|
sharedBind *bind.SharedBind
|
|
holePunchManager *holepunch.Manager
|
|
peerMonitor *peermonitor.PeerMonitor
|
|
globalConfig GlobalConfig
|
|
globalCtx context.Context
|
|
stopRegister func()
|
|
stopPing chan struct{}
|
|
)
|
|
|
|
func Init(ctx context.Context, config GlobalConfig) {
|
|
globalConfig = config
|
|
globalCtx = ctx
|
|
|
|
// Create a cancellable context for internal shutdown control
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
|
|
|
if config.HTTPAddr != "" {
|
|
apiServer = api.NewAPI(config.HTTPAddr)
|
|
} else if config.SocketPath != "" {
|
|
apiServer = api.NewAPISocket(config.SocketPath)
|
|
}
|
|
|
|
apiServer.SetVersion(config.Version)
|
|
|
|
if err := apiServer.Start(); err != nil {
|
|
logger.Fatal("Failed to start HTTP server: %v", err)
|
|
}
|
|
|
|
// Set up API handlers
|
|
apiServer.SetHandlers(
|
|
// onConnect
|
|
func(req api.ConnectionRequest) error {
|
|
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
|
|
|
// Stop any existing tunnel before starting a new one
|
|
if olmClient != nil {
|
|
logger.Info("Stopping existing tunnel before starting new connection")
|
|
StopTunnel()
|
|
}
|
|
|
|
tunnelConfig := TunnelConfig{
|
|
Endpoint: req.Endpoint,
|
|
ID: req.ID,
|
|
Secret: req.Secret,
|
|
UserToken: req.UserToken,
|
|
MTU: req.MTU,
|
|
DNS: req.DNS,
|
|
InterfaceName: req.InterfaceName,
|
|
Holepunch: req.Holepunch,
|
|
TlsClientCert: req.TlsClientCert,
|
|
OrgID: req.OrgID,
|
|
}
|
|
|
|
var err error
|
|
// Parse ping interval
|
|
if req.PingInterval != "" {
|
|
tunnelConfig.PingIntervalDuration, err = time.ParseDuration(req.PingInterval)
|
|
if err != nil {
|
|
logger.Warn("Invalid PING_INTERVAL value: %s, using default 3 seconds", req.PingInterval)
|
|
tunnelConfig.PingIntervalDuration = 3 * time.Second
|
|
}
|
|
} else {
|
|
tunnelConfig.PingIntervalDuration = 3 * time.Second
|
|
}
|
|
// Parse ping timeout
|
|
if req.PingTimeout != "" {
|
|
tunnelConfig.PingTimeoutDuration, err = time.ParseDuration(req.PingTimeout)
|
|
if err != nil {
|
|
logger.Warn("Invalid PING_TIMEOUT value: %s, using default 5 seconds", req.PingTimeout)
|
|
tunnelConfig.PingTimeoutDuration = 5 * time.Second
|
|
}
|
|
} else {
|
|
tunnelConfig.PingTimeoutDuration = 5 * time.Second
|
|
}
|
|
if req.MTU == 0 {
|
|
tunnelConfig.MTU = 1420
|
|
}
|
|
if req.DNS == "" {
|
|
tunnelConfig.DNS = "9.9.9.9"
|
|
}
|
|
if req.InterfaceName == "" {
|
|
tunnelConfig.InterfaceName = "olm"
|
|
}
|
|
|
|
// Start the tunnel process with the new credentials
|
|
if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" {
|
|
logger.Info("Starting tunnel with new credentials")
|
|
go StartTunnel(tunnelConfig)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
// onSwitchOrg
|
|
func(req api.SwitchOrgRequest) error {
|
|
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
|
|
|
// Ensure we have an active olmClient
|
|
if olmClient == nil {
|
|
return fmt.Errorf("no active connection to switch organizations")
|
|
}
|
|
|
|
// Update the orgID in the API server
|
|
apiServer.SetOrgID(req.OrgID)
|
|
|
|
// Mark as not connected to trigger re-registration
|
|
connected = false
|
|
|
|
Close()
|
|
|
|
// Clear peer statuses in API
|
|
apiServer.SetRegistered(false)
|
|
apiServer.SetTunnelIP("")
|
|
|
|
// Trigger re-registration with new orgId
|
|
logger.Info("Re-registering with new orgId: %s", req.OrgID)
|
|
publicKey := privateKey.PublicKey()
|
|
stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
|
"publicKey": publicKey.String(),
|
|
"relay": true, // Default to relay mode for org switch
|
|
"olmVersion": globalConfig.Version,
|
|
"orgId": req.OrgID,
|
|
}, 1*time.Second)
|
|
|
|
return nil
|
|
},
|
|
// onDisconnect
|
|
func() error {
|
|
logger.Info("Processing disconnect request via API")
|
|
StopTunnel()
|
|
return nil
|
|
},
|
|
// onExit
|
|
func() error {
|
|
logger.Info("Processing shutdown request via API")
|
|
cancel()
|
|
return nil
|
|
},
|
|
)
|
|
}
|
|
|
|
func StartTunnel(config TunnelConfig) {
|
|
if tunnelRunning {
|
|
logger.Info("Tunnel already running")
|
|
return
|
|
}
|
|
|
|
tunnelRunning = true // Also set it here in case it is called externally
|
|
|
|
if config.Holepunch {
|
|
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
|
}
|
|
|
|
// Create a cancellable context for this tunnel process
|
|
tunnelCtx, cancel := context.WithCancel(globalCtx)
|
|
tunnelCancel = cancel
|
|
defer func() {
|
|
tunnelCancel = nil
|
|
}()
|
|
|
|
// Recreate channels for this tunnel session
|
|
stopPing = make(chan struct{})
|
|
|
|
var (
|
|
interfaceName = config.InterfaceName
|
|
id = config.ID
|
|
secret = config.Secret
|
|
endpoint = config.Endpoint
|
|
userToken = config.UserToken
|
|
)
|
|
|
|
apiServer.SetOrgID(config.OrgID)
|
|
|
|
// Create a new olm client using the provided credentials
|
|
olm, err := websocket.NewClient(
|
|
id, // Use provided ID
|
|
secret, // Use provided secret
|
|
userToken, // Use provided user token OPTIONAL
|
|
endpoint, // Use provided endpoint
|
|
config.PingIntervalDuration,
|
|
config.PingTimeoutDuration,
|
|
)
|
|
if err != nil {
|
|
logger.Error("Failed to create olm: %v", err)
|
|
return
|
|
}
|
|
|
|
// Store the client reference globally
|
|
olmClient = olm
|
|
|
|
privateKey, err = wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
logger.Error("Failed to generate private key: %v", err)
|
|
return
|
|
}
|
|
|
|
// Create shared UDP socket for both holepunch and WireGuard
|
|
if sharedBind == nil {
|
|
sourcePort, err := util.FindAvailableUDPPort(49152, 65535)
|
|
if err != nil {
|
|
logger.Error("Error finding available port: %v", err)
|
|
return
|
|
}
|
|
|
|
localAddr := &net.UDPAddr{
|
|
Port: int(sourcePort),
|
|
IP: net.IPv4zero,
|
|
}
|
|
|
|
udpConn, err := net.ListenUDP("udp", localAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to create shared UDP socket: %v", err)
|
|
return
|
|
}
|
|
|
|
sharedBind, err = bind.New(udpConn)
|
|
if err != nil {
|
|
logger.Error("Failed to create shared bind: %v", err)
|
|
udpConn.Close()
|
|
return
|
|
}
|
|
|
|
// Add a reference for the hole punch senders (creator already has one reference for WireGuard)
|
|
sharedBind.AddRef()
|
|
|
|
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
|
}
|
|
|
|
// Create the holepunch manager
|
|
if holePunchManager == nil {
|
|
holePunchManager = holepunch.NewManager(sharedBind, id, "olm")
|
|
}
|
|
|
|
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received message: %v", msg.Data)
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Info("Error marshaling data: %v", err)
|
|
return
|
|
}
|
|
|
|
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
|
|
logger.Info("Error unmarshaling target data: %v", err)
|
|
return
|
|
}
|
|
|
|
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
|
|
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
|
|
for i, node := range holePunchData.ExitNodes {
|
|
exitNodes[i] = holepunch.ExitNode{
|
|
Endpoint: node.Endpoint,
|
|
PublicKey: node.PublicKey,
|
|
}
|
|
}
|
|
|
|
// Start hole punching using the manager
|
|
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
|
|
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
|
|
logger.Warn("Failed to start hole punch: %v", err)
|
|
}
|
|
})
|
|
|
|
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
|
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
|
|
logger.Debug("Received message: %v", msg.Data)
|
|
|
|
type LegacyHolePunchData struct {
|
|
ServerPubKey string `json:"serverPubKey"`
|
|
Endpoint string `json:"endpoint"`
|
|
}
|
|
|
|
var legacyHolePunchData LegacyHolePunchData
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Info("Error marshaling data: %v", err)
|
|
return
|
|
}
|
|
|
|
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
|
|
logger.Info("Error unmarshaling target data: %v", err)
|
|
return
|
|
}
|
|
|
|
// Stop any existing hole punch operations
|
|
if holePunchManager != nil {
|
|
holePunchManager.Stop()
|
|
}
|
|
|
|
// Start hole punching for the exit node
|
|
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
|
|
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
|
|
logger.Warn("Failed to start hole punch: %v", err)
|
|
}
|
|
})
|
|
|
|
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received message: %v", msg.Data)
|
|
|
|
if connected {
|
|
logger.Info("Already connected. Ignoring new connection request.")
|
|
return
|
|
}
|
|
|
|
if stopRegister != nil {
|
|
stopRegister()
|
|
stopRegister = nil
|
|
}
|
|
|
|
// wait 10 milliseconds to ensure the previous connection is closed
|
|
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
// if there is an existing tunnel then close it
|
|
if dev != nil {
|
|
logger.Info("Got new message. Closing existing tunnel!")
|
|
dev.Close()
|
|
}
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Info("Error marshaling data: %v", err)
|
|
return
|
|
}
|
|
|
|
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
|
logger.Info("Error unmarshaling target data: %v", err)
|
|
return
|
|
}
|
|
|
|
tdev, err = func() (tun.Device, error) {
|
|
if config.FileDescriptorTun != 0 {
|
|
return createTUNFromFD(config.FileDescriptorTun, config.MTU)
|
|
}
|
|
var ifName = interfaceName
|
|
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
|
ifName, err = findUnusedUTUN()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return tun.CreateTUN(ifName, config.MTU)
|
|
}()
|
|
|
|
if err != nil {
|
|
logger.Error("Failed to create TUN device: %v", err)
|
|
return
|
|
}
|
|
|
|
if config.FileDescriptorTun == 0 {
|
|
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
|
interfaceName = realInterfaceName
|
|
}
|
|
}
|
|
|
|
// fileUAPI, err := func() (*os.File, error) {
|
|
// if config.FileDescriptorUAPI != 0 {
|
|
// fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32)
|
|
// if err != nil {
|
|
// return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err)
|
|
// }
|
|
// return os.NewFile(uintptr(fd), ""), nil
|
|
// }
|
|
// return uapiOpen(interfaceName)
|
|
// }()
|
|
// if err != nil {
|
|
// logger.Error("UAPI listen error: %v", err)
|
|
// os.Exit(1)
|
|
// return
|
|
// }
|
|
|
|
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
|
dev = device.NewDevice(tdev, sharedBind, (*device.Logger)(wgLogger))
|
|
|
|
// uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
|
// if err != nil {
|
|
// logger.Error("Failed to listen on uapi socket: %v", err)
|
|
// os.Exit(1)
|
|
// }
|
|
|
|
// go func() {
|
|
// for {
|
|
// conn, err := uapiListener.Accept()
|
|
// if err != nil {
|
|
|
|
// return
|
|
// }
|
|
// go dev.IpcHandle(conn)
|
|
// }
|
|
// }()
|
|
// logger.Info("UAPI listener started")
|
|
|
|
if err = dev.Up(); err != nil {
|
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
|
}
|
|
|
|
if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil {
|
|
logger.Error("Failed to configure interface: %v", err)
|
|
}
|
|
|
|
peerMonitor = peermonitor.NewPeerMonitor(
|
|
func(siteID int, connected bool, rtt time.Duration) {
|
|
// Find the site config to get endpoint information
|
|
var endpoint string
|
|
var isRelay bool
|
|
for _, site := range wgData.Sites {
|
|
if site.SiteId == siteID {
|
|
endpoint = site.Endpoint
|
|
// TODO: We'll need to track relay status separately
|
|
// For now, assume not using relay unless we get relay data
|
|
isRelay = !config.Holepunch
|
|
break
|
|
}
|
|
}
|
|
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
|
|
if connected {
|
|
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
|
|
} else {
|
|
logger.Warn("Peer %d is disconnected", siteID)
|
|
}
|
|
},
|
|
util.FixKey(privateKey.String()),
|
|
olm,
|
|
dev,
|
|
config.Holepunch,
|
|
)
|
|
|
|
for i := range wgData.Sites {
|
|
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
|
|
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
|
|
|
|
// Format the endpoint before configuring the peer.
|
|
site.Endpoint = formatEndpoint(site.Endpoint)
|
|
|
|
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
|
|
logger.Error("Failed to configure peer: %v", err)
|
|
return
|
|
}
|
|
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
|
|
logger.Error("Failed to add route for peer: %v", err)
|
|
return
|
|
}
|
|
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
|
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
|
return
|
|
}
|
|
|
|
logger.Info("Configured peer %s", site.PublicKey)
|
|
}
|
|
|
|
peerMonitor.Start()
|
|
|
|
apiServer.SetRegistered(true)
|
|
|
|
connected = true
|
|
|
|
logger.Info("WireGuard device created.")
|
|
})
|
|
|
|
olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received update-peer message: %v", msg.Data)
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Error("Error marshaling data: %v", err)
|
|
return
|
|
}
|
|
|
|
var updateData UpdatePeerData
|
|
if err := json.Unmarshal(jsonData, &updateData); err != nil {
|
|
logger.Error("Error unmarshaling update data: %v", err)
|
|
return
|
|
}
|
|
|
|
// Convert to SiteConfig
|
|
siteConfig := SiteConfig{
|
|
SiteId: updateData.SiteId,
|
|
Endpoint: updateData.Endpoint,
|
|
PublicKey: updateData.PublicKey,
|
|
ServerIP: updateData.ServerIP,
|
|
ServerPort: updateData.ServerPort,
|
|
RemoteSubnets: updateData.RemoteSubnets,
|
|
}
|
|
|
|
// Update the peer in WireGuard
|
|
if dev != nil {
|
|
// Find the existing peer to get old data
|
|
var oldRemoteSubnets string
|
|
var oldPublicKey string
|
|
for _, site := range wgData.Sites {
|
|
if site.SiteId == updateData.SiteId {
|
|
oldRemoteSubnets = site.RemoteSubnets
|
|
oldPublicKey = site.PublicKey
|
|
break
|
|
}
|
|
}
|
|
|
|
// If the public key has changed, remove the old peer first
|
|
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
|
|
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
|
|
if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
|
|
logger.Error("Failed to remove old peer: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Format the endpoint before updating the peer.
|
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
|
|
|
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
|
logger.Error("Failed to update peer: %v", err)
|
|
return
|
|
}
|
|
|
|
// Remove old remote subnet routes if they changed
|
|
if oldRemoteSubnets != siteConfig.RemoteSubnets {
|
|
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
|
|
logger.Error("Failed to remove old remote subnet routes: %v", err)
|
|
// Continue anyway to add new routes
|
|
}
|
|
|
|
// Add new remote subnet routes
|
|
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
|
logger.Error("Failed to add new remote subnet routes: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Update successful
|
|
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
|
for i := range wgData.Sites {
|
|
if wgData.Sites[i].SiteId == updateData.SiteId {
|
|
wgData.Sites[i] = siteConfig
|
|
break
|
|
}
|
|
}
|
|
} else {
|
|
logger.Error("WireGuard device not initialized")
|
|
}
|
|
})
|
|
|
|
// Handler for adding a new peer
|
|
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received add-peer message: %v", msg.Data)
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Error("Error marshaling data: %v", err)
|
|
return
|
|
}
|
|
|
|
var addData AddPeerData
|
|
if err := json.Unmarshal(jsonData, &addData); err != nil {
|
|
logger.Error("Error unmarshaling add data: %v", err)
|
|
return
|
|
}
|
|
|
|
// Convert to SiteConfig
|
|
siteConfig := SiteConfig{
|
|
SiteId: addData.SiteId,
|
|
Endpoint: addData.Endpoint,
|
|
PublicKey: addData.PublicKey,
|
|
ServerIP: addData.ServerIP,
|
|
ServerPort: addData.ServerPort,
|
|
RemoteSubnets: addData.RemoteSubnets,
|
|
}
|
|
|
|
// Add the peer to WireGuard
|
|
if dev != nil {
|
|
// Format the endpoint before adding the new peer.
|
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
|
|
|
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
|
logger.Error("Failed to add peer: %v", err)
|
|
return
|
|
}
|
|
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
|
|
logger.Error("Failed to add route for new peer: %v", err)
|
|
return
|
|
}
|
|
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
|
return
|
|
}
|
|
|
|
// Add successful
|
|
logger.Info("Successfully added peer for site %d", addData.SiteId)
|
|
|
|
// Update WgData with the new peer
|
|
wgData.Sites = append(wgData.Sites, siteConfig)
|
|
} else {
|
|
logger.Error("WireGuard device not initialized")
|
|
}
|
|
})
|
|
|
|
// Handler for removing a peer
|
|
olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received remove-peer message: %v", msg.Data)
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Error("Error marshaling data: %v", err)
|
|
return
|
|
}
|
|
|
|
var removeData RemovePeerData
|
|
if err := json.Unmarshal(jsonData, &removeData); err != nil {
|
|
logger.Error("Error unmarshaling remove data: %v", err)
|
|
return
|
|
}
|
|
|
|
// Find the peer to remove
|
|
var peerToRemove *SiteConfig
|
|
var newSites []SiteConfig
|
|
|
|
for _, site := range wgData.Sites {
|
|
if site.SiteId == removeData.SiteId {
|
|
peerToRemove = &site
|
|
} else {
|
|
newSites = append(newSites, site)
|
|
}
|
|
}
|
|
|
|
if peerToRemove == nil {
|
|
logger.Error("Peer with site ID %d not found", removeData.SiteId)
|
|
return
|
|
}
|
|
|
|
// Remove the peer from WireGuard
|
|
if dev != nil {
|
|
if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
|
logger.Error("Failed to remove peer: %v", err)
|
|
// Send error response if needed
|
|
return
|
|
}
|
|
|
|
// Remove route for the peer
|
|
err = removeRouteForServerIP(peerToRemove.ServerIP, interfaceName)
|
|
if err != nil {
|
|
logger.Error("Failed to remove route for peer: %v", err)
|
|
return
|
|
}
|
|
|
|
// Remove routes for remote subnets
|
|
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
|
|
logger.Error("Failed to remove routes for remote subnets: %v", err)
|
|
return
|
|
}
|
|
|
|
// Remove successful
|
|
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
|
|
|
|
// Update WgData to remove the peer
|
|
wgData.Sites = newSites
|
|
} else {
|
|
logger.Error("WireGuard device not initialized")
|
|
}
|
|
})
|
|
|
|
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received relay-peer message: %v", msg.Data)
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
if err != nil {
|
|
logger.Error("Error marshaling data: %v", err)
|
|
return
|
|
}
|
|
|
|
var relayData RelayPeerData
|
|
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
|
logger.Error("Error unmarshaling relay data: %v", err)
|
|
return
|
|
}
|
|
|
|
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
|
|
if err != nil {
|
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
|
}
|
|
|
|
// Update HTTP server to mark this peer as using relay
|
|
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
|
|
|
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
|
})
|
|
|
|
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
|
|
logger.Info("Received no-sites message - no sites available for connection")
|
|
|
|
if stopRegister != nil {
|
|
stopRegister()
|
|
stopRegister = nil
|
|
}
|
|
|
|
logger.Info("No sites available - stopped registration and holepunch processes")
|
|
})
|
|
|
|
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
|
|
logger.Info("Received terminate message")
|
|
olm.Close()
|
|
})
|
|
|
|
olm.OnConnect(func() error {
|
|
logger.Info("Websocket Connected")
|
|
|
|
apiServer.SetConnectionStatus(true)
|
|
|
|
if connected {
|
|
logger.Debug("Already connected, skipping registration")
|
|
return nil
|
|
}
|
|
|
|
publicKey := privateKey.PublicKey()
|
|
|
|
if stopRegister == nil {
|
|
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
|
|
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
|
"publicKey": publicKey.String(),
|
|
"relay": !config.Holepunch,
|
|
"olmVersion": globalConfig.Version,
|
|
"orgId": config.OrgID,
|
|
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
|
}, 1*time.Second)
|
|
}
|
|
|
|
go keepSendingPing(olm)
|
|
|
|
return nil
|
|
})
|
|
|
|
olm.OnTokenUpdate(func(token string) {
|
|
if holePunchManager != nil {
|
|
holePunchManager.SetToken(token)
|
|
}
|
|
})
|
|
|
|
// Connect to the WebSocket server
|
|
if err := olm.Connect(); err != nil {
|
|
logger.Error("Failed to connect to server: %v", err)
|
|
return
|
|
}
|
|
defer olm.Close()
|
|
|
|
// Wait for context cancellation
|
|
<-tunnelCtx.Done()
|
|
logger.Info("Tunnel process context cancelled, cleaning up")
|
|
}
|
|
|
|
func Close() {
|
|
// Stop hole punch manager
|
|
if holePunchManager != nil {
|
|
holePunchManager.Stop()
|
|
}
|
|
|
|
if stopPing != nil {
|
|
select {
|
|
case <-stopPing:
|
|
// Channel already closed
|
|
default:
|
|
close(stopPing)
|
|
}
|
|
}
|
|
|
|
if stopRegister != nil {
|
|
stopRegister()
|
|
stopRegister = nil
|
|
}
|
|
|
|
if peerMonitor != nil {
|
|
peerMonitor.Stop()
|
|
peerMonitor = nil
|
|
}
|
|
|
|
if uapiListener != nil {
|
|
uapiListener.Close()
|
|
uapiListener = nil
|
|
}
|
|
if dev != nil {
|
|
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
|
|
dev = nil
|
|
}
|
|
// Close TUN device
|
|
if tdev != nil {
|
|
tdev.Close()
|
|
tdev = nil
|
|
}
|
|
|
|
// Release the hole punch reference to the shared bind
|
|
if sharedBind != nil {
|
|
// Release hole punch reference (WireGuard already released its reference via dev.Close())
|
|
logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount())
|
|
sharedBind.Release()
|
|
sharedBind = nil
|
|
logger.Info("Released shared UDP bind")
|
|
}
|
|
|
|
logger.Info("Olm service stopped")
|
|
}
|
|
|
|
// StopTunnel stops just the tunnel process and websocket connection
|
|
// without shutting down the entire application
|
|
func StopTunnel() {
|
|
logger.Info("Stopping tunnel process")
|
|
|
|
// Cancel the tunnel context if it exists
|
|
if tunnelCancel != nil {
|
|
tunnelCancel()
|
|
// Give it a moment to clean up
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
|
|
// Close the websocket connection
|
|
if olmClient != nil {
|
|
olmClient.Close()
|
|
olmClient = nil
|
|
}
|
|
|
|
Close()
|
|
|
|
// Reset the connected state
|
|
connected = false
|
|
tunnelRunning = false
|
|
|
|
// Update API server status
|
|
apiServer.SetConnectionStatus(false)
|
|
apiServer.SetRegistered(false)
|
|
apiServer.SetTunnelIP("")
|
|
|
|
network.ClearNetworkSettings()
|
|
|
|
logger.Info("Tunnel process stopped")
|
|
}
|