Split up concerns so parent can call start and stop

This commit is contained in:
Owen
2025-11-18 18:14:21 -05:00
parent e7be7fb281
commit 8f97c43b63
4 changed files with 246 additions and 222 deletions

View File

@@ -15,7 +15,7 @@ import (
)
// ConfigureInterface configures a network interface with an IP address and brings it up
func ConfigureInterface(interfaceName string, wgData WgData) error {
func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error {
logger.Info("The tunnel IP is: %s", wgData.TunnelIP)
// Parse the IP address and network
@@ -32,6 +32,7 @@ func ConfigureInterface(interfaceName string, wgData WgData) error {
// network.SetTunnelRemoteAddress() // what does this do?
network.SetIPv4Settings([]string{destinationAddress}, []string{mask})
network.SetMTU(mtu)
apiServer.SetTunnelIP(destinationAddress)
if interfaceName == "" {

View File

@@ -3,6 +3,7 @@ package olm
import (
"context"
"encoding/json"
"fmt"
"net"
"runtime"
"time"
@@ -20,7 +21,21 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Config struct {
type GlobalConfig struct {
// Logging
LogLevel string
// HTTP server
EnableAPI bool
HTTPAddr string
SocketPath string
Version string
// Source tracking (not in JSON)
sources map[string]string
}
type TunnelConfig struct {
// Connection settings
Endpoint string
ID string
@@ -32,14 +47,6 @@ type Config struct {
DNS string
InterfaceName string
// Logging
LogLevel string
// HTTP server
EnableAPI bool
HTTPAddr string
SocketPath string
// Advanced
Holepunch bool
TlsClientCert string
@@ -48,11 +55,7 @@ type Config struct {
PingIntervalDuration time.Duration
PingTimeoutDuration time.Duration
// Source tracking (not in JSON)
sources map[string]string
Version string
OrgID string
OrgID string
// DoNotCreateNewClient bool
FileDescriptorTun uint32
@@ -74,21 +77,21 @@ var (
sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager
peerMonitor *peermonitor.PeerMonitor
globalConfig GlobalConfig
globalCtx context.Context
stopRegister func()
stopPing chan struct{}
)
func Init(ctx context.Context, config Config) {
func Init(ctx context.Context, config GlobalConfig) {
globalConfig = config
globalCtx = ctx
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
network.SetMTU(config.MTU)
if config.Holepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
if config.HTTPAddr != "" {
apiServer = api.NewAPI(config.HTTPAddr)
@@ -97,35 +100,15 @@ func Init(ctx context.Context, config Config) {
}
apiServer.SetVersion(config.Version)
apiServer.SetOrgID(config.OrgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
// Listen for shutdown requests from the API
go func() {
<-apiServer.GetShutdownChannel()
logger.Info("Shutdown requested via API")
// Cancel the context to trigger graceful shutdown
cancel()
}()
var (
id = config.ID
secret = config.Secret
endpoint = config.Endpoint
userToken = config.UserToken
)
// Main event loop that handles connect, disconnect, and reconnect
for {
select {
case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials")
goto shutdown
case req := <-apiServer.GetConnectionChannel():
// Set up API handlers
apiServer.SetHandlers(
// onConnect
func(req api.ConnectionRequest) error {
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Stop any existing tunnel before starting a new one
@@ -134,67 +117,120 @@ func Init(ctx context.Context, config Config) {
StopTunnel()
}
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
userToken := req.UserToken
tunnelConfig := TunnelConfig{
Endpoint: req.Endpoint,
ID: req.ID,
Secret: req.Secret,
UserToken: req.UserToken,
MTU: req.MTU,
DNS: req.DNS,
InterfaceName: req.InterfaceName,
Holepunch: req.Holepunch,
TlsClientCert: req.TlsClientCert,
OrgID: req.OrgID,
}
var err error
// Parse ping interval
if req.PingInterval != "" {
tunnelConfig.PingIntervalDuration, err = time.ParseDuration(req.PingInterval)
if err != nil {
logger.Warn("Invalid PING_INTERVAL value: %s, using default 3 seconds", req.PingInterval)
tunnelConfig.PingIntervalDuration = 3 * time.Second
}
} else {
tunnelConfig.PingIntervalDuration = 3 * time.Second
}
// Parse ping timeout
if req.PingTimeout != "" {
tunnelConfig.PingTimeoutDuration, err = time.ParseDuration(req.PingTimeout)
if err != nil {
logger.Warn("Invalid PING_TIMEOUT value: %s, using default 5 seconds", req.PingTimeout)
tunnelConfig.PingTimeoutDuration = 5 * time.Second
}
} else {
tunnelConfig.PingTimeoutDuration = 5 * time.Second
}
if req.MTU == 0 {
tunnelConfig.MTU = 1420
}
if req.DNS == "" {
tunnelConfig.DNS = "9.9.9.9"
}
if req.InterfaceName == "" {
tunnelConfig.InterfaceName = "olm"
}
// Start the tunnel process with the new credentials
if id != "" && secret != "" && endpoint != "" {
if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" {
logger.Info("Starting tunnel with new credentials")
tunnelRunning = true
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
go StartTunnel(tunnelConfig)
}
case <-apiServer.GetDisconnectChannel():
logger.Info("Received disconnect request via API")
return nil
},
// onSwitchOrg
func(req api.SwitchOrgRequest) error {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Ensure we have an active olmClient
if olmClient == nil {
return fmt.Errorf("no active connection to switch organizations")
}
// Update the orgID in the API server
apiServer.SetOrgID(req.OrgID)
// Mark as not connected to trigger re-registration
connected = false
Close()
// Clear peer statuses in API
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", req.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": true, // Default to relay mode for org switch
"olmVersion": globalConfig.Version,
"orgId": req.OrgID,
}, 1*time.Second)
return nil
},
// onDisconnect
func() error {
logger.Info("Processing disconnect request via API")
StopTunnel()
// Clear credentials so we wait for new connect call
id = ""
secret = ""
endpoint = ""
userToken = ""
default:
// If we have credentials and no tunnel is running, start it
if id != "" && secret != "" && endpoint != "" && !tunnelRunning {
logger.Info("Starting tunnel process with initial credentials")
tunnelRunning = true
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
} else if id == "" || secret == "" || endpoint == "" {
// If we don't have credentials, check if API is enabled
if !config.EnableAPI {
missing := []string{}
if id == "" {
missing = append(missing, "id")
}
if secret == "" {
missing = append(missing, "secret")
}
if endpoint == "" {
missing = append(missing, "endpoint")
}
// exit the application because there is no way to provide the missing parameters
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
goto shutdown
}
}
// Sleep briefly to prevent tight loop
time.Sleep(100 * time.Millisecond)
}
}
shutdown:
Close()
apiServer.Stop()
logger.Info("Olm service shutting down")
return nil
},
// onExit
func() error {
logger.Info("Processing shutdown request via API")
cancel()
return nil
},
)
}
func StartTunnel(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) {
func StartTunnel(config TunnelConfig) {
if tunnelRunning {
logger.Info("Tunnel already running")
return
}
tunnelRunning = true // Also set it here in case it is called externally
if config.Holepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
// Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(ctx)
tunnelCtx, cancel := context.WithCancel(globalCtx)
tunnelCancel = cancel
defer func() {
tunnelCancel = nil
@@ -205,8 +241,14 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
var (
interfaceName = config.InterfaceName
id = config.ID
secret = config.Secret
endpoint = config.Endpoint
userToken = config.UserToken
)
apiServer.SetOrgID(config.OrgID)
// Create a new olm client using the provided credentials
olm, err := websocket.NewClient(
id, // Use provided ID
@@ -431,7 +473,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
@@ -753,7 +795,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"olmVersion": globalConfig.Version,
"orgId": config.OrgID,
// "doNotCreateNewClient": config.DoNotCreateNewClient,
}, 1*time.Second)
@@ -777,36 +819,6 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
}
defer olm.Close()
// Listen for org switch requests from the API
go func() {
for req := range apiServer.GetSwitchOrgChannel() {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Update the config with the new orgId
config.OrgID = req.OrgID
// Mark as not connected to trigger re-registration
connected = false
Close()
// Clear peer statuses in API
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
apiServer.SetOrgID(config.OrgID)
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
}, 1*time.Second)
}
}()
// Wait for context cancellation
<-tunnelCtx.Done()
logger.Info("Tunnel process context cancelled, cleaning up")