Split up concerns so parent can call start and stop

This commit is contained in:
Owen
2025-11-18 18:14:21 -05:00
parent e7be7fb281
commit 8f97c43b63
4 changed files with 246 additions and 222 deletions

View File

@@ -13,10 +13,18 @@ import (
// ConnectionRequest defines the structure for an incoming connection request // ConnectionRequest defines the structure for an incoming connection request
type ConnectionRequest struct { type ConnectionRequest struct {
ID string `json:"id"` ID string `json:"id"`
Secret string `json:"secret"` Secret string `json:"secret"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"` UserToken string `json:"userToken,omitempty"`
MTU int `json:"mtu,omitempty"`
DNS string `json:"dns,omitempty"`
InterfaceName string `json:"interfaceName,omitempty"`
Holepunch bool `json:"holepunch,omitempty"`
TlsClientCert string `json:"tlsClientCert,omitempty"`
PingInterval string `json:"pingInterval,omitempty"`
PingTimeout string `json:"pingTimeout,omitempty"`
OrgID string `json:"orgId,omitempty"`
} }
// SwitchOrgRequest defines the structure for switching organizations // SwitchOrgRequest defines the structure for switching organizations
@@ -47,33 +55,29 @@ type StatusResponse struct {
// API represents the HTTP server and its state // API represents the HTTP server and its state
type API struct { type API struct {
addr string addr string
socketPath string socketPath string
listener net.Listener listener net.Listener
server *http.Server server *http.Server
connectionChan chan ConnectionRequest onConnect func(ConnectionRequest) error
switchOrgChan chan SwitchOrgRequest onSwitchOrg func(SwitchOrgRequest) error
shutdownChan chan struct{} onDisconnect func() error
disconnectChan chan struct{} onExit func() error
statusMu sync.RWMutex statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus peerStatuses map[int]*PeerStatus
connectedAt time.Time connectedAt time.Time
isConnected bool isConnected bool
isRegistered bool isRegistered bool
tunnelIP string tunnelIP string
version string version string
orgID string orgID string
} }
// NewAPI creates a new HTTP server that listens on a TCP address // NewAPI creates a new HTTP server that listens on a TCP address
func NewAPI(addr string) *API { func NewAPI(addr string) *API {
s := &API{ s := &API{
addr: addr, addr: addr,
connectionChan: make(chan ConnectionRequest, 1), peerStatuses: make(map[int]*PeerStatus),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
} }
return s return s
@@ -82,17 +86,26 @@ func NewAPI(addr string) *API {
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe // NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
func NewAPISocket(socketPath string) *API { func NewAPISocket(socketPath string) *API {
s := &API{ s := &API{
socketPath: socketPath, socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1), peerStatuses: make(map[int]*PeerStatus),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
} }
return s return s
} }
// SetHandlers sets the callback functions for handling API requests
func (s *API) SetHandlers(
onConnect func(ConnectionRequest) error,
onSwitchOrg func(SwitchOrgRequest) error,
onDisconnect func() error,
onExit func() error,
) {
s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg
s.onDisconnect = onDisconnect
s.onExit = onExit
}
// Start starts the HTTP server // Start starts the HTTP server
func (s *API) Start() error { func (s *API) Start() error {
mux := http.NewServeMux() mux := http.NewServeMux()
@@ -149,26 +162,6 @@ func (s *API) Stop() error {
return nil return nil
} }
// GetConnectionChannel returns the channel for receiving connection requests
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
return s.connectionChan
}
// GetSwitchOrgChannel returns the channel for receiving org switch requests
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
return s.switchOrgChan
}
// GetShutdownChannel returns the channel for receiving shutdown requests
func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
}
// GetDisconnectChannel returns the channel for receiving disconnect requests
func (s *API) GetDisconnectChannel() <-chan struct{} {
return s.disconnectChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info // UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock() s.statusMu.Lock()
@@ -277,8 +270,13 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
return return
} }
// Send the request to the main goroutine // Call the connect handler if set
s.connectionChan <- req if s.onConnect != nil {
if err := s.onConnect(req); err != nil {
http.Error(w, fmt.Sprintf("Connection failed: %v", err), http.StatusInternalServerError)
return
}
}
// Return a success response // Return a success response
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@@ -320,12 +318,12 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
logger.Info("Received exit request via API") logger.Info("Received exit request via API")
// Send shutdown signal // Call the exit handler if set
select { if s.onExit != nil {
case s.shutdownChan <- struct{}{}: if err := s.onExit(); err != nil {
// Signal sent successfully http.Error(w, fmt.Sprintf("Exit failed: %v", err), http.StatusInternalServerError)
default: return
// Channel already has a signal, don't block }
} }
// Return a success response // Return a success response
@@ -358,14 +356,12 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
logger.Info("Received org switch request to orgId: %s", req.OrgID) logger.Info("Received org switch request to orgId: %s", req.OrgID)
// Send the request to the main goroutine // Call the switch org handler if set
select { if s.onSwitchOrg != nil {
case s.switchOrgChan <- req: if err := s.onSwitchOrg(req); err != nil {
// Signal sent successfully http.Error(w, fmt.Sprintf("Org switch failed: %v", err), http.StatusInternalServerError)
default: return
// Channel already has a pending request }
http.Error(w, "Org switch already in progress", http.StatusConflict)
return
} }
// Return a success response // Return a success response
@@ -394,12 +390,12 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
logger.Info("Received disconnect request via API") logger.Info("Received disconnect request via API")
// Send disconnect signal // Call the disconnect handler if set
select { if s.onDisconnect != nil {
case s.disconnectChan <- struct{}{}: if err := s.onDisconnect(); err != nil {
// Signal sent successfully http.Error(w, fmt.Sprintf("Disconnect failed: %v", err), http.StatusInternalServerError)
default: return
// Channel already has a signal, don't block }
} }
// Return a success response // Return a success response

53
main.go
View File

@@ -205,26 +205,41 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
} }
// Create a new olm.Config struct and copy values from the main config // Create a new olm.Config struct and copy values from the main config
olmConfig := olm.Config{ olmConfig := olm.GlobalConfig{
Endpoint: config.Endpoint, LogLevel: config.LogLevel,
ID: config.ID, EnableAPI: config.EnableAPI,
Secret: config.Secret, HTTPAddr: config.HTTPAddr,
UserToken: config.UserToken, SocketPath: config.SocketPath,
MTU: config.MTU, Version: config.Version,
DNS: config.DNS,
InterfaceName: config.InterfaceName,
LogLevel: config.LogLevel,
EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr,
SocketPath: config.SocketPath,
Holepunch: config.Holepunch,
TlsClientCert: config.TlsClientCert,
PingIntervalDuration: config.PingIntervalDuration,
PingTimeoutDuration: config.PingTimeoutDuration,
Version: config.Version,
OrgID: config.OrgID,
// DoNotCreateNewClient: config.DoNotCreateNewClient,
} }
olm.Init(ctx, olmConfig) olm.Init(ctx, olmConfig)
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
tunnelConfig := olm.TunnelConfig{
Endpoint: config.Endpoint,
ID: config.ID,
Secret: config.Secret,
UserToken: config.UserToken,
MTU: config.MTU,
DNS: config.DNS,
InterfaceName: config.InterfaceName,
Holepunch: config.Holepunch,
TlsClientCert: config.TlsClientCert,
PingIntervalDuration: config.PingIntervalDuration,
PingTimeoutDuration: config.PingTimeoutDuration,
OrgID: config.OrgID,
}
go olm.StartTunnel(tunnelConfig)
} else {
logger.Info("Incomplete tunnel configuration, not starting tunnel")
}
// Wait for context cancellation (from signals or API shutdown)
<-ctx.Done()
logger.Info("Shutdown signal received, cleaning up...")
// Clean up resources
olm.Close()
logger.Info("Shutdown complete")
} }

