Split up concerns so parent can call start and stop

This commit is contained in:
Owen
2025-11-18 18:14:21 -05:00
parent e7be7fb281
commit 8f97c43b63
4 changed files with 246 additions and 222 deletions

View File

@@ -13,10 +13,18 @@ import (
// ConnectionRequest defines the structure for an incoming connection request
type ConnectionRequest struct {
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"`
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"`
MTU int `json:"mtu,omitempty"`
DNS string `json:"dns,omitempty"`
InterfaceName string `json:"interfaceName,omitempty"`
Holepunch bool `json:"holepunch,omitempty"`
TlsClientCert string `json:"tlsClientCert,omitempty"`
PingInterval string `json:"pingInterval,omitempty"`
PingTimeout string `json:"pingTimeout,omitempty"`
OrgID string `json:"orgId,omitempty"`
}
// SwitchOrgRequest defines the structure for switching organizations
@@ -47,33 +55,29 @@ type StatusResponse struct {
// API represents the HTTP server and its state
type API struct {
addr string
socketPath string
listener net.Listener
server *http.Server
connectionChan chan ConnectionRequest
switchOrgChan chan SwitchOrgRequest
shutdownChan chan struct{}
disconnectChan chan struct{}
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
isRegistered bool
tunnelIP string
version string
orgID string
addr string
socketPath string
listener net.Listener
server *http.Server
onConnect func(ConnectionRequest) error
onSwitchOrg func(SwitchOrgRequest) error
onDisconnect func() error
onExit func() error
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
isRegistered bool
tunnelIP string
version string
orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
func NewAPI(addr string) *API {
s := &API{
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
addr: addr,
peerStatuses: make(map[int]*PeerStatus),
}
return s
@@ -82,17 +86,26 @@ func NewAPI(addr string) *API {
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
func NewAPISocket(socketPath string) *API {
s := &API{
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
socketPath: socketPath,
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// SetHandlers sets the callback functions for handling API requests
func (s *API) SetHandlers(
onConnect func(ConnectionRequest) error,
onSwitchOrg func(SwitchOrgRequest) error,
onDisconnect func() error,
onExit func() error,
) {
s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg
s.onDisconnect = onDisconnect
s.onExit = onExit
}
// Start starts the HTTP server
func (s *API) Start() error {
mux := http.NewServeMux()
@@ -149,26 +162,6 @@ func (s *API) Stop() error {
return nil
}
// GetConnectionChannel returns the channel for receiving connection requests
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
return s.connectionChan
}
// GetSwitchOrgChannel returns the channel for receiving org switch requests
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
return s.switchOrgChan
}
// GetShutdownChannel returns the channel for receiving shutdown requests
func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
}
// GetDisconnectChannel returns the channel for receiving disconnect requests
func (s *API) GetDisconnectChannel() <-chan struct{} {
return s.disconnectChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -277,8 +270,13 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
return
}
// Send the request to the main goroutine
s.connectionChan <- req
// Call the connect handler if set
if s.onConnect != nil {
if err := s.onConnect(req); err != nil {
http.Error(w, fmt.Sprintf("Connection failed: %v", err), http.StatusInternalServerError)
return
}
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
@@ -320,12 +318,12 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
logger.Info("Received exit request via API")
// Send shutdown signal
select {
case s.shutdownChan <- struct{}{}:
// Signal sent successfully
default:
// Channel already has a signal, don't block
// Call the exit handler if set
if s.onExit != nil {
if err := s.onExit(); err != nil {
http.Error(w, fmt.Sprintf("Exit failed: %v", err), http.StatusInternalServerError)
return
}
}
// Return a success response
@@ -358,14 +356,12 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
logger.Info("Received org switch request to orgId: %s", req.OrgID)
// Send the request to the main goroutine
select {
case s.switchOrgChan <- req:
// Signal sent successfully
default:
// Channel already has a pending request
http.Error(w, "Org switch already in progress", http.StatusConflict)
return
// Call the switch org handler if set
if s.onSwitchOrg != nil {
if err := s.onSwitchOrg(req); err != nil {
http.Error(w, fmt.Sprintf("Org switch failed: %v", err), http.StatusInternalServerError)
return
}
}
// Return a success response
@@ -394,12 +390,12 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
logger.Info("Received disconnect request via API")
// Send disconnect signal
select {
case s.disconnectChan <- struct{}{}:
// Signal sent successfully
default:
// Channel already has a signal, don't block
// Call the disconnect handler if set
if s.onDisconnect != nil {
if err := s.onDisconnect(); err != nil {
http.Error(w, fmt.Sprintf("Disconnect failed: %v", err), http.StatusInternalServerError)
return
}
}
// Return a success response

53
main.go
View File

@@ -205,26 +205,41 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
}
// Create a new olm.Config struct and copy values from the main config
olmConfig := olm.Config{
Endpoint: config.Endpoint,
ID: config.ID,
Secret: config.Secret,
UserToken: config.UserToken,
MTU: config.MTU,
DNS: config.DNS,
InterfaceName: config.InterfaceName,
LogLevel: config.LogLevel,
EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr,
SocketPath: config.SocketPath,
Holepunch: config.Holepunch,
TlsClientCert: config.TlsClientCert,
PingIntervalDuration: config.PingIntervalDuration,
PingTimeoutDuration: config.PingTimeoutDuration,
Version: config.Version,
OrgID: config.OrgID,
// DoNotCreateNewClient: config.DoNotCreateNewClient,
olmConfig := olm.GlobalConfig{
LogLevel: config.LogLevel,
EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr,
SocketPath: config.SocketPath,
Version: config.Version,
}
olm.Init(ctx, olmConfig)
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
tunnelConfig := olm.TunnelConfig{
Endpoint: config.Endpoint,
ID: config.ID,
Secret: config.Secret,
UserToken: config.UserToken,
MTU: config.MTU,
DNS: config.DNS,
InterfaceName: config.InterfaceName,
Holepunch: config.Holepunch,
TlsClientCert: config.TlsClientCert,
PingIntervalDuration: config.PingIntervalDuration,
PingTimeoutDuration: config.PingTimeoutDuration,
OrgID: config.OrgID,
}
go olm.StartTunnel(tunnelConfig)
} else {
logger.Info("Incomplete tunnel configuration, not starting tunnel")
}
// Wait for context cancellation (from signals or API shutdown)
<-ctx.Done()
logger.Info("Shutdown signal received, cleaning up...")
// Clean up resources
olm.Close()
logger.Info("Shutdown complete")
}

View File

@@ -15,7 +15,7 @@ import (
)
// ConfigureInterface configures a network interface with an IP address and brings it up
func ConfigureInterface(interfaceName string, wgData WgData) error {
func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error {
logger.Info("The tunnel IP is: %s", wgData.TunnelIP)
// Parse the IP address and network
@@ -32,6 +32,7 @@ func ConfigureInterface(interfaceName string, wgData WgData) error {
// network.SetTunnelRemoteAddress() // what does this do?
network.SetIPv4Settings([]string{destinationAddress}, []string{mask})
network.SetMTU(mtu)
apiServer.SetTunnelIP(destinationAddress)
if interfaceName == "" {

View File

@@ -3,6 +3,7 @@ package olm
import (
"context"
"encoding/json"
"fmt"
"net"
"runtime"
"time"
@@ -20,7 +21,21 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Config struct {
type GlobalConfig struct {
// Logging
LogLevel string
// HTTP server
EnableAPI bool
HTTPAddr string
SocketPath string
Version string
// Source tracking (not in JSON)
sources map[string]string
}
type TunnelConfig struct {
// Connection settings
Endpoint string
ID string
@@ -32,14 +47,6 @@ type Config struct {
DNS string
InterfaceName string
// Logging
LogLevel string
// HTTP server
EnableAPI bool
HTTPAddr string
SocketPath string
// Advanced
Holepunch bool
TlsClientCert string
@@ -48,11 +55,7 @@ type Config struct {
PingIntervalDuration time.Duration
PingTimeoutDuration time.Duration
// Source tracking (not in JSON)
sources map[string]string
Version string
OrgID string
OrgID string
// DoNotCreateNewClient bool
FileDescriptorTun uint32
@@ -74,21 +77,21 @@ var (
sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager
peerMonitor *peermonitor.PeerMonitor
globalConfig GlobalConfig
globalCtx context.Context
stopRegister func()
stopPing chan struct{}
)
func Init(ctx context.Context, config Config) {
func Init(ctx context.Context, config GlobalConfig) {
globalConfig = config
globalCtx = ctx
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
network.SetMTU(config.MTU)
if config.Holepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
if config.HTTPAddr != "" {
apiServer = api.NewAPI(config.HTTPAddr)
@@ -97,35 +100,15 @@ func Init(ctx context.Context, config Config) {
}
apiServer.SetVersion(config.Version)
apiServer.SetOrgID(config.OrgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
// Listen for shutdown requests from the API
go func() {
<-apiServer.GetShutdownChannel()
logger.Info("Shutdown requested via API")
// Cancel the context to trigger graceful shutdown
cancel()
}()
var (
id = config.ID
secret = config.Secret
endpoint = config.Endpoint
userToken = config.UserToken
)
// Main event loop that handles connect, disconnect, and reconnect
for {
select {
case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials")
goto shutdown
case req := <-apiServer.GetConnectionChannel():
// Set up API handlers
apiServer.SetHandlers(
// onConnect
func(req api.ConnectionRequest) error {
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Stop any existing tunnel before starting a new one
@@ -134,67 +117,120 @@ func Init(ctx context.Context, config Config) {
StopTunnel()
}
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
userToken := req.UserToken
tunnelConfig := TunnelConfig{
Endpoint: req.Endpoint,
ID: req.ID,
Secret: req.Secret,
UserToken: req.UserToken,
MTU: req.MTU,
DNS: req.DNS,
InterfaceName: req.InterfaceName,
Holepunch: req.Holepunch,
TlsClientCert: req.TlsClientCert,
OrgID: req.OrgID,
}
var err error
// Parse ping interval
if req.PingInterval != "" {
tunnelConfig.PingIntervalDuration, err = time.ParseDuration(req.PingInterval)
if err != nil {
logger.Warn("Invalid PING_INTERVAL value: %s, using default 3 seconds", req.PingInterval)
tunnelConfig.PingIntervalDuration = 3 * time.Second
}
} else {
tunnelConfig.PingIntervalDuration = 3 * time.Second
}
// Parse ping timeout
if req.PingTimeout != "" {
tunnelConfig.PingTimeoutDuration, err = time.ParseDuration(req.PingTimeout)
if err != nil {
logger.Warn("Invalid PING_TIMEOUT value: %s, using default 5 seconds", req.PingTimeout)
tunnelConfig.PingTimeoutDuration = 5 * time.Second
}
} else {
tunnelConfig.PingTimeoutDuration = 5 * time.Second
}
if req.MTU == 0 {
tunnelConfig.MTU = 1420
}
if req.DNS == "" {
tunnelConfig.DNS = "9.9.9.9"
}
if req.InterfaceName == "" {
tunnelConfig.InterfaceName = "olm"
}
// Start the tunnel process with the new credentials
if id != "" && secret != "" && endpoint != "" {
if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" {
logger.Info("Starting tunnel with new credentials")
tunnelRunning = true
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
go StartTunnel(tunnelConfig)
}
case <-apiServer.GetDisconnectChannel():
logger.Info("Received disconnect request via API")
return nil
},
// onSwitchOrg
func(req api.SwitchOrgRequest) error {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Ensure we have an active olmClient
if olmClient == nil {
return fmt.Errorf("no active connection to switch organizations")
}
// Update the orgID in the API server
apiServer.SetOrgID(req.OrgID)
// Mark as not connected to trigger re-registration
connected = false
Close()
// Clear peer statuses in API
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", req.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": true, // Default to relay mode for org switch
"olmVersion": globalConfig.Version,
"orgId": req.OrgID,
}, 1*time.Second)
return nil
},
// onDisconnect
func() error {
logger.Info("Processing disconnect request via API")
StopTunnel()
// Clear credentials so we wait for new connect call
id = ""
secret = ""
endpoint = ""
userToken = ""
default:
// If we have credentials and no tunnel is running, start it
if id != "" && secret != "" && endpoint != "" && !tunnelRunning {
logger.Info("Starting tunnel process with initial credentials")
tunnelRunning = true
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
} else if id == "" || secret == "" || endpoint == "" {
// If we don't have credentials, check if API is enabled
if !config.EnableAPI {
missing := []string{}
if id == "" {
missing = append(missing, "id")
}
if secret == "" {
missing = append(missing, "secret")
}
if endpoint == "" {
missing = append(missing, "endpoint")
}
// exit the application because there is no way to provide the missing parameters
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
goto shutdown
}
}
// Sleep briefly to prevent tight loop
time.Sleep(100 * time.Millisecond)
}
}
shutdown:
Close()
apiServer.Stop()
logger.Info("Olm service shutting down")
return nil
},
// onExit
func() error {
logger.Info("Processing shutdown request via API")
cancel()
return nil
},
)
}
func StartTunnel(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) {
func StartTunnel(config TunnelConfig) {
if tunnelRunning {
logger.Info("Tunnel already running")
return
}
tunnelRunning = true // Also set it here in case it is called externally
if config.Holepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
// Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(ctx)
tunnelCtx, cancel := context.WithCancel(globalCtx)
tunnelCancel = cancel
defer func() {
tunnelCancel = nil
@@ -205,8 +241,14 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
var (
interfaceName = config.InterfaceName
id = config.ID
secret = config.Secret
endpoint = config.Endpoint
userToken = config.UserToken
)
apiServer.SetOrgID(config.OrgID)
// Create a new olm client using the provided credentials
olm, err := websocket.NewClient(
id, // Use provided ID
@@ -431,7 +473,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
@@ -753,7 +795,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"olmVersion": globalConfig.Version,
"orgId": config.OrgID,
// "doNotCreateNewClient": config.DoNotCreateNewClient,
}, 1*time.Second)
@@ -777,36 +819,6 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
}
defer olm.Close()
// Listen for org switch requests from the API
go func() {
for req := range apiServer.GetSwitchOrgChannel() {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Update the config with the new orgId
config.OrgID = req.OrgID
// Mark as not connected to trigger re-registration
connected = false
Close()
// Clear peer statuses in API
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
apiServer.SetOrgID(config.OrgID)
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
}, 1*time.Second)
}
}()
// Wait for context cancellation
<-tunnelCtx.Done()
logger.Info("Tunnel process context cancelled, cleaning up")