From 15e96a779cc3ff109e4c4cf6a46ae0cdbd359ec9 Mon Sep 17 00:00:00 2001 From: Varun Narravula Date: Mon, 5 Jan 2026 01:41:54 -0800 Subject: [PATCH] refactor(olm): convert global state into an olm instance Former-commit-id: b755f77d95ecc9d645806fa33a11d91261cfd059 --- api/api.go | 29 +- main.go | 12 +- olm/connect.go | 223 +++++++++ olm/data.go | 197 ++++++++ olm/olm.go | 976 ++++++++------------------------------- olm/peer.go | 195 ++++++++ olm/{util.go => ping.go} | 6 +- olm/types.go | 2 +- 8 files changed, 841 insertions(+), 799 deletions(-) create mode 100644 olm/connect.go create mode 100644 olm/data.go create mode 100644 olm/peer.go rename olm/{util.go => ping.go} (89%) diff --git a/api/api.go b/api/api.go index 91d9f37..a6ac9cd 100644 --- a/api/api.go +++ b/api/api.go @@ -63,23 +63,26 @@ type StatusResponse struct { // API represents the HTTP server and its state type API struct { - addr string - socketPath string - listener net.Listener - server *http.Server + addr string + socketPath string + listener net.Listener + server *http.Server + onConnect func(ConnectionRequest) error onSwitchOrg func(SwitchOrgRequest) error onDisconnect func() error onExit func() error + statusMu sync.RWMutex peerStatuses map[int]*PeerStatus connectedAt time.Time isConnected bool isRegistered bool isTerminated bool - version string - agent string - orgID string + + version string + agent string + orgID string } // NewAPI creates a new HTTP server that listens on a TCP address @@ -173,7 +176,7 @@ func (s *API) Stop() error { // Close the server first, which will also close the listener gracefully if s.server != nil { - s.server.Close() + _ = s.server.Close() } // Clean up socket file if using Unix socket @@ -358,7 +361,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusAccepted) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "connection request accepted", }) } @@ -406,7 +409,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "ok", }) } @@ -423,7 +426,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) { // Return a success response first w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "shutdown initiated", }) @@ -472,7 +475,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "org switch request accepted", }) } @@ -506,7 +509,7 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "disconnect initiated", }) } diff --git a/main.go b/main.go index 5b6c15e..2bf8dcd 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/updates" - "github.com/fosrl/olm/olm" + olmpkg "github.com/fosrl/olm/olm" ) func main() { @@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt } // Create a new olm.Config struct and copy values from the main config - olmConfig := olm.GlobalConfig{ + olmConfig := olmpkg.OlmConfig{ LogLevel: config.LogLevel, EnableAPI: config.EnableAPI, HTTPAddr: config.HTTPAddr, @@ -222,13 +222,17 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE } - olm.Init(ctx, olmConfig) + olm, err := olmpkg.Init(ctx, olmConfig) + if err != nil { + logger.Fatal("Failed to initialize olm: %v", err) + } + if err := olm.StartApi(); err != nil { logger.Fatal("Failed to start API server: %v", err) } if config.ID != "" && config.Secret != "" && config.Endpoint != "" { - tunnelConfig := olm.TunnelConfig{ + tunnelConfig := olmpkg.TunnelConfig{ Endpoint: config.Endpoint, ID: config.ID, Secret: config.Secret, diff --git a/olm/connect.go b/olm/connect.go new file mode 100644 index 0000000..568c731 --- /dev/null +++ b/olm/connect.go @@ -0,0 +1,223 @@ +package olm + +import ( + "encoding/json" + "fmt" + "os" + "runtime" + "strconv" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + olmDevice "github.com/fosrl/olm/device" + "github.com/fosrl/olm/dns" + dnsOverride "github.com/fosrl/olm/dns/override" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" +) + +func (o *Olm) handleConnect(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + + var wgData WgData + + if o.connected { + logger.Info("Already connected. Ignoring new connection request.") + return + } + + if o.stopRegister != nil { + o.stopRegister() + o.stopRegister = nil + } + + if o.updateRegister != nil { + o.updateRegister = nil + } + + // if there is an existing tunnel then close it + if o.dev != nil { + logger.Info("Got new message. Closing existing tunnel!") + o.dev.Close() + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &wgData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + + o.tdev, err = func() (tun.Device, error) { + if o.tunnelConfig.FileDescriptorTun != 0 { + return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU) + } + ifName := o.tunnelConfig.InterfaceName + if runtime.GOOS == "darwin" { // this is if we dont pass a fd + ifName, err = network.FindUnusedUTUN() + if err != nil { + return nil, err + } + } + return tun.CreateTUN(ifName, o.tunnelConfig.MTU) + }() + if err != nil { + logger.Error("Failed to create TUN device: %v", err) + return + } + + // if config.FileDescriptorTun == 0 { + if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? + o.tunnelConfig.InterfaceName = realInterfaceName + } + // } + + // Wrap TUN device with packet filter for DNS proxy + o.middleDev = olmDevice.NewMiddleDevice(o.tdev) + + wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") + // Use filtered device instead of raw TUN device + o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger)) + + if o.tunnelConfig.EnableUAPI { + fileUAPI, err := func() (*os.File, error) { + if o.tunnelConfig.FileDescriptorUAPI != 0 { + fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) + } + return os.NewFile(uintptr(fd), ""), nil + } + return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName) + }() + if err != nil { + logger.Error("UAPI listen error: %v", err) + os.Exit(1) + return + } + + o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI) + if err != nil { + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) + } + + go func() { + for { + conn, err := o.uapiListener.Accept() + if err != nil { + return + } + go o.dev.IpcHandle(conn) + } + }() + logger.Info("UAPI listener started") + } + + if err = o.dev.Up(); err != nil { + logger.Error("Failed to bring up WireGuard device: %v", err) + } + + // Extract interface IP (strip CIDR notation if present) + interfaceIP := wgData.TunnelIP + if strings.Contains(interfaceIP, "/") { + interfaceIP = strings.Split(interfaceIP, "/")[0] + } + + // Create and start DNS proxy + o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP) + if err != nil { + logger.Error("Failed to create DNS proxy: %v", err) + } + + if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil { + logger.Error("Failed to o.tunnelConfigure interface: %v", err) + } + + if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet + logger.Error("Failed to add route for utility subnet: %v", err) + } + + // Create peer manager with integrated peer monitoring + o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ + Device: o.dev, + DNSProxy: o.dnsProxy, + InterfaceName: o.tunnelConfig.InterfaceName, + PrivateKey: o.privateKey, + MiddleDev: o.middleDev, + LocalIP: interfaceIP, + SharedBind: o.sharedBind, + WSClient: o.olmClient, + APIServer: o.apiServer, + }) + + for i := range wgData.Sites { + site := wgData.Sites[i] + var siteEndpoint string + // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer + if site.RelayEndpoint != "" { + siteEndpoint = site.RelayEndpoint + } else { + siteEndpoint = site.Endpoint + } + + o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) + + if err := o.peerManager.AddPeer(site); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + logger.Info("Configured peer %s", site.PublicKey) + } + + o.peerManager.Start() + + if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime + logger.Error("Failed to start DNS proxy: %v", err) + } + + if o.tunnelConfig.OverrideDNS { + // Set up DNS override to use our DNS proxy + if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { + logger.Error("Failed to setup DNS override: %v", err) + return + } + + network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()}) + } + + o.apiServer.SetRegistered(true) + + o.connected = true + + // Invoke onConnected callback if configured + if o.olmConfig.OnConnected != nil { + go o.olmConfig.OnConnected() + } + + logger.Info("WireGuard device created.") +} + +func (o *Olm) handleTerminate(msg websocket.WSMessage) { + logger.Info("Received terminate message") + o.apiServer.SetTerminated(true) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) + o.apiServer.ClearPeerStatuses() + + network.ClearNetworkSettings() + + o.Close() + + if o.olmConfig.OnTerminated != nil { + go o.olmConfig.OnTerminated() + } +} diff --git a/olm/data.go b/olm/data.go new file mode 100644 index 0000000..9c8d33f --- /dev/null +++ b/olm/data.go @@ -0,0 +1,197 @@ +package olm + +import ( + "encoding/json" + "time" + + "github.com/fosrl/newt/holepunch" + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" +) + +func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) { + logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var addSubnetsData peers.PeerAdd + if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil { + logger.Error("Error unmarshaling add-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId) + return + } + + // Add new subnets + for _, subnet := range addSubnetsData.RemoteSubnets { + if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases + for _, alias := range addSubnetsData.Aliases { + if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } +} + +func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) { + logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeSubnetsData peers.RemovePeerData + if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil { + logger.Error("Error unmarshaling remove-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId) + return + } + + // Remove subnets + for _, subnet := range removeSubnetsData.RemoteSubnets { + if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Remove aliases + for _, alias := range removeSubnetsData.Aliases { + if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } +} + +func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) { + logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateSubnetsData peers.UpdatePeerData + if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil { + logger.Error("Error unmarshaling update-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId) + return + } + + // Add new subnets BEFORE removing old ones to preserve shared subnets + // This ensures that if an old and new subnet are the same on different peers, + // the route won't be temporarily removed + for _, subnet := range updateSubnetsData.NewRemoteSubnets { + if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Remove old subnets after new ones are added + for _, subnet := range updateSubnetsData.OldRemoteSubnets { + if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases BEFORE removing old ones to preserve shared IP addresses + // This ensures that if an old and new alias share the same IP, the IP won't be + // temporarily removed from the allowed IPs list + for _, alias := range updateSubnetsData.NewAliases { + if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } + + // Remove old aliases after new ones are added + for _, alias := range updateSubnetsData.OldAliases { + if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } + + logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) +} + +func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { + logger.Debug("Received peer-handshake message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling handshake data: %v", err) + return + } + + var handshakeData struct { + SiteId int `json:"siteId"` + ExitNode struct { + PublicKey string `json:"publicKey"` + Endpoint string `json:"endpoint"` + RelayPort uint16 `json:"relayPort"` + } `json:"exitNode"` + } + + if err := json.Unmarshal(jsonData, &handshakeData); err != nil { + logger.Error("Error unmarshaling handshake data: %v", err) + return + } + + // Get existing peer from PeerManager + _, exists := o.peerManager.GetPeer(handshakeData.SiteId) + if exists { + logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) + return + } + + relayPort := handshakeData.ExitNode.RelayPort + if relayPort == 0 { + relayPort = 21820 // default relay port + } + + siteId := handshakeData.SiteId + exitNode := holepunch.ExitNode{ + Endpoint: handshakeData.ExitNode.Endpoint, + RelayPort: relayPort, + PublicKey: handshakeData.ExitNode.PublicKey, + SiteIds: []int{siteId}, + } + + added := o.holePunchManager.AddExitNode(exitNode) + if added { + logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) + } else { + logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) + } + + o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt + o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud + + // Send handshake acknowledgment back to server with retry + o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + }, 1*time.Second) + + logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) +} diff --git a/olm/olm.go b/olm/olm.go index 774a3cb..6d8f7a5 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -2,15 +2,11 @@ package olm import ( "context" - "encoding/json" "fmt" "net" "net/http" _ "net/http/pprof" "os" - "runtime" - "strconv" - "strings" "time" "github.com/fosrl/newt/bind" @@ -30,41 +26,49 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -var ( - privateKey wgtypes.Key - connected bool - dev *device.Device - uapiListener net.Listener - tdev tun.Device - middleDev *olmDevice.MiddleDevice - interfaceName string +type Olm struct { + privateKey wgtypes.Key + logFile *os.File + + connected bool + tunnelRunning bool + + uapiListener net.Listener + dev *device.Device + tdev tun.Device + middleDev *olmDevice.MiddleDevice + sharedBind *bind.SharedBind + dnsProxy *dns.DNSProxy apiServer *api.API olmClient *websocket.Client - tunnelCancel context.CancelFunc - tunnelRunning bool - sharedBind *bind.SharedBind holePunchManager *holepunch.Manager - globalConfig GlobalConfig - tunnelConfig TunnelConfig - globalCtx context.Context - stopRegister func() - stopPeerSend func() - updateRegister func(newData interface{}) - stopPing chan struct{} peerManager *peers.PeerManager -) + + olmCtx context.Context + tunnelCancel context.CancelFunc + + olmConfig OlmConfig + tunnelConfig TunnelConfig + + stopRegister func() + stopPeerSend func() + updateRegister func(newData any) + + stopPing chan struct{} +} // initTunnelInfo creates the shared UDP socket and holepunch manager. // This is used during initial tunnel setup and when switching organizations. -func initTunnelInfo(clientID string) error { - var err error - privateKey, err = wgtypes.GeneratePrivateKey() +func (o *Olm) initTunnelInfo(clientID string) error { + privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { logger.Error("Failed to generate private key: %v", err) return err } + o.privateKey = privateKey + sourcePort, err := util.FindAvailableUDPPort(49152, 65535) if err != nil { return fmt.Errorf("failed to find available UDP port: %w", err) @@ -80,27 +84,26 @@ func initTunnelInfo(clientID string) error { return fmt.Errorf("failed to create UDP socket: %w", err) } - sharedBind, err = bind.New(udpConn) + sharedBind, err := bind.New(udpConn) if err != nil { - udpConn.Close() + _ = udpConn.Close() return fmt.Errorf("failed to create shared bind: %w", err) } + o.sharedBind = sharedBind + // Add a reference for the hole punch senders (creator already has one reference for WireGuard) sharedBind.AddRef() logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) // Create the holepunch manager - holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) + o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) return nil } -func Init(ctx context.Context, config GlobalConfig) { - globalConfig = config - globalCtx = ctx - +func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) // Start pprof server if enabled @@ -112,25 +115,27 @@ func Init(ctx context.Context, config GlobalConfig) { } }() } - + + var logFile *os.File if config.LogFilePath != "" { - logFile, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + file, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) if err != nil { logger.Fatal("Failed to open log file: %v", err) + return nil, err } - // TODO: figure out how to close file, if set - logger.SetOutput(logFile) - return + logger.SetOutput(file) + logFile = file } logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() if err != nil { logger.Fatal("Insufficient permissions to create native TUN interface: %v", err) - return + return nil, err } + var apiServer *api.API if config.HTTPAddr != "" { apiServer = api.NewAPI(config.HTTPAddr) } else if config.SocketPath != "" { @@ -143,18 +148,24 @@ func Init(ctx context.Context, config GlobalConfig) { apiServer.SetVersion(config.Version) apiServer.SetAgent(config.Agent) - // Set up API handlers - apiServer.SetHandlers( + newOlm := &Olm{ + logFile: logFile, + olmCtx: ctx, + apiServer: apiServer, + olmConfig: config, + } + + newOlm.registerAPICallbacks() + + return newOlm, nil +} + +func (o *Olm) registerAPICallbacks() { + o.apiServer.SetHandlers( // onConnect func(req api.ConnectionRequest) error { logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) - // Stop any existing tunnel before starting a new one - if olmClient != nil { - logger.Info("Stopping existing tunnel before starting new connection") - StopTunnel() - } - tunnelConfig := TunnelConfig{ Endpoint: req.Endpoint, ID: req.ID, @@ -208,7 +219,7 @@ func Init(ctx context.Context, config GlobalConfig) { // Start the tunnel process with the new credentials if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" { logger.Info("Starting tunnel with new credentials") - go StartTunnel(tunnelConfig) + go o.StartTunnel(tunnelConfig) } return nil @@ -216,66 +227,64 @@ func Init(ctx context.Context, config GlobalConfig) { // onSwitchOrg func(req api.SwitchOrgRequest) error { logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID) - return SwitchOrg(req.OrgID) + return o.SwitchOrg(req.OrgID) }, // onDisconnect func() error { logger.Info("Processing disconnect request via API") - return StopTunnel() + return o.StopTunnel() }, // onExit func() error { logger.Info("Processing shutdown request via API") - Close() - if globalConfig.OnExit != nil { - globalConfig.OnExit() + o.Close() + if o.olmConfig.OnExit != nil { + o.olmConfig.OnExit() } return nil }, ) } -func StartTunnel(config TunnelConfig) { - if tunnelRunning { +func (o *Olm) StartTunnel(config TunnelConfig) { + if o.tunnelRunning { logger.Info("Tunnel already running") return } - tunnelRunning = true // Also set it here in case it is called externally - tunnelConfig = config + o.tunnelRunning = true // Also set it here in case it is called externally + o.tunnelConfig = config // Reset terminated status when tunnel starts - apiServer.SetTerminated(false) + o.apiServer.SetTerminated(false) // debug print out the whole config logger.Debug("Starting tunnel with config: %+v", config) // Create a cancellable context for this tunnel process - tunnelCtx, cancel := context.WithCancel(globalCtx) - tunnelCancel = cancel - defer func() { - tunnelCancel = nil - }() + tunnelCtx, cancel := context.WithCancel(o.olmCtx) + o.tunnelCancel = cancel // Recreate channels for this tunnel session - stopPing = make(chan struct{}) + o.stopPing = make(chan struct{}) var ( id = config.ID secret = config.Secret userToken = config.UserToken ) - interfaceName = config.InterfaceName - apiServer.SetOrgID(config.OrgID) + o.tunnelConfig.InterfaceName = config.InterfaceName - // Create a new olm client using the provided credentials - olm, err := websocket.NewClient( - id, // Use provided ID - secret, // Use provided secret - userToken, // Use provided user token OPTIONAL + o.apiServer.SetOrgID(config.OrgID) + + // Create a new olmClient client using the provided credentials + olmClient, err := websocket.NewClient( + id, + secret, + userToken, config.OrgID, - config.Endpoint, // Use provided endpoint + config.Endpoint, config.PingIntervalDuration, config.PingTimeoutDuration, ) @@ -284,638 +293,70 @@ func StartTunnel(config TunnelConfig) { return } - // Store the client reference globally - olmClient = olm - // Create shared UDP socket and holepunch manager - if err := initTunnelInfo(id); err != nil { + if err := o.initTunnelInfo(id); err != nil { logger.Error("%v", err) return } - olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - - var wgData WgData - - if connected { - logger.Info("Already connected. Ignoring new connection request.") - return - } - - if stopRegister != nil { - stopRegister() - stopRegister = nil - } - - if updateRegister != nil { - updateRegister = nil - } - - // if there is an existing tunnel then close it - if dev != nil { - logger.Info("Got new message. Closing existing tunnel!") - dev.Close() - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &wgData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - tdev, err = func() (tun.Device, error) { - if config.FileDescriptorTun != 0 { - return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) - } - ifName := interfaceName - if runtime.GOOS == "darwin" { // this is if we dont pass a fd - ifName, err = network.FindUnusedUTUN() - if err != nil { - return nil, err - } - } - return tun.CreateTUN(ifName, config.MTU) - }() - if err != nil { - logger.Error("Failed to create TUN device: %v", err) - return - } - - // if config.FileDescriptorTun == 0 { - if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? - interfaceName = realInterfaceName - } - // } - - // Wrap TUN device with packet filter for DNS proxy - middleDev = olmDevice.NewMiddleDevice(tdev) - - wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") - // Use filtered device instead of raw TUN device - dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger)) - - if config.EnableUAPI { - fileUAPI, err := func() (*os.File, error) { - if config.FileDescriptorUAPI != 0 { - fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32) - if err != nil { - return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) - } - return os.NewFile(uintptr(fd), ""), nil - } - return olmDevice.UapiOpen(interfaceName) - }() - if err != nil { - logger.Error("UAPI listen error: %v", err) - os.Exit(1) - return - } - - uapiListener, err = olmDevice.UapiListen(interfaceName, fileUAPI) - if err != nil { - logger.Error("Failed to listen on uapi socket: %v", err) - os.Exit(1) - } - - go func() { - for { - conn, err := uapiListener.Accept() - if err != nil { - return - } - go dev.IpcHandle(conn) - } - }() - logger.Info("UAPI listener started") - } - - if err = dev.Up(); err != nil { - logger.Error("Failed to bring up WireGuard device: %v", err) - } - - // Extract interface IP (strip CIDR notation if present) - interfaceIP := wgData.TunnelIP - if strings.Contains(interfaceIP, "/") { - interfaceIP = strings.Split(interfaceIP, "/")[0] - } - - // Create and start DNS proxy - dnsProxy, err = dns.NewDNSProxy(middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP) - if err != nil { - logger.Error("Failed to create DNS proxy: %v", err) - } - - if err = network.ConfigureInterface(interfaceName, wgData.TunnelIP, config.MTU); err != nil { - logger.Error("Failed to configure interface: %v", err) - } - - if network.AddRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet - logger.Error("Failed to add route for utility subnet: %v", err) - } - - // Create peer manager with integrated peer monitoring - peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ - Device: dev, - DNSProxy: dnsProxy, - InterfaceName: interfaceName, - PrivateKey: privateKey, - MiddleDev: middleDev, - LocalIP: interfaceIP, - SharedBind: sharedBind, - WSClient: olm, - APIServer: apiServer, - }) - - for i := range wgData.Sites { - site := wgData.Sites[i] - var siteEndpoint string - // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer - if site.RelayEndpoint != "" { - siteEndpoint = site.RelayEndpoint - } else { - siteEndpoint = site.Endpoint - } - - apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) - - if err := peerManager.AddPeer(site); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - - logger.Info("Configured peer %s", site.PublicKey) - } - - peerManager.Start() - - if err := dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime - logger.Error("Failed to start DNS proxy: %v", err) - } - - if config.OverrideDNS { - // Set up DNS override to use our DNS proxy - if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy.GetProxyIP()); err != nil { - logger.Error("Failed to setup DNS override: %v", err) - return - } - - network.SetDNSServers([]string{dnsProxy.GetProxyIP().String()}) - } - - apiServer.SetRegistered(true) - - connected = true - - // Invoke onConnected callback if configured - if globalConfig.OnConnected != nil { - go globalConfig.OnConnected() - } - - logger.Info("WireGuard device created.") - }) - - olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { - logger.Debug("Received update-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var updateData peers.SiteConfig - if err := json.Unmarshal(jsonData, &updateData); err != nil { - logger.Error("Error unmarshaling update data: %v", err) - return - } - - // Get existing peer from PeerManager - existingPeer, exists := peerManager.GetPeer(updateData.SiteId) - if !exists { - logger.Warn("Peer with site ID %d not found", updateData.SiteId) - return - } - - // Create updated site config by merging with existing data - siteConfig := existingPeer - - if updateData.Endpoint != "" { - siteConfig.Endpoint = updateData.Endpoint - } - if updateData.RelayEndpoint != "" { - siteConfig.RelayEndpoint = updateData.RelayEndpoint - } - if updateData.PublicKey != "" { - siteConfig.PublicKey = updateData.PublicKey - } - if updateData.ServerIP != "" { - siteConfig.ServerIP = updateData.ServerIP - } - if updateData.ServerPort != 0 { - siteConfig.ServerPort = updateData.ServerPort - } - if updateData.RemoteSubnets != nil { - siteConfig.RemoteSubnets = updateData.RemoteSubnets - } - - if err := peerManager.UpdatePeer(siteConfig); err != nil { - logger.Error("Failed to update peer: %v", err) - return - } - - // If the endpoint changed, trigger holepunch to refresh NAT mappings - if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { - logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) - holePunchManager.TriggerHolePunch() - holePunchManager.ResetInterval() - } - - // Update successful - logger.Info("Successfully updated peer for site %d", updateData.SiteId) - }) - - // Handler for adding a new peer - olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { - logger.Debug("Received add-peer message: %v", msg.Data) - - if stopPeerSend != nil { - stopPeerSend() - stopPeerSend = nil - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var siteConfig peers.SiteConfig - if err := json.Unmarshal(jsonData, &siteConfig); err != nil { - logger.Error("Error unmarshaling add data: %v", err) - return - } - - holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it - - if err := peerManager.AddPeer(siteConfig); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - - // Add successful - logger.Info("Successfully added peer for site %d", siteConfig.SiteId) - }) - - // Handler for removing a peer - olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) { - logger.Debug("Received remove-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var removeData peers.PeerRemove - if err := json.Unmarshal(jsonData, &removeData); err != nil { - logger.Error("Error unmarshaling remove data: %v", err) - return - } - - if err := peerManager.RemovePeer(removeData.SiteId); err != nil { - logger.Error("Failed to remove peer: %v", err) - return - } - - // Remove any exit nodes associated with this peer from hole punching - if holePunchManager != nil { - removed := holePunchManager.RemoveExitNodesByPeer(removeData.SiteId) - if removed > 0 { - logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId) - } - } - - // Remove successful - logger.Info("Successfully removed peer for site %d", removeData.SiteId) - }) - - // Handler for adding remote subnets to a peer - olm.RegisterHandler("olm/wg/peer/data/add", func(msg websocket.WSMessage) { - logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var addSubnetsData peers.PeerAdd - if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil { - logger.Error("Error unmarshaling add-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(addSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId) - return - } - - // Add new subnets - for _, subnet := range addSubnetsData.RemoteSubnets { - if err := peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to add allowed IP %s: %v", subnet, err) - } - } - - // Add new aliases - for _, alias := range addSubnetsData.Aliases { - if err := peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil { - logger.Error("Failed to add alias %s: %v", alias.Alias, err) - } - } - }) - - // Handler for removing remote subnets from a peer - olm.RegisterHandler("olm/wg/peer/data/remove", func(msg websocket.WSMessage) { - logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var removeSubnetsData peers.RemovePeerData - if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil { - logger.Error("Error unmarshaling remove-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(removeSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId) - return - } - - // Remove subnets - for _, subnet := range removeSubnetsData.RemoteSubnets { - if err := peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to remove allowed IP %s: %v", subnet, err) - } - } - - // Remove aliases - for _, alias := range removeSubnetsData.Aliases { - if err := peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil { - logger.Error("Failed to remove alias %s: %v", alias.Alias, err) - } - } - }) - - // Handler for updating remote subnets of a peer (remove old, add new in one operation) - olm.RegisterHandler("olm/wg/peer/data/update", func(msg websocket.WSMessage) { - logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var updateSubnetsData peers.UpdatePeerData - if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil { - logger.Error("Error unmarshaling update-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(updateSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", updateSubnetsData.SiteId) - return - } - - // Add new subnets BEFORE removing old ones to preserve shared subnets - // This ensures that if an old and new subnet are the same on different peers, - // the route won't be temporarily removed - for _, subnet := range updateSubnetsData.NewRemoteSubnets { - if err := peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to add allowed IP %s: %v", subnet, err) - } - } - - // Remove old subnets after new ones are added - for _, subnet := range updateSubnetsData.OldRemoteSubnets { - if err := peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to remove allowed IP %s: %v", subnet, err) - } - } - - // Add new aliases BEFORE removing old ones to preserve shared IP addresses - // This ensures that if an old and new alias share the same IP, the IP won't be - // temporarily removed from the allowed IPs list - for _, alias := range updateSubnetsData.NewAliases { - if err := peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil { - logger.Error("Failed to add alias %s: %v", alias.Alias, err) - } - } - - // Remove old aliases after new ones are added - for _, alias := range updateSubnetsData.OldAliases { - if err := peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil { - logger.Error("Failed to remove alias %s: %v", alias.Alias, err) - } - } - - logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) - }) - - olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) { - logger.Debug("Received relay-peer message: %v", msg.Data) - - // Check if peerManager is still valid (may be nil during shutdown) - if peerManager == nil { - logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") - return - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var relayData peers.RelayPeerData - if err := json.Unmarshal(jsonData, &relayData); err != nil { - logger.Error("Error unmarshaling relay data: %v", err) - return - } - - primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - // Update HTTP server to mark this peer as using relay - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) - - peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort) - }) - - olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) { - logger.Debug("Received unrelay-peer message: %v", msg.Data) - - // Check if peerManager is still valid (may be nil during shutdown) - if peerManager == nil { - logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") - return - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var relayData peers.UnRelayPeerData - if err := json.Unmarshal(jsonData, &relayData); err != nil { - logger.Error("Error unmarshaling relay data: %v", err) - return - } - - primaryRelay, err := util.ResolveDomain(relayData.Endpoint) - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - // Update HTTP server to mark this peer as using relay - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) - - peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) - }) + // Handlers for managing connection status + olmClient.RegisterHandler("olm/wg/connect", o.handleConnect) + olmClient.RegisterHandler("olm/terminate", o.handleTerminate) + + // Handlers for managing peers + olmClient.RegisterHandler("olm/wg/peer/add", o.handleWgPeerAdd) + olmClient.RegisterHandler("olm/wg/peer/remove", o.handleWgPeerRemove) + olmClient.RegisterHandler("olm/wg/peer/update", o.handleWgPeerUpdate) + olmClient.RegisterHandler("olm/wg/peer/relay", o.handleWgPeerRelay) + olmClient.RegisterHandler("olm/wg/peer/unrelay", o.handleWgPeerUnrelay) + + // Handlers for managing remote subnets to a peer + olmClient.RegisterHandler("olm/wg/peer/data/add", o.handleWgPeerAddData) + olmClient.RegisterHandler("olm/wg/peer/data/remove", o.handleWgPeerRemoveData) + olmClient.RegisterHandler("olm/wg/peer/data/update", o.handleWgPeerUpdateData) // Handler for peer handshake - adds exit node to holepunch rotation and notifies server - olm.RegisterHandler("olm/wg/peer/holepunch/site/add", func(msg websocket.WSMessage) { - logger.Debug("Received peer-handshake message: %v", msg.Data) + olmClient.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling handshake data: %v", err) - return - } - - var handshakeData struct { - SiteId int `json:"siteId"` - ExitNode struct { - PublicKey string `json:"publicKey"` - Endpoint string `json:"endpoint"` - RelayPort uint16 `json:"relayPort"` - } `json:"exitNode"` - } - - if err := json.Unmarshal(jsonData, &handshakeData); err != nil { - logger.Error("Error unmarshaling handshake data: %v", err) - return - } - - // Get existing peer from PeerManager - _, exists := peerManager.GetPeer(handshakeData.SiteId) - if exists { - logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) - return - } - - relayPort := handshakeData.ExitNode.RelayPort - if relayPort == 0 { - relayPort = 21820 // default relay port - } - - siteId := handshakeData.SiteId - exitNode := holepunch.ExitNode{ - Endpoint: handshakeData.ExitNode.Endpoint, - RelayPort: relayPort, - PublicKey: handshakeData.ExitNode.PublicKey, - SiteIds: []int{siteId}, - } - - added := holePunchManager.AddExitNode(exitNode) - if added { - logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) - } else { - logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) - } - - holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt - holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud - - // Send handshake acknowledgment back to server with retry - stopPeerSend, _ = olm.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": handshakeData.SiteId, - }, 1*time.Second) - - logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) - }) - - olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { - logger.Info("Received terminate message") - apiServer.SetTerminated(true) - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) - apiServer.ClearPeerStatuses() - network.ClearNetworkSettings() - Close() - - if globalConfig.OnTerminated != nil { - go globalConfig.OnTerminated() - } - }) - - olm.RegisterHandler("pong", func(msg websocket.WSMessage) { - logger.Debug("Received pong message") - }) - - olm.OnConnect(func() error { + olmClient.OnConnect(func() error { logger.Info("Websocket Connected") - apiServer.SetConnectionStatus(true) + o.apiServer.SetConnectionStatus(true) - if connected { + if o.connected { logger.Debug("Already connected, skipping registration") return nil } - publicKey := privateKey.PublicKey() + publicKey := o.privateKey.PublicKey() // delay for 500ms to allow for time for the hp to get processed time.Sleep(500 * time.Millisecond) - if stopRegister == nil { + if o.stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) - stopRegister, updateRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + o.stopRegister, o.updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]any{ "publicKey": publicKey.String(), "relay": !config.Holepunch, - "olmVersion": globalConfig.Version, - "olmAgent": globalConfig.Agent, + "olmVersion": o.olmConfig.Version, + "olmAgent": o.olmConfig.Agent, "orgId": config.OrgID, "userToken": userToken, }, 1*time.Second) // Invoke onRegistered callback if configured - if globalConfig.OnRegistered != nil { - go globalConfig.OnRegistered() + if o.olmConfig.OnRegistered != nil { + go o.olmConfig.OnRegistered() } } - go keepSendingPing(olm) + go o.keepSendingPing(olmClient) return nil }) - olm.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { - holePunchManager.SetToken(token) + olmClient.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { + o.holePunchManager.SetToken(token) logger.Debug("Got exit nodes for hole punching: %v", exitNodes) @@ -939,141 +380,113 @@ func StartTunnel(config TunnelConfig) { // Start hole punching using the manager logger.Info("Starting hole punch for %d exit nodes", len(exitNodes)) - if err := holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { + if err := o.holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { logger.Warn("Failed to start hole punch: %v", err) } }) - olm.OnAuthError(func(statusCode int, message string) { + olmClient.OnAuthError(func(statusCode int, message string) { logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message) - apiServer.SetTerminated(true) - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) - apiServer.ClearPeerStatuses() + o.apiServer.SetTerminated(true) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) + o.apiServer.ClearPeerStatuses() network.ClearNetworkSettings() - Close() + o.Close() - if globalConfig.OnAuthError != nil { - go globalConfig.OnAuthError(statusCode, message) + if o.olmConfig.OnAuthError != nil { + go o.olmConfig.OnAuthError(statusCode, message) } - if globalConfig.OnTerminated != nil { - go globalConfig.OnTerminated() + if o.olmConfig.OnTerminated != nil { + go o.olmConfig.OnTerminated() } }) // Connect to the WebSocket server - if err := olm.Connect(); err != nil { + if err := olmClient.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) return } - defer olm.Close() + defer func() { _ = olmClient.Close() }() + + o.olmClient = olmClient // Wait for context cancellation <-tunnelCtx.Done() logger.Info("Tunnel process context cancelled, cleaning up") } -func AddDevice(fd uint32) error { - if middleDev == nil { - return fmt.Errorf("middle device is not initialized") - } - - if tunnelConfig.MTU == 0 { - // error - return fmt.Errorf("tunnel MTU is not set") - } - - tdev, err := olmDevice.CreateTUNFromFD(fd, tunnelConfig.MTU) - - if err != nil { - return fmt.Errorf("failed to create TUN device from fd: %v", err) - } - - // if config.FileDescriptorTun == 0 { - if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? - interfaceName = realInterfaceName - } - - // Here we replace the existing TUN device in the middle device with the new one - middleDev.AddDevice(tdev) - - return nil -} - -func Close() { +func (o *Olm) Close() { // Restore original DNS configuration // we do this first to avoid any DNS issues if something else gets stuck if err := dnsOverride.RestoreDNSOverride(); err != nil { logger.Error("Failed to restore DNS: %v", err) } - // Stop hole punch manager - if holePunchManager != nil { - holePunchManager.Stop() - holePunchManager = nil + if o.holePunchManager != nil { + o.holePunchManager.Stop() + o.holePunchManager = nil } - if stopPing != nil { - select { - case <-stopPing: - // Channel already closed - default: - close(stopPing) - } + if o.stopPing != nil { + close(o.stopPing) + o.stopPing = nil } - if stopRegister != nil { - stopRegister() - stopRegister = nil + if o.stopRegister != nil { + o.stopRegister() + o.stopRegister = nil } - if updateRegister != nil { - updateRegister = nil + // Close() also calls Stop() internally + if o.peerManager != nil { + o.peerManager.Close() + o.peerManager = nil } - if peerManager != nil { - peerManager.Close() // Close() also calls Stop() internally - peerManager = nil + if o.uapiListener != nil { + _ = o.uapiListener.Close() + o.uapiListener = nil } - if uapiListener != nil { - uapiListener.Close() - uapiListener = nil + if o.logFile != nil { + _ = o.logFile.Close() + o.logFile = nil } // Stop DNS proxy first - it uses the middleDev for packet filtering - logger.Debug("Stopping DNS proxy") - if dnsProxy != nil { - dnsProxy.Stop() - dnsProxy = nil + if o.dnsProxy != nil { + logger.Debug("Stopping DNS proxy") + o.dnsProxy.Stop() + o.dnsProxy = nil } // Close MiddleDevice first - this closes the TUN and signals the closed channel // This unblocks the pump goroutine and allows WireGuard's TUN reader to exit - logger.Debug("Closing MiddleDevice") - if middleDev != nil { - middleDev.Close() - middleDev = nil + // Note: o.tdev is closed by o.middleDev.Close() since middleDev wraps it + if o.middleDev != nil { + logger.Debug("Closing MiddleDevice") + _ = o.middleDev.Close() + o.middleDev = nil } - // Note: tdev is closed by middleDev.Close() since middleDev wraps it - tdev = nil // Now close WireGuard device - its TUN reader should have exited by now - logger.Debug("Closing WireGuard device") - if dev != nil { - dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference - dev = nil + // This will call sharedBind.Close() which releases WireGuard's reference + if o.dev != nil { + logger.Debug("Closing WireGuard device") + o.dev.Close() + o.dev = nil } - // Release the hole punch reference to the shared bind - if sharedBind != nil { - // Release hole punch reference (WireGuard already released its reference via dev.Close()) - logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount()) - sharedBind.Release() - sharedBind = nil + // Release the hole punch reference to the shared bind (WireGuard already + // released its reference via dev.Close()) + if o.sharedBind != nil { + logger.Debug("Releasing shared bind (refcount before release: %d)", o.sharedBind.GetRefCount()) + _ = o.sharedBind.Release() logger.Info("Released shared UDP bind") + o.sharedBind = nil } logger.Info("Olm service stopped") @@ -1081,78 +494,85 @@ func Close() { // StopTunnel stops just the tunnel process and websocket connection // without shutting down the entire application -func StopTunnel() error { +func (o *Olm) StopTunnel() error { logger.Info("Stopping tunnel process") + if !o.tunnelRunning { + logger.Debug("Tunnel not running, nothing to stop") + return nil + } + // Cancel the tunnel context if it exists - if tunnelCancel != nil { - tunnelCancel() + if o.tunnelCancel != nil { + o.tunnelCancel() // Give it a moment to clean up time.Sleep(200 * time.Millisecond) } // Close the websocket connection - if olmClient != nil { - olmClient.Close() - olmClient = nil + if o.olmClient != nil { + _ = o.olmClient.Close() + o.olmClient = nil } - Close() + o.Close() // Reset the connected state - connected = false - tunnelRunning = false + o.connected = false + o.tunnelRunning = false // Update API server status - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) network.ClearNetworkSettings() - apiServer.ClearPeerStatuses() + o.apiServer.ClearPeerStatuses() logger.Info("Tunnel process stopped") return nil } -func StopApi() error { - if apiServer != nil { - err := apiServer.Stop() +func (o *Olm) StopApi() error { + if o.apiServer != nil { + err := o.apiServer.Stop() if err != nil { return fmt.Errorf("failed to stop API server: %w", err) } } + return nil } -func StartApi() error { - if apiServer != nil { - err := apiServer.Start() +func (o *Olm) StartApi() error { + if o.apiServer != nil { + err := o.apiServer.Start() if err != nil { return fmt.Errorf("failed to start API server: %w", err) } } + return nil } -func GetStatus() api.StatusResponse { - return apiServer.GetStatus() +func (o *Olm) GetStatus() api.StatusResponse { + return o.apiServer.GetStatus() } -func SwitchOrg(orgID string) error { +func (o *Olm) SwitchOrg(orgID string) error { logger.Info("Processing org switch request to orgId: %s", orgID) // stop the tunnel - if err := StopTunnel(); err != nil { + if err := o.StopTunnel(); err != nil { return fmt.Errorf("failed to stop existing tunnel: %w", err) } // Update the org ID in the API server and global config - apiServer.SetOrgID(orgID) + o.apiServer.SetOrgID(orgID) - tunnelConfig.OrgID = orgID + o.tunnelConfig.OrgID = orgID // Restart the tunnel with the same config but new org ID - go StartTunnel(tunnelConfig) + go o.StartTunnel(o.tunnelConfig) return nil } diff --git a/olm/peer.go b/olm/peer.go new file mode 100644 index 0000000..8acec42 --- /dev/null +++ b/olm/peer.go @@ -0,0 +1,195 @@ +package olm + +import ( + "encoding/json" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" +) + +func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { + logger.Debug("Received add-peer message: %v", msg.Data) + + if o.stopPeerSend != nil { + o.stopPeerSend() + o.stopPeerSend = nil + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var siteConfig peers.SiteConfig + if err := json.Unmarshal(jsonData, &siteConfig); err != nil { + logger.Error("Error unmarshaling add data: %v", err) + return + } + + _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it + + if err := o.peerManager.AddPeer(siteConfig); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + logger.Info("Successfully added peer for site %d", siteConfig.SiteId) +} + +func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { + logger.Debug("Received remove-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeData peers.PeerRemove + if err := json.Unmarshal(jsonData, &removeData); err != nil { + logger.Error("Error unmarshaling remove data: %v", err) + return + } + + if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil { + logger.Error("Failed to remove peer: %v", err) + return + } + + // Remove any exit nodes associated with this peer from hole punching + if o.holePunchManager != nil { + removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId) + if removed > 0 { + logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId) + } + } + + logger.Info("Successfully removed peer for site %d", removeData.SiteId) +} + +func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) { + logger.Debug("Received update-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateData peers.SiteConfig + if err := json.Unmarshal(jsonData, &updateData); err != nil { + logger.Error("Error unmarshaling update data: %v", err) + return + } + + // Get existing peer from PeerManager + existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId) + if !exists { + logger.Warn("Peer with site ID %d not found", updateData.SiteId) + return + } + + // Create updated site config by merging with existing data + siteConfig := existingPeer + + if updateData.Endpoint != "" { + siteConfig.Endpoint = updateData.Endpoint + } + if updateData.RelayEndpoint != "" { + siteConfig.RelayEndpoint = updateData.RelayEndpoint + } + if updateData.PublicKey != "" { + siteConfig.PublicKey = updateData.PublicKey + } + if updateData.ServerIP != "" { + siteConfig.ServerIP = updateData.ServerIP + } + if updateData.ServerPort != 0 { + siteConfig.ServerPort = updateData.ServerPort + } + if updateData.RemoteSubnets != nil { + siteConfig.RemoteSubnets = updateData.RemoteSubnets + } + + if err := o.peerManager.UpdatePeer(siteConfig); err != nil { + logger.Error("Failed to update peer: %v", err) + return + } + + // If the endpoint changed, trigger holepunch to refresh NAT mappings + if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { + logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) + _ = o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetInterval() + } + + logger.Info("Successfully updated peer for site %d", updateData.SiteId) +} + +func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) { + logger.Debug("Received relay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if o.peerManager == nil { + logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.RelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) + if err != nil { + logger.Error("Failed to resolve primary relay endpoint: %v", err) + return + } + + // Update HTTP server to mark this peer as using relay + o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) + + o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort) +} + +func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { + logger.Debug("Received unrelay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if o.peerManager == nil { + logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.UnRelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.Endpoint) + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + // Update HTTP server to mark this peer as using relay + o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) + + o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) +} diff --git a/olm/util.go b/olm/ping.go similarity index 89% rename from olm/util.go rename to olm/ping.go index 6bfd171..bbeee9a 100644 --- a/olm/util.go +++ b/olm/ping.go @@ -9,7 +9,7 @@ import ( ) func sendPing(olm *websocket.Client) error { - err := olm.SendMessage("olm/ping", map[string]interface{}{ + err := olm.SendMessage("olm/ping", map[string]any{ "timestamp": time.Now().Unix(), "userToken": olm.GetConfig().UserToken, }) @@ -21,7 +21,7 @@ func sendPing(olm *websocket.Client) error { return nil } -func keepSendingPing(olm *websocket.Client) { +func (o *Olm) keepSendingPing(olm *websocket.Client) { // Send ping immediately on startup if err := sendPing(olm); err != nil { logger.Error("Failed to send initial ping: %v", err) @@ -35,7 +35,7 @@ func keepSendingPing(olm *websocket.Client) { for { select { - case <-stopPing: + case <-o.stopPing: logger.Info("Stopping ping messages") return case <-ticker.C: diff --git a/olm/types.go b/olm/types.go index 9187860..77c0b5f 100644 --- a/olm/types.go +++ b/olm/types.go @@ -12,7 +12,7 @@ type WgData struct { UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses } -type GlobalConfig struct { +type OlmConfig struct { // Logging LogLevel string LogFilePath string