View File

@@ -15,7 +15,7 @@ import (
) )
// ConfigureInterface configures a network interface with an IP address and brings it up // ConfigureInterface configures a network interface with an IP address and brings it up
func ConfigureInterface(interfaceName string, wgData WgData) error { func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error {
logger.Info("The tunnel IP is: %s", wgData.TunnelIP) logger.Info("The tunnel IP is: %s", wgData.TunnelIP)
// Parse the IP address and network // Parse the IP address and network
@@ -32,6 +32,7 @@ func ConfigureInterface(interfaceName string, wgData WgData) error {
// network.SetTunnelRemoteAddress() // what does this do? // network.SetTunnelRemoteAddress() // what does this do?
network.SetIPv4Settings([]string{destinationAddress}, []string{mask}) network.SetIPv4Settings([]string{destinationAddress}, []string{mask})
network.SetMTU(mtu)
apiServer.SetTunnelIP(destinationAddress) apiServer.SetTunnelIP(destinationAddress)
if interfaceName == "" { if interfaceName == "" {

View File

@@ -3,6 +3,7 @@ package olm
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net" "net"
"runtime" "runtime"
"time" "time"
@@ -20,7 +21,21 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type Config struct { type GlobalConfig struct {
// Logging
LogLevel string
// HTTP server
EnableAPI bool
HTTPAddr string
SocketPath string
Version string
// Source tracking (not in JSON)
sources map[string]string
}
type TunnelConfig struct {
// Connection settings // Connection settings
Endpoint string Endpoint string
ID string ID string
@@ -32,14 +47,6 @@ type Config struct {
DNS string DNS string
InterfaceName string InterfaceName string
// Logging
LogLevel string
// HTTP server
EnableAPI bool
HTTPAddr string
SocketPath string
// Advanced // Advanced
Holepunch bool Holepunch bool
TlsClientCert string TlsClientCert string
@@ -48,11 +55,7 @@ type Config struct {
PingIntervalDuration time.Duration PingIntervalDuration time.Duration
PingTimeoutDuration time.Duration PingTimeoutDuration time.Duration
// Source tracking (not in JSON) OrgID string
sources map[string]string
Version string
OrgID string
// DoNotCreateNewClient bool // DoNotCreateNewClient bool
FileDescriptorTun uint32 FileDescriptorTun uint32
@@ -74,21 +77,21 @@ var (
sharedBind *bind.SharedBind sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager holePunchManager *holepunch.Manager
peerMonitor *peermonitor.PeerMonitor peerMonitor *peermonitor.PeerMonitor
globalConfig GlobalConfig
globalCtx context.Context
stopRegister func() stopRegister func()
stopPing chan struct{} stopPing chan struct{}
) )
func Init(ctx context.Context, config Config) { func Init(ctx context.Context, config GlobalConfig) {
globalConfig = config
globalCtx = ctx
// Create a cancellable context for internal shutdown control // Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
network.SetMTU(config.MTU)
if config.Holepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
if config.HTTPAddr != "" { if config.HTTPAddr != "" {
apiServer = api.NewAPI(config.HTTPAddr) apiServer = api.NewAPI(config.HTTPAddr)
@@ -97,35 +100,15 @@ func Init(ctx context.Context, config Config) {
} }
apiServer.SetVersion(config.Version) apiServer.SetVersion(config.Version)
apiServer.SetOrgID(config.OrgID)
if err := apiServer.Start(); err != nil { if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err) logger.Fatal("Failed to start HTTP server: %v", err)
} }
// Listen for shutdown requests from the API // Set up API handlers
go func() { apiServer.SetHandlers(
<-apiServer.GetShutdownChannel() // onConnect
logger.Info("Shutdown requested via API") func(req api.ConnectionRequest) error {
// Cancel the context to trigger graceful shutdown
cancel()
}()
var (
id = config.ID
secret = config.Secret
endpoint = config.Endpoint
userToken = config.UserToken
)
// Main event loop that handles connect, disconnect, and reconnect
for {
select {
case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials")
goto shutdown
case req := <-apiServer.GetConnectionChannel():
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Stop any existing tunnel before starting a new one // Stop any existing tunnel before starting a new one
@@ -134,67 +117,120 @@ func Init(ctx context.Context, config Config) {
StopTunnel() StopTunnel()
} }
// Set the connection parameters tunnelConfig := TunnelConfig{
id = req.ID Endpoint: req.Endpoint,
secret = req.Secret ID: req.ID,
endpoint = req.Endpoint Secret: req.Secret,
userToken := req.UserToken UserToken: req.UserToken,
MTU: req.MTU,
DNS: req.DNS,
InterfaceName: req.InterfaceName,
Holepunch: req.Holepunch,
TlsClientCert: req.TlsClientCert,
OrgID: req.OrgID,
}
var err error
// Parse ping interval
if req.PingInterval != "" {
tunnelConfig.PingIntervalDuration, err = time.ParseDuration(req.PingInterval)
if err != nil {
logger.Warn("Invalid PING_INTERVAL value: %s, using default 3 seconds", req.PingInterval)
tunnelConfig.PingIntervalDuration = 3 * time.Second
}
} else {
tunnelConfig.PingIntervalDuration = 3 * time.Second
}
// Parse ping timeout
if req.PingTimeout != "" {
tunnelConfig.PingTimeoutDuration, err = time.ParseDuration(req.PingTimeout)
if err != nil {
logger.Warn("Invalid PING_TIMEOUT value: %s, using default 5 seconds", req.PingTimeout)
tunnelConfig.PingTimeoutDuration = 5 * time.Second
}
} else {
tunnelConfig.PingTimeoutDuration = 5 * time.Second
}
if req.MTU == 0 {
tunnelConfig.MTU = 1420
}
if req.DNS == "" {
tunnelConfig.DNS = "9.9.9.9"
}
if req.InterfaceName == "" {
tunnelConfig.InterfaceName = "olm"
}
// Start the tunnel process with the new credentials // Start the tunnel process with the new credentials
if id != "" && secret != "" && endpoint != "" { if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" {
logger.Info("Starting tunnel with new credentials") logger.Info("Starting tunnel with new credentials")
tunnelRunning = true go StartTunnel(tunnelConfig)
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
} }
case <-apiServer.GetDisconnectChannel(): return nil
logger.Info("Received disconnect request via API") },
// onSwitchOrg
func(req api.SwitchOrgRequest) error {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Ensure we have an active olmClient
if olmClient == nil {
return fmt.Errorf("no active connection to switch organizations")
}
// Update the orgID in the API server
apiServer.SetOrgID(req.OrgID)
// Mark as not connected to trigger re-registration
connected = false
Close()
// Clear peer statuses in API
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", req.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": true, // Default to relay mode for org switch
"olmVersion": globalConfig.Version,
"orgId": req.OrgID,
}, 1*time.Second)
return nil
},
// onDisconnect
func() error {
logger.Info("Processing disconnect request via API")
StopTunnel() StopTunnel()
// Clear credentials so we wait for new connect call return nil
id = "" },
secret = "" // onExit
endpoint = "" func() error {
userToken = "" logger.Info("Processing shutdown request via API")
cancel()
default: return nil
// If we have credentials and no tunnel is running, start it },
if id != "" && secret != "" && endpoint != "" && !tunnelRunning { )
logger.Info("Starting tunnel process with initial credentials")
tunnelRunning = true
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
} else if id == "" || secret == "" || endpoint == "" {
// If we don't have credentials, check if API is enabled
if !config.EnableAPI {
missing := []string{}
if id == "" {
missing = append(missing, "id")
}
if secret == "" {
missing = append(missing, "secret")
}
if endpoint == "" {
missing = append(missing, "endpoint")
}
// exit the application because there is no way to provide the missing parameters
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
goto shutdown
}
}
// Sleep briefly to prevent tight loop
time.Sleep(100 * time.Millisecond)
}
}
shutdown:
Close()
apiServer.Stop()
logger.Info("Olm service shutting down")
} }
func StartTunnel(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) { func StartTunnel(config TunnelConfig) {
if tunnelRunning {
logger.Info("Tunnel already running")
return
}
tunnelRunning = true // Also set it here in case it is called externally
if config.Holepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
// Create a cancellable context for this tunnel process // Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(ctx) tunnelCtx, cancel := context.WithCancel(globalCtx)
tunnelCancel = cancel tunnelCancel = cancel
defer func() { defer func() {
tunnelCancel = nil tunnelCancel = nil
@@ -205,8 +241,14 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
var ( var (
interfaceName = config.InterfaceName interfaceName = config.InterfaceName
id = config.ID
secret = config.Secret
endpoint = config.Endpoint
userToken = config.UserToken
) )
apiServer.SetOrgID(config.OrgID)
// Create a new olm client using the provided credentials // Create a new olm client using the provided credentials
olm, err := websocket.NewClient( olm, err := websocket.NewClient(
id, // Use provided ID id, // Use provided ID
@@ -431,7 +473,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
logger.Error("Failed to bring up WireGuard device: %v", err) logger.Error("Failed to bring up WireGuard device: %v", err)
} }
if err = ConfigureInterface(interfaceName, wgData); err != nil { if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil {
logger.Error("Failed to configure interface: %v", err) logger.Error("Failed to configure interface: %v", err)
} }
@@ -753,7 +795,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"relay": !config.Holepunch, "relay": !config.Holepunch,
"olmVersion": config.Version, "olmVersion": globalConfig.Version,
"orgId": config.OrgID, "orgId": config.OrgID,
// "doNotCreateNewClient": config.DoNotCreateNewClient, // "doNotCreateNewClient": config.DoNotCreateNewClient,
}, 1*time.Second) }, 1*time.Second)
@@ -777,36 +819,6 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
} }
defer olm.Close() defer olm.Close()
// Listen for org switch requests from the API
go func() {
for req := range apiServer.GetSwitchOrgChannel() {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Update the config with the new orgId
config.OrgID = req.OrgID
// Mark as not connected to trigger re-registration
connected = false
Close()
// Clear peer statuses in API
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
apiServer.SetOrgID(config.OrgID)
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
}, 1*time.Second)
}
}()
// Wait for context cancellation // Wait for context cancellation
<-tunnelCtx.Done() <-tunnelCtx.Done()
logger.Info("Tunnel process context cancelled, cleaning up") logger.Info("Tunnel process context cancelled, cleaning up")