mirror of
https://github.com/fosrl/olm.git
synced 2026-02-20 20:06:43 +00:00
Split up concerns so parent can call start and stop
This commit is contained in:
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
// ConfigureInterface configures a network interface with an IP address and brings it up
|
||||
func ConfigureInterface(interfaceName string, wgData WgData) error {
|
||||
func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error {
|
||||
logger.Info("The tunnel IP is: %s", wgData.TunnelIP)
|
||||
|
||||
// Parse the IP address and network
|
||||
@@ -32,6 +32,7 @@ func ConfigureInterface(interfaceName string, wgData WgData) error {
|
||||
|
||||
// network.SetTunnelRemoteAddress() // what does this do?
|
||||
network.SetIPv4Settings([]string{destinationAddress}, []string{mask})
|
||||
network.SetMTU(mtu)
|
||||
apiServer.SetTunnelIP(destinationAddress)
|
||||
|
||||
if interfaceName == "" {
|
||||
|
||||
268
olm/olm.go
268
olm/olm.go
@@ -3,6 +3,7 @@ package olm
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
@@ -20,7 +21,21 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
type GlobalConfig struct {
|
||||
// Logging
|
||||
LogLevel string
|
||||
|
||||
// HTTP server
|
||||
EnableAPI bool
|
||||
HTTPAddr string
|
||||
SocketPath string
|
||||
Version string
|
||||
|
||||
// Source tracking (not in JSON)
|
||||
sources map[string]string
|
||||
}
|
||||
|
||||
type TunnelConfig struct {
|
||||
// Connection settings
|
||||
Endpoint string
|
||||
ID string
|
||||
@@ -32,14 +47,6 @@ type Config struct {
|
||||
DNS string
|
||||
InterfaceName string
|
||||
|
||||
// Logging
|
||||
LogLevel string
|
||||
|
||||
// HTTP server
|
||||
EnableAPI bool
|
||||
HTTPAddr string
|
||||
SocketPath string
|
||||
|
||||
// Advanced
|
||||
Holepunch bool
|
||||
TlsClientCert string
|
||||
@@ -48,11 +55,7 @@ type Config struct {
|
||||
PingIntervalDuration time.Duration
|
||||
PingTimeoutDuration time.Duration
|
||||
|
||||
// Source tracking (not in JSON)
|
||||
sources map[string]string
|
||||
|
||||
Version string
|
||||
OrgID string
|
||||
OrgID string
|
||||
// DoNotCreateNewClient bool
|
||||
|
||||
FileDescriptorTun uint32
|
||||
@@ -74,21 +77,21 @@ var (
|
||||
sharedBind *bind.SharedBind
|
||||
holePunchManager *holepunch.Manager
|
||||
peerMonitor *peermonitor.PeerMonitor
|
||||
globalConfig GlobalConfig
|
||||
globalCtx context.Context
|
||||
stopRegister func()
|
||||
stopPing chan struct{}
|
||||
)
|
||||
|
||||
func Init(ctx context.Context, config Config) {
|
||||
func Init(ctx context.Context, config GlobalConfig) {
|
||||
globalConfig = config
|
||||
globalCtx = ctx
|
||||
|
||||
// Create a cancellable context for internal shutdown control
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
||||
network.SetMTU(config.MTU)
|
||||
|
||||
if config.Holepunch {
|
||||
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||
}
|
||||
|
||||
if config.HTTPAddr != "" {
|
||||
apiServer = api.NewAPI(config.HTTPAddr)
|
||||
@@ -97,35 +100,15 @@ func Init(ctx context.Context, config Config) {
|
||||
}
|
||||
|
||||
apiServer.SetVersion(config.Version)
|
||||
apiServer.SetOrgID(config.OrgID)
|
||||
|
||||
if err := apiServer.Start(); err != nil {
|
||||
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||
}
|
||||
|
||||
// Listen for shutdown requests from the API
|
||||
go func() {
|
||||
<-apiServer.GetShutdownChannel()
|
||||
logger.Info("Shutdown requested via API")
|
||||
// Cancel the context to trigger graceful shutdown
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var (
|
||||
id = config.ID
|
||||
secret = config.Secret
|
||||
endpoint = config.Endpoint
|
||||
userToken = config.UserToken
|
||||
)
|
||||
|
||||
// Main event loop that handles connect, disconnect, and reconnect
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info("Context cancelled while waiting for credentials")
|
||||
goto shutdown
|
||||
|
||||
case req := <-apiServer.GetConnectionChannel():
|
||||
// Set up API handlers
|
||||
apiServer.SetHandlers(
|
||||
// onConnect
|
||||
func(req api.ConnectionRequest) error {
|
||||
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
||||
|
||||
// Stop any existing tunnel before starting a new one
|
||||
@@ -134,67 +117,120 @@ func Init(ctx context.Context, config Config) {
|
||||
StopTunnel()
|
||||
}
|
||||
|
||||
// Set the connection parameters
|
||||
id = req.ID
|
||||
secret = req.Secret
|
||||
endpoint = req.Endpoint
|
||||
userToken := req.UserToken
|
||||
tunnelConfig := TunnelConfig{
|
||||
Endpoint: req.Endpoint,
|
||||
ID: req.ID,
|
||||
Secret: req.Secret,
|
||||
UserToken: req.UserToken,
|
||||
MTU: req.MTU,
|
||||
DNS: req.DNS,
|
||||
InterfaceName: req.InterfaceName,
|
||||
Holepunch: req.Holepunch,
|
||||
TlsClientCert: req.TlsClientCert,
|
||||
OrgID: req.OrgID,
|
||||
}
|
||||
|
||||
var err error
|
||||
// Parse ping interval
|
||||
if req.PingInterval != "" {
|
||||
tunnelConfig.PingIntervalDuration, err = time.ParseDuration(req.PingInterval)
|
||||
if err != nil {
|
||||
logger.Warn("Invalid PING_INTERVAL value: %s, using default 3 seconds", req.PingInterval)
|
||||
tunnelConfig.PingIntervalDuration = 3 * time.Second
|
||||
}
|
||||
} else {
|
||||
tunnelConfig.PingIntervalDuration = 3 * time.Second
|
||||
}
|
||||
// Parse ping timeout
|
||||
if req.PingTimeout != "" {
|
||||
tunnelConfig.PingTimeoutDuration, err = time.ParseDuration(req.PingTimeout)
|
||||
if err != nil {
|
||||
logger.Warn("Invalid PING_TIMEOUT value: %s, using default 5 seconds", req.PingTimeout)
|
||||
tunnelConfig.PingTimeoutDuration = 5 * time.Second
|
||||
}
|
||||
} else {
|
||||
tunnelConfig.PingTimeoutDuration = 5 * time.Second
|
||||
}
|
||||
if req.MTU == 0 {
|
||||
tunnelConfig.MTU = 1420
|
||||
}
|
||||
if req.DNS == "" {
|
||||
tunnelConfig.DNS = "9.9.9.9"
|
||||
}
|
||||
if req.InterfaceName == "" {
|
||||
tunnelConfig.InterfaceName = "olm"
|
||||
}
|
||||
|
||||
// Start the tunnel process with the new credentials
|
||||
if id != "" && secret != "" && endpoint != "" {
|
||||
if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" {
|
||||
logger.Info("Starting tunnel with new credentials")
|
||||
tunnelRunning = true
|
||||
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
|
||||
go StartTunnel(tunnelConfig)
|
||||
}
|
||||
|
||||
case <-apiServer.GetDisconnectChannel():
|
||||
logger.Info("Received disconnect request via API")
|
||||
return nil
|
||||
},
|
||||
// onSwitchOrg
|
||||
func(req api.SwitchOrgRequest) error {
|
||||
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
||||
|
||||
// Ensure we have an active olmClient
|
||||
if olmClient == nil {
|
||||
return fmt.Errorf("no active connection to switch organizations")
|
||||
}
|
||||
|
||||
// Update the orgID in the API server
|
||||
apiServer.SetOrgID(req.OrgID)
|
||||
|
||||
// Mark as not connected to trigger re-registration
|
||||
connected = false
|
||||
|
||||
Close()
|
||||
|
||||
// Clear peer statuses in API
|
||||
apiServer.SetRegistered(false)
|
||||
apiServer.SetTunnelIP("")
|
||||
|
||||
// Trigger re-registration with new orgId
|
||||
logger.Info("Re-registering with new orgId: %s", req.OrgID)
|
||||
publicKey := privateKey.PublicKey()
|
||||
stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": true, // Default to relay mode for org switch
|
||||
"olmVersion": globalConfig.Version,
|
||||
"orgId": req.OrgID,
|
||||
}, 1*time.Second)
|
||||
|
||||
return nil
|
||||
},
|
||||
// onDisconnect
|
||||
func() error {
|
||||
logger.Info("Processing disconnect request via API")
|
||||
StopTunnel()
|
||||
// Clear credentials so we wait for new connect call
|
||||
id = ""
|
||||
secret = ""
|
||||
endpoint = ""
|
||||
userToken = ""
|
||||
|
||||
default:
|
||||
// If we have credentials and no tunnel is running, start it
|
||||
if id != "" && secret != "" && endpoint != "" && !tunnelRunning {
|
||||
logger.Info("Starting tunnel process with initial credentials")
|
||||
tunnelRunning = true
|
||||
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
|
||||
} else if id == "" || secret == "" || endpoint == "" {
|
||||
// If we don't have credentials, check if API is enabled
|
||||
if !config.EnableAPI {
|
||||
missing := []string{}
|
||||
if id == "" {
|
||||
missing = append(missing, "id")
|
||||
}
|
||||
if secret == "" {
|
||||
missing = append(missing, "secret")
|
||||
}
|
||||
if endpoint == "" {
|
||||
missing = append(missing, "endpoint")
|
||||
}
|
||||
// exit the application because there is no way to provide the missing parameters
|
||||
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
|
||||
goto shutdown
|
||||
}
|
||||
}
|
||||
|
||||
// Sleep briefly to prevent tight loop
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
shutdown:
|
||||
Close()
|
||||
apiServer.Stop()
|
||||
logger.Info("Olm service shutting down")
|
||||
return nil
|
||||
},
|
||||
// onExit
|
||||
func() error {
|
||||
logger.Info("Processing shutdown request via API")
|
||||
cancel()
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func StartTunnel(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) {
|
||||
func StartTunnel(config TunnelConfig) {
|
||||
if tunnelRunning {
|
||||
logger.Info("Tunnel already running")
|
||||
return
|
||||
}
|
||||
|
||||
tunnelRunning = true // Also set it here in case it is called externally
|
||||
|
||||
if config.Holepunch {
|
||||
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||
}
|
||||
|
||||
// Create a cancellable context for this tunnel process
|
||||
tunnelCtx, cancel := context.WithCancel(ctx)
|
||||
tunnelCtx, cancel := context.WithCancel(globalCtx)
|
||||
tunnelCancel = cancel
|
||||
defer func() {
|
||||
tunnelCancel = nil
|
||||
@@ -205,8 +241,14 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
||||
|
||||
var (
|
||||
interfaceName = config.InterfaceName
|
||||
id = config.ID
|
||||
secret = config.Secret
|
||||
endpoint = config.Endpoint
|
||||
userToken = config.UserToken
|
||||
)
|
||||
|
||||
apiServer.SetOrgID(config.OrgID)
|
||||
|
||||
// Create a new olm client using the provided credentials
|
||||
olm, err := websocket.NewClient(
|
||||
id, // Use provided ID
|
||||
@@ -431,7 +473,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||
if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil {
|
||||
logger.Error("Failed to configure interface: %v", err)
|
||||
}
|
||||
|
||||
@@ -753,7 +795,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": !config.Holepunch,
|
||||
"olmVersion": config.Version,
|
||||
"olmVersion": globalConfig.Version,
|
||||
"orgId": config.OrgID,
|
||||
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||
}, 1*time.Second)
|
||||
@@ -777,36 +819,6 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
||||
}
|
||||
defer olm.Close()
|
||||
|
||||
// Listen for org switch requests from the API
|
||||
go func() {
|
||||
for req := range apiServer.GetSwitchOrgChannel() {
|
||||
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
||||
|
||||
// Update the config with the new orgId
|
||||
config.OrgID = req.OrgID
|
||||
|
||||
// Mark as not connected to trigger re-registration
|
||||
connected = false
|
||||
|
||||
Close()
|
||||
|
||||
// Clear peer statuses in API
|
||||
apiServer.SetRegistered(false)
|
||||
apiServer.SetTunnelIP("")
|
||||
apiServer.SetOrgID(config.OrgID)
|
||||
|
||||
// Trigger re-registration with new orgId
|
||||
logger.Info("Re-registering with new orgId: %s", config.OrgID)
|
||||
publicKey := privateKey.PublicKey()
|
||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": !config.Holepunch,
|
||||
"olmVersion": config.Version,
|
||||
"orgId": config.OrgID,
|
||||
}, 1*time.Second)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for context cancellation
|
||||
<-tunnelCtx.Done()
|
||||
logger.Info("Tunnel process context cancelled, cleaning up")
|
||||
|
||||
Reference in New Issue
Block a user