mirror of
https://github.com/fosrl/olm.git
synced 2026-02-24 22:06:42 +00:00
Allow connecting and disconnecting
This commit is contained in:
41
api/api.go
41
api/api.go
@@ -13,9 +13,10 @@ import (
|
|||||||
|
|
||||||
// ConnectionRequest defines the structure for an incoming connection request
|
// ConnectionRequest defines the structure for an incoming connection request
|
||||||
type ConnectionRequest struct {
|
type ConnectionRequest struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Secret string `json:"secret"`
|
Secret string `json:"secret"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
|
UserToken string `json:"userToken,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SwitchOrgRequest defines the structure for switching organizations
|
// SwitchOrgRequest defines the structure for switching organizations
|
||||||
@@ -53,6 +54,7 @@ type API struct {
|
|||||||
connectionChan chan ConnectionRequest
|
connectionChan chan ConnectionRequest
|
||||||
switchOrgChan chan SwitchOrgRequest
|
switchOrgChan chan SwitchOrgRequest
|
||||||
shutdownChan chan struct{}
|
shutdownChan chan struct{}
|
||||||
|
disconnectChan chan struct{}
|
||||||
statusMu sync.RWMutex
|
statusMu sync.RWMutex
|
||||||
peerStatuses map[int]*PeerStatus
|
peerStatuses map[int]*PeerStatus
|
||||||
connectedAt time.Time
|
connectedAt time.Time
|
||||||
@@ -70,6 +72,7 @@ func NewAPI(addr string) *API {
|
|||||||
connectionChan: make(chan ConnectionRequest, 1),
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||||
shutdownChan: make(chan struct{}, 1),
|
shutdownChan: make(chan struct{}, 1),
|
||||||
|
disconnectChan: make(chan struct{}, 1),
|
||||||
peerStatuses: make(map[int]*PeerStatus),
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,6 +86,7 @@ func NewAPISocket(socketPath string) *API {
|
|||||||
connectionChan: make(chan ConnectionRequest, 1),
|
connectionChan: make(chan ConnectionRequest, 1),
|
||||||
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||||
shutdownChan: make(chan struct{}, 1),
|
shutdownChan: make(chan struct{}, 1),
|
||||||
|
disconnectChan: make(chan struct{}, 1),
|
||||||
peerStatuses: make(map[int]*PeerStatus),
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,6 +99,7 @@ func (s *API) Start() error {
|
|||||||
mux.HandleFunc("/connect", s.handleConnect)
|
mux.HandleFunc("/connect", s.handleConnect)
|
||||||
mux.HandleFunc("/status", s.handleStatus)
|
mux.HandleFunc("/status", s.handleStatus)
|
||||||
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||||
|
mux.HandleFunc("/disconnect", s.handleDisconnect)
|
||||||
mux.HandleFunc("/exit", s.handleExit)
|
mux.HandleFunc("/exit", s.handleExit)
|
||||||
|
|
||||||
s.server = &http.Server{
|
s.server = &http.Server{
|
||||||
@@ -159,6 +164,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
|
|||||||
return s.shutdownChan
|
return s.shutdownChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDisconnectChannel returns the channel for receiving disconnect requests
|
||||||
|
func (s *API) GetDisconnectChannel() <-chan struct{} {
|
||||||
|
return s.disconnectChan
|
||||||
|
}
|
||||||
|
|
||||||
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||||
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||||
s.statusMu.Lock()
|
s.statusMu.Lock()
|
||||||
@@ -356,3 +366,28 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
|||||||
"status": "org switch request accepted",
|
"status": "org switch request accepted",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleDisconnect handles the /disconnect endpoint
|
||||||
|
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Received disconnect request via API")
|
||||||
|
|
||||||
|
// Send disconnect signal
|
||||||
|
select {
|
||||||
|
case s.disconnectChan <- struct{}{}:
|
||||||
|
// Signal sent successfully
|
||||||
|
default:
|
||||||
|
// Channel already has a signal, don't block
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a success response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"status": "disconnect initiated",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
457
olm/olm.go
457
olm/olm.go
@@ -3,7 +3,6 @@ package olm
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -39,10 +38,6 @@ type Config struct {
|
|||||||
HTTPAddr string
|
HTTPAddr string
|
||||||
SocketPath string
|
SocketPath string
|
||||||
|
|
||||||
// Ping settings
|
|
||||||
PingInterval string
|
|
||||||
PingTimeout string
|
|
||||||
|
|
||||||
// Advanced
|
// Advanced
|
||||||
Holepunch bool
|
Holepunch bool
|
||||||
TlsClientCert string
|
TlsClientCert string
|
||||||
@@ -58,133 +53,175 @@ type Config struct {
|
|||||||
OrgID string
|
OrgID string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
privateKey wgtypes.Key
|
||||||
|
connected bool
|
||||||
|
dev *device.Device
|
||||||
|
wgData WgData
|
||||||
|
holePunchData HolePunchData
|
||||||
|
uapiListener net.Listener
|
||||||
|
tdev tun.Device
|
||||||
|
apiServer *api.API
|
||||||
|
olmClient *websocket.Client
|
||||||
|
tunnelCancel context.CancelFunc
|
||||||
|
)
|
||||||
|
|
||||||
func Run(ctx context.Context, config Config) {
|
func Run(ctx context.Context, config Config) {
|
||||||
// Create a cancellable context for internal shutdown control
|
// Create a cancellable context for internal shutdown control
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Extract commonly used values from config for convenience
|
logger.GetLogger().SetLevel(parseLogLevel(config.LogLevel))
|
||||||
var (
|
|
||||||
endpoint = config.Endpoint
|
|
||||||
id = config.ID
|
|
||||||
secret = config.Secret
|
|
||||||
mtu = config.MTU
|
|
||||||
logLevel = config.LogLevel
|
|
||||||
interfaceName = config.InterfaceName
|
|
||||||
pingInterval = config.PingIntervalDuration
|
|
||||||
pingTimeout = config.PingTimeoutDuration
|
|
||||||
doHolepunch = config.Holepunch
|
|
||||||
privateKey wgtypes.Key
|
|
||||||
connected bool
|
|
||||||
dev *device.Device
|
|
||||||
wgData WgData
|
|
||||||
holePunchData HolePunchData
|
|
||||||
uapiListener net.Listener
|
|
||||||
tdev tun.Device
|
|
||||||
)
|
|
||||||
|
|
||||||
stopHolepunch = make(chan struct{})
|
|
||||||
stopPing = make(chan struct{})
|
|
||||||
|
|
||||||
loggerLevel := parseLogLevel(logLevel)
|
|
||||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
|
||||||
|
|
||||||
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
|
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
|
||||||
logger.Debug("Failed to check for updates: %v", err)
|
logger.Debug("Failed to check for updates: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log startup information
|
if config.Holepunch {
|
||||||
logger.Debug("Olm service starting...")
|
|
||||||
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
|
|
||||||
|
|
||||||
if doHolepunch {
|
|
||||||
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiServer *api.API
|
if config.HTTPAddr != "" {
|
||||||
if config.EnableAPI {
|
apiServer = api.NewAPI(config.HTTPAddr)
|
||||||
if config.HTTPAddr != "" {
|
} else if config.SocketPath != "" {
|
||||||
apiServer = api.NewAPI(config.HTTPAddr)
|
apiServer = api.NewAPISocket(config.SocketPath)
|
||||||
} else if config.SocketPath != "" {
|
|
||||||
apiServer = api.NewAPISocket(config.SocketPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
apiServer.SetVersion(config.Version)
|
|
||||||
apiServer.SetOrgID(config.OrgID)
|
|
||||||
if err := apiServer.Start(); err != nil {
|
|
||||||
logger.Fatal("Failed to start HTTP server: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Listen for shutdown requests from the API
|
|
||||||
go func() {
|
|
||||||
<-apiServer.GetShutdownChannel()
|
|
||||||
logger.Info("Shutdown requested via API")
|
|
||||||
// Cancel the context to trigger graceful shutdown
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// // Use a goroutine to handle connection requests
|
apiServer.SetVersion(config.Version)
|
||||||
// go func() {
|
apiServer.SetOrgID(config.OrgID)
|
||||||
// for req := range apiServer.GetConnectionChannel() {
|
|
||||||
// logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
|
||||||
|
|
||||||
// // Set the connection parameters
|
if err := apiServer.Start(); err != nil {
|
||||||
// id = req.ID
|
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||||
// secret = req.Secret
|
}
|
||||||
// endpoint = req.Endpoint
|
|
||||||
// }
|
|
||||||
// }()
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Create a new olm
|
// Listen for shutdown requests from the API
|
||||||
olm, err := websocket.NewClient(
|
go func() {
|
||||||
"olm",
|
<-apiServer.GetShutdownChannel()
|
||||||
id, // CLI arg takes precedence
|
logger.Info("Shutdown requested via API")
|
||||||
secret, // CLI arg takes precedence
|
// Cancel the context to trigger graceful shutdown
|
||||||
endpoint,
|
cancel()
|
||||||
pingInterval,
|
}()
|
||||||
pingTimeout,
|
|
||||||
|
var (
|
||||||
|
id = config.ID
|
||||||
|
secret = config.Secret
|
||||||
|
endpoint = config.Endpoint
|
||||||
)
|
)
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to create olm: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait until we have a client id and secret and endpoint
|
// Main event loop that handles connect, disconnect, and reconnect
|
||||||
waitCount := 0
|
for {
|
||||||
for id == "" || secret == "" || endpoint == "" {
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
logger.Info("Context cancelled while waiting for credentials")
|
logger.Info("Context cancelled while waiting for credentials")
|
||||||
return
|
goto shutdown
|
||||||
|
|
||||||
|
case req := <-apiServer.GetConnectionChannel():
|
||||||
|
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
||||||
|
|
||||||
|
// Stop any existing tunnel before starting a new one
|
||||||
|
if olmClient != nil {
|
||||||
|
logger.Info("Stopping existing tunnel before starting new connection")
|
||||||
|
StopTunnel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the connection parameters
|
||||||
|
id = req.ID
|
||||||
|
secret = req.Secret
|
||||||
|
endpoint = req.Endpoint
|
||||||
|
|
||||||
|
// Start the tunnel process with the new credentials
|
||||||
|
if id != "" && secret != "" && endpoint != "" {
|
||||||
|
logger.Info("Starting tunnel with new credentials")
|
||||||
|
go TunnelProcess(ctx, config, id, secret, endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-apiServer.GetDisconnectChannel():
|
||||||
|
logger.Info("Received disconnect request via API")
|
||||||
|
StopTunnel()
|
||||||
|
// Clear credentials so we wait for new connect call
|
||||||
|
id = ""
|
||||||
|
secret = ""
|
||||||
|
endpoint = ""
|
||||||
|
|
||||||
default:
|
default:
|
||||||
missing := []string{}
|
// If we have credentials and no tunnel is running, start it
|
||||||
if id == "" {
|
if id != "" && secret != "" && endpoint != "" && olmClient == nil {
|
||||||
missing = append(missing, "id")
|
logger.Info("Starting tunnel process with initial credentials")
|
||||||
|
go TunnelProcess(ctx, config, id, secret, endpoint)
|
||||||
|
} else if id == "" || secret == "" || endpoint == "" {
|
||||||
|
// If we don't have credentials, check if API is enabled
|
||||||
|
if !config.EnableAPI {
|
||||||
|
missing := []string{}
|
||||||
|
if id == "" {
|
||||||
|
missing = append(missing, "id")
|
||||||
|
}
|
||||||
|
if secret == "" {
|
||||||
|
missing = append(missing, "secret")
|
||||||
|
}
|
||||||
|
if endpoint == "" {
|
||||||
|
missing = append(missing, "endpoint")
|
||||||
|
}
|
||||||
|
// exit the application because there is no way to provide the missing parameters
|
||||||
|
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
|
||||||
|
goto shutdown
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if secret == "" {
|
|
||||||
missing = append(missing, "secret")
|
// Sleep briefly to prevent tight loop
|
||||||
}
|
time.Sleep(100 * time.Millisecond)
|
||||||
if endpoint == "" {
|
|
||||||
missing = append(missing, "endpoint")
|
|
||||||
}
|
|
||||||
waitCount++
|
|
||||||
if waitCount%10 == 1 { // Log every 10 seconds instead of every second
|
|
||||||
logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount)
|
|
||||||
}
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
shutdown:
|
||||||
|
Stop()
|
||||||
|
apiServer.Stop()
|
||||||
|
logger.Info("Olm service shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TunnelProcess(ctx context.Context, config Config, id string, secret string, endpoint string) {
|
||||||
|
// Create a cancellable context for this tunnel process
|
||||||
|
tunnelCtx, cancel := context.WithCancel(ctx)
|
||||||
|
tunnelCancel = cancel
|
||||||
|
defer func() {
|
||||||
|
tunnelCancel = nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Recreate channels for this tunnel session
|
||||||
|
stopHolepunch = make(chan struct{})
|
||||||
|
stopPing = make(chan struct{})
|
||||||
|
|
||||||
|
var (
|
||||||
|
interfaceName = config.InterfaceName
|
||||||
|
loggerLevel = parseLogLevel(config.LogLevel)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create a new olm client using the provided credentials
|
||||||
|
olm, err := websocket.NewClient(
|
||||||
|
"olm",
|
||||||
|
id, // Use provided ID
|
||||||
|
secret, // Use provided secret
|
||||||
|
endpoint, // Use provided endpoint
|
||||||
|
config.PingIntervalDuration,
|
||||||
|
config.PingTimeoutDuration,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create olm: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the client reference globally
|
||||||
|
olmClient = olm
|
||||||
|
|
||||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to generate private key: %v", err)
|
logger.Error("Failed to generate private key: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sourcePort, err := FindAvailableUDPPort(49152, 65535)
|
sourcePort, err := FindAvailableUDPPort(49152, 65535)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error finding available port: %v\n", err)
|
logger.Error("Error finding available port: %v", err)
|
||||||
os.Exit(1)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
||||||
@@ -289,12 +326,12 @@ func Run(ctx context.Context, config Config) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return tun.CreateTUN(interfaceName, mtu)
|
return tun.CreateTUN(interfaceName, config.MTU)
|
||||||
}
|
}
|
||||||
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
|
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
|
||||||
return createTUNFromFD(tunFdStr, mtu)
|
return createTUNFromFD(tunFdStr, config.MTU)
|
||||||
}
|
}
|
||||||
return tun.CreateTUN(interfaceName, mtu)
|
return tun.CreateTUN(interfaceName, config.MTU)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -347,27 +384,23 @@ func Run(ctx context.Context, config Config) {
|
|||||||
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||||
logger.Error("Failed to configure interface: %v", err)
|
logger.Error("Failed to configure interface: %v", err)
|
||||||
}
|
}
|
||||||
if apiServer != nil {
|
apiServer.SetTunnelIP(wgData.TunnelIP)
|
||||||
apiServer.SetTunnelIP(wgData.TunnelIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerMonitor = peermonitor.NewPeerMonitor(
|
peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
func(siteID int, connected bool, rtt time.Duration) {
|
func(siteID int, connected bool, rtt time.Duration) {
|
||||||
if apiServer != nil {
|
// Find the site config to get endpoint information
|
||||||
// Find the site config to get endpoint information
|
var endpoint string
|
||||||
var endpoint string
|
var isRelay bool
|
||||||
var isRelay bool
|
for _, site := range wgData.Sites {
|
||||||
for _, site := range wgData.Sites {
|
if site.SiteId == siteID {
|
||||||
if site.SiteId == siteID {
|
endpoint = site.Endpoint
|
||||||
endpoint = site.Endpoint
|
// TODO: We'll need to track relay status separately
|
||||||
// TODO: We'll need to track relay status separately
|
// For now, assume not using relay unless we get relay data
|
||||||
// For now, assume not using relay unless we get relay data
|
isRelay = !config.Holepunch
|
||||||
isRelay = !doHolepunch
|
break
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
|
|
||||||
}
|
}
|
||||||
|
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
|
||||||
if connected {
|
if connected {
|
||||||
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
|
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
|
||||||
} else {
|
} else {
|
||||||
@@ -377,14 +410,12 @@ func Run(ctx context.Context, config Config) {
|
|||||||
fixKey(privateKey.String()),
|
fixKey(privateKey.String()),
|
||||||
olm,
|
olm,
|
||||||
dev,
|
dev,
|
||||||
doHolepunch,
|
config.Holepunch,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i := range wgData.Sites {
|
for i := range wgData.Sites {
|
||||||
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
|
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
|
||||||
if apiServer != nil {
|
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
|
||||||
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format the endpoint before configuring the peer.
|
// Format the endpoint before configuring the peer.
|
||||||
site.Endpoint = formatEndpoint(site.Endpoint)
|
site.Endpoint = formatEndpoint(site.Endpoint)
|
||||||
@@ -407,9 +438,7 @@ func Run(ctx context.Context, config Config) {
|
|||||||
|
|
||||||
peerMonitor.Start()
|
peerMonitor.Start()
|
||||||
|
|
||||||
if apiServer != nil {
|
apiServer.SetRegistered(true)
|
||||||
apiServer.SetRegistered(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
connected = true
|
connected = true
|
||||||
|
|
||||||
@@ -637,9 +666,7 @@ func Run(ctx context.Context, config Config) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update HTTP server to mark this peer as using relay
|
// Update HTTP server to mark this peer as using relay
|
||||||
if apiServer != nil {
|
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
||||||
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||||
})
|
})
|
||||||
@@ -670,9 +697,7 @@ func Run(ctx context.Context, config Config) {
|
|||||||
olm.OnConnect(func() error {
|
olm.OnConnect(func() error {
|
||||||
logger.Info("Websocket Connected")
|
logger.Info("Websocket Connected")
|
||||||
|
|
||||||
if apiServer != nil {
|
apiServer.SetConnectionStatus(true)
|
||||||
apiServer.SetConnectionStatus(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
if connected {
|
if connected {
|
||||||
logger.Debug("Already connected, skipping registration")
|
logger.Debug("Already connected, skipping registration")
|
||||||
@@ -681,11 +706,11 @@ func Run(ctx context.Context, config Config) {
|
|||||||
|
|
||||||
publicKey := privateKey.PublicKey()
|
publicKey := privateKey.PublicKey()
|
||||||
|
|
||||||
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
|
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
|
||||||
|
|
||||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"relay": !doHolepunch,
|
"relay": !config.Holepunch,
|
||||||
"olmVersion": config.Version,
|
"olmVersion": config.Version,
|
||||||
"orgId": config.OrgID,
|
"orgId": config.OrgID,
|
||||||
}, 1*time.Second)
|
}, 1*time.Second)
|
||||||
@@ -700,89 +725,50 @@ func Run(ctx context.Context, config Config) {
|
|||||||
olmToken = token
|
olmToken = token
|
||||||
})
|
})
|
||||||
|
|
||||||
// Listen for org switch requests from the API
|
|
||||||
if apiServer != nil {
|
|
||||||
go func() {
|
|
||||||
for req := range apiServer.GetSwitchOrgChannel() {
|
|
||||||
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
|
||||||
|
|
||||||
// Update the config with the new orgId
|
|
||||||
config.OrgID = req.OrgID
|
|
||||||
|
|
||||||
// Mark as not connected to trigger re-registration
|
|
||||||
connected = false
|
|
||||||
|
|
||||||
// Stop registration if running
|
|
||||||
if stopRegister != nil {
|
|
||||||
stopRegister()
|
|
||||||
stopRegister = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop hole punching
|
|
||||||
select {
|
|
||||||
case <-stopHolepunch:
|
|
||||||
// Already closed
|
|
||||||
default:
|
|
||||||
close(stopHolepunch)
|
|
||||||
}
|
|
||||||
stopHolepunch = make(chan struct{})
|
|
||||||
|
|
||||||
// Stop peer monitor
|
|
||||||
if peerMonitor != nil {
|
|
||||||
peerMonitor.Stop()
|
|
||||||
peerMonitor = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the WireGuard device
|
|
||||||
if dev != nil {
|
|
||||||
logger.Info("Closing existing WireGuard device for org switch")
|
|
||||||
dev.Close()
|
|
||||||
dev = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close UAPI listener
|
|
||||||
if uapiListener != nil {
|
|
||||||
uapiListener.Close()
|
|
||||||
uapiListener = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close TUN device
|
|
||||||
if tdev != nil {
|
|
||||||
tdev.Close()
|
|
||||||
tdev = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear peer statuses in API
|
|
||||||
if apiServer != nil {
|
|
||||||
apiServer.SetRegistered(false)
|
|
||||||
apiServer.SetTunnelIP("")
|
|
||||||
apiServer.SetOrgID(config.OrgID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trigger re-registration with new orgId
|
|
||||||
logger.Info("Re-registering with new orgId: %s", config.OrgID)
|
|
||||||
publicKey := privateKey.PublicKey()
|
|
||||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
|
||||||
"publicKey": publicKey.String(),
|
|
||||||
"relay": !doHolepunch,
|
|
||||||
"olmVersion": config.Version,
|
|
||||||
"orgId": config.OrgID,
|
|
||||||
}, 1*time.Second)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect to the WebSocket server
|
// Connect to the WebSocket server
|
||||||
if err := olm.Connect(); err != nil {
|
if err := olm.Connect(); err != nil {
|
||||||
logger.Fatal("Failed to connect to server: %v", err)
|
logger.Error("Failed to connect to server: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
defer olm.Close()
|
defer olm.Close()
|
||||||
|
|
||||||
select {
|
// Listen for org switch requests from the API
|
||||||
case <-ctx.Done():
|
go func() {
|
||||||
logger.Info("Context cancelled")
|
for req := range apiServer.GetSwitchOrgChannel() {
|
||||||
}
|
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
||||||
|
|
||||||
|
// Update the config with the new orgId
|
||||||
|
config.OrgID = req.OrgID
|
||||||
|
|
||||||
|
// Mark as not connected to trigger re-registration
|
||||||
|
connected = false
|
||||||
|
|
||||||
|
Stop()
|
||||||
|
|
||||||
|
// Clear peer statuses in API
|
||||||
|
apiServer.SetRegistered(false)
|
||||||
|
apiServer.SetTunnelIP("")
|
||||||
|
apiServer.SetOrgID(config.OrgID)
|
||||||
|
|
||||||
|
stopHolepunch = make(chan struct{})
|
||||||
|
// Trigger re-registration with new orgId
|
||||||
|
logger.Info("Re-registering with new orgId: %s", config.OrgID)
|
||||||
|
publicKey := privateKey.PublicKey()
|
||||||
|
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
"publicKey": publicKey.String(),
|
||||||
|
"relay": !config.Holepunch,
|
||||||
|
"olmVersion": config.Version,
|
||||||
|
"orgId": config.OrgID,
|
||||||
|
}, 1*time.Second)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for context cancellation
|
||||||
|
<-tunnelCtx.Done()
|
||||||
|
logger.Info("Tunnel process context cancelled, cleaning up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Stop() {
|
||||||
select {
|
select {
|
||||||
case <-stopHolepunch:
|
case <-stopHolepunch:
|
||||||
// Channel already closed, do nothing
|
// Channel already closed, do nothing
|
||||||
@@ -790,11 +776,6 @@ func Run(ctx context.Context, config Config) {
|
|||||||
close(stopHolepunch)
|
close(stopHolepunch)
|
||||||
}
|
}
|
||||||
|
|
||||||
if stopRegister != nil {
|
|
||||||
stopRegister()
|
|
||||||
stopRegister = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-stopPing:
|
case <-stopPing:
|
||||||
// Channel already closed
|
// Channel already closed
|
||||||
@@ -802,20 +783,60 @@ func Run(ctx context.Context, config Config) {
|
|||||||
close(stopPing)
|
close(stopPing)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if stopRegister != nil {
|
||||||
|
stopRegister()
|
||||||
|
stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
if peerMonitor != nil {
|
if peerMonitor != nil {
|
||||||
peerMonitor.Stop()
|
peerMonitor.Stop()
|
||||||
|
peerMonitor = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if uapiListener != nil {
|
if uapiListener != nil {
|
||||||
uapiListener.Close()
|
uapiListener.Close()
|
||||||
|
uapiListener = nil
|
||||||
}
|
}
|
||||||
if dev != nil {
|
if dev != nil {
|
||||||
dev.Close()
|
dev.Close()
|
||||||
|
dev = nil
|
||||||
}
|
}
|
||||||
|
// Close TUN device
|
||||||
if apiServer != nil {
|
if tdev != nil {
|
||||||
apiServer.Stop()
|
tdev.Close()
|
||||||
|
tdev = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Olm service stopped")
|
logger.Info("Olm service stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StopTunnel stops just the tunnel process and websocket connection
|
||||||
|
// without shutting down the entire application
|
||||||
|
func StopTunnel() {
|
||||||
|
logger.Info("Stopping tunnel process")
|
||||||
|
|
||||||
|
// Cancel the tunnel context if it exists
|
||||||
|
if tunnelCancel != nil {
|
||||||
|
tunnelCancel()
|
||||||
|
// Give it a moment to clean up
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the websocket connection
|
||||||
|
if olmClient != nil {
|
||||||
|
olmClient.Close()
|
||||||
|
olmClient = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
Stop()
|
||||||
|
|
||||||
|
// Reset the connected state
|
||||||
|
connected = false
|
||||||
|
|
||||||
|
// Update API server status
|
||||||
|
apiServer.SetConnectionStatus(false)
|
||||||
|
apiServer.SetRegistered(false)
|
||||||
|
apiServer.SetTunnelIP("")
|
||||||
|
|
||||||
|
logger.Info("Tunnel process stopped")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user