mirror of
https://github.com/fosrl/olm.git
synced 2026-02-27 07:16:49 +00:00
Split up concerns so parent can call start and stop
This commit is contained in:
144
api/api.go
144
api/api.go
@@ -13,10 +13,18 @@ import (
|
|||||||
|
|
||||||
// ConnectionRequest defines the structure for an incoming connection request
|
// ConnectionRequest defines the structure for an incoming connection request
|
||||||
type ConnectionRequest struct {
|
type ConnectionRequest struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Secret string `json:"secret"`
|
Secret string `json:"secret"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
UserToken string `json:"userToken,omitempty"`
|
UserToken string `json:"userToken,omitempty"`
|
||||||
|
MTU int `json:"mtu,omitempty"`
|
||||||
|
DNS string `json:"dns,omitempty"`
|
||||||
|
InterfaceName string `json:"interfaceName,omitempty"`
|
||||||
|
Holepunch bool `json:"holepunch,omitempty"`
|
||||||
|
TlsClientCert string `json:"tlsClientCert,omitempty"`
|
||||||
|
PingInterval string `json:"pingInterval,omitempty"`
|
||||||
|
PingTimeout string `json:"pingTimeout,omitempty"`
|
||||||
|
OrgID string `json:"orgId,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SwitchOrgRequest defines the structure for switching organizations
|
// SwitchOrgRequest defines the structure for switching organizations
|
||||||
@@ -47,33 +55,29 @@ type StatusResponse struct {
|
|||||||
|
|
||||||
// API represents the HTTP server and its state
|
// API represents the HTTP server and its state
|
||||||
type API struct {
|
type API struct {
|
||||||
addr string
|
addr string
|
||||||
socketPath string
|
socketPath string
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
server *http.Server
|
server *http.Server
|
||||||
connectionChan chan ConnectionRequest
|
onConnect func(ConnectionRequest) error
|
||||||
switchOrgChan chan SwitchOrgRequest
|
onSwitchOrg func(SwitchOrgRequest) error
|
||||||
shutdownChan chan struct{}
|
onDisconnect func() error
|
||||||
disconnectChan chan struct{}
|
onExit func() error
|
||||||
statusMu sync.RWMutex
|
statusMu sync.RWMutex
|
||||||
peerStatuses map[int]*PeerStatus
|
peerStatuses map[int]*PeerStatus
|
||||||
connectedAt time.Time
|
connectedAt time.Time
|
||||||
isConnected bool
|
isConnected bool
|
||||||
isRegistered bool
|
isRegistered bool
|
||||||
tunnelIP string
|
tunnelIP string
|
||||||
version string
|
version string
|
||||||
orgID string
|
orgID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPI creates a new HTTP server that listens on a TCP address
|
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||||
func NewAPI(addr string) *API {
|
func NewAPI(addr string) *API {
|
||||||
s := &API{
|
s := &API{
|
||||||
addr: addr,
|
addr: addr,
|
||||||
connectionChan: make(chan ConnectionRequest, 1),
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
|
||||||
shutdownChan: make(chan struct{}, 1),
|
|
||||||
disconnectChan: make(chan struct{}, 1),
|
|
||||||
peerStatuses: make(map[int]*PeerStatus),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
@@ -82,17 +86,26 @@ func NewAPI(addr string) *API {
|
|||||||
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
|
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
|
||||||
func NewAPISocket(socketPath string) *API {
|
func NewAPISocket(socketPath string) *API {
|
||||||
s := &API{
|
s := &API{
|
||||||
socketPath: socketPath,
|
socketPath: socketPath,
|
||||||
connectionChan: make(chan ConnectionRequest, 1),
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
|
||||||
shutdownChan: make(chan struct{}, 1),
|
|
||||||
disconnectChan: make(chan struct{}, 1),
|
|
||||||
peerStatuses: make(map[int]*PeerStatus),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetHandlers sets the callback functions for handling API requests
|
||||||
|
func (s *API) SetHandlers(
|
||||||
|
onConnect func(ConnectionRequest) error,
|
||||||
|
onSwitchOrg func(SwitchOrgRequest) error,
|
||||||
|
onDisconnect func() error,
|
||||||
|
onExit func() error,
|
||||||
|
) {
|
||||||
|
s.onConnect = onConnect
|
||||||
|
s.onSwitchOrg = onSwitchOrg
|
||||||
|
s.onDisconnect = onDisconnect
|
||||||
|
s.onExit = onExit
|
||||||
|
}
|
||||||
|
|
||||||
// Start starts the HTTP server
|
// Start starts the HTTP server
|
||||||
func (s *API) Start() error {
|
func (s *API) Start() error {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
@@ -149,26 +162,6 @@ func (s *API) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnectionChannel returns the channel for receiving connection requests
|
|
||||||
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
|
|
||||||
return s.connectionChan
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
|
||||||
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
|
||||||
return s.switchOrgChan
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetShutdownChannel returns the channel for receiving shutdown requests
|
|
||||||
func (s *API) GetShutdownChannel() <-chan struct{} {
|
|
||||||
return s.shutdownChan
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDisconnectChannel returns the channel for receiving disconnect requests
|
|
||||||
func (s *API) GetDisconnectChannel() <-chan struct{} {
|
|
||||||
return s.disconnectChan
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||||
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||||
s.statusMu.Lock()
|
s.statusMu.Lock()
|
||||||
@@ -277,8 +270,13 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the request to the main goroutine
|
// Call the connect handler if set
|
||||||
s.connectionChan <- req
|
if s.onConnect != nil {
|
||||||
|
if err := s.onConnect(req); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Connection failed: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Return a success response
|
// Return a success response
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -320,12 +318,12 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
logger.Info("Received exit request via API")
|
logger.Info("Received exit request via API")
|
||||||
|
|
||||||
// Send shutdown signal
|
// Call the exit handler if set
|
||||||
select {
|
if s.onExit != nil {
|
||||||
case s.shutdownChan <- struct{}{}:
|
if err := s.onExit(); err != nil {
|
||||||
// Signal sent successfully
|
http.Error(w, fmt.Sprintf("Exit failed: %v", err), http.StatusInternalServerError)
|
||||||
default:
|
return
|
||||||
// Channel already has a signal, don't block
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a success response
|
// Return a success response
|
||||||
@@ -358,14 +356,12 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||||
|
|
||||||
// Send the request to the main goroutine
|
// Call the switch org handler if set
|
||||||
select {
|
if s.onSwitchOrg != nil {
|
||||||
case s.switchOrgChan <- req:
|
if err := s.onSwitchOrg(req); err != nil {
|
||||||
// Signal sent successfully
|
http.Error(w, fmt.Sprintf("Org switch failed: %v", err), http.StatusInternalServerError)
|
||||||
default:
|
return
|
||||||
// Channel already has a pending request
|
}
|
||||||
http.Error(w, "Org switch already in progress", http.StatusConflict)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a success response
|
// Return a success response
|
||||||
@@ -394,12 +390,12 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
logger.Info("Received disconnect request via API")
|
logger.Info("Received disconnect request via API")
|
||||||
|
|
||||||
// Send disconnect signal
|
// Call the disconnect handler if set
|
||||||
select {
|
if s.onDisconnect != nil {
|
||||||
case s.disconnectChan <- struct{}{}:
|
if err := s.onDisconnect(); err != nil {
|
||||||
// Signal sent successfully
|
http.Error(w, fmt.Sprintf("Disconnect failed: %v", err), http.StatusInternalServerError)
|
||||||
default:
|
return
|
||||||
// Channel already has a signal, don't block
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a success response
|
// Return a success response
|
||||||
|
|||||||
53
main.go
53
main.go
@@ -205,26 +205,41 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new olm.Config struct and copy values from the main config
|
// Create a new olm.Config struct and copy values from the main config
|
||||||
olmConfig := olm.Config{
|
olmConfig := olm.GlobalConfig{
|
||||||
Endpoint: config.Endpoint,
|
LogLevel: config.LogLevel,
|
||||||
ID: config.ID,
|
EnableAPI: config.EnableAPI,
|
||||||
Secret: config.Secret,
|
HTTPAddr: config.HTTPAddr,
|
||||||
UserToken: config.UserToken,
|
SocketPath: config.SocketPath,
|
||||||
MTU: config.MTU,
|
Version: config.Version,
|
||||||
DNS: config.DNS,
|
|
||||||
InterfaceName: config.InterfaceName,
|
|
||||||
LogLevel: config.LogLevel,
|
|
||||||
EnableAPI: config.EnableAPI,
|
|
||||||
HTTPAddr: config.HTTPAddr,
|
|
||||||
SocketPath: config.SocketPath,
|
|
||||||
Holepunch: config.Holepunch,
|
|
||||||
TlsClientCert: config.TlsClientCert,
|
|
||||||
PingIntervalDuration: config.PingIntervalDuration,
|
|
||||||
PingTimeoutDuration: config.PingTimeoutDuration,
|
|
||||||
Version: config.Version,
|
|
||||||
OrgID: config.OrgID,
|
|
||||||
// DoNotCreateNewClient: config.DoNotCreateNewClient,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
olm.Init(ctx, olmConfig)
|
olm.Init(ctx, olmConfig)
|
||||||
|
|
||||||
|
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
|
||||||
|
tunnelConfig := olm.TunnelConfig{
|
||||||
|
Endpoint: config.Endpoint,
|
||||||
|
ID: config.ID,
|
||||||
|
Secret: config.Secret,
|
||||||
|
UserToken: config.UserToken,
|
||||||
|
MTU: config.MTU,
|
||||||
|
DNS: config.DNS,
|
||||||
|
InterfaceName: config.InterfaceName,
|
||||||
|
Holepunch: config.Holepunch,
|
||||||
|
TlsClientCert: config.TlsClientCert,
|
||||||
|
PingIntervalDuration: config.PingIntervalDuration,
|
||||||
|
PingTimeoutDuration: config.PingTimeoutDuration,
|
||||||
|
OrgID: config.OrgID,
|
||||||
|
}
|
||||||
|
go olm.StartTunnel(tunnelConfig)
|
||||||
|
} else {
|
||||||
|
logger.Info("Incomplete tunnel configuration, not starting tunnel")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for context cancellation (from signals or API shutdown)
|
||||||
|
<-ctx.Done()
|
||||||
|
logger.Info("Shutdown signal received, cleaning up...")
|
||||||
|
|
||||||
|
// Clean up resources
|
||||||
|
olm.Close()
|
||||||
|
logger.Info("Shutdown complete")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ConfigureInterface configures a network interface with an IP address and brings it up
|
// ConfigureInterface configures a network interface with an IP address and brings it up
|
||||||
func ConfigureInterface(interfaceName string, wgData WgData) error {
|
func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error {
|
||||||
logger.Info("The tunnel IP is: %s", wgData.TunnelIP)
|
logger.Info("The tunnel IP is: %s", wgData.TunnelIP)
|
||||||
|
|
||||||
// Parse the IP address and network
|
// Parse the IP address and network
|
||||||
@@ -32,6 +32,7 @@ func ConfigureInterface(interfaceName string, wgData WgData) error {
|
|||||||
|
|
||||||
// network.SetTunnelRemoteAddress() // what does this do?
|
// network.SetTunnelRemoteAddress() // what does this do?
|
||||||
network.SetIPv4Settings([]string{destinationAddress}, []string{mask})
|
network.SetIPv4Settings([]string{destinationAddress}, []string{mask})
|
||||||
|
network.SetMTU(mtu)
|
||||||
apiServer.SetTunnelIP(destinationAddress)
|
apiServer.SetTunnelIP(destinationAddress)
|
||||||
|
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
|
|||||||
268
olm/olm.go
268
olm/olm.go
@@ -3,6 +3,7 @@ package olm
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
@@ -20,7 +21,21 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type GlobalConfig struct {
|
||||||
|
// Logging
|
||||||
|
LogLevel string
|
||||||
|
|
||||||
|
// HTTP server
|
||||||
|
EnableAPI bool
|
||||||
|
HTTPAddr string
|
||||||
|
SocketPath string
|
||||||
|
Version string
|
||||||
|
|
||||||
|
// Source tracking (not in JSON)
|
||||||
|
sources map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelConfig struct {
|
||||||
// Connection settings
|
// Connection settings
|
||||||
Endpoint string
|
Endpoint string
|
||||||
ID string
|
ID string
|
||||||
@@ -32,14 +47,6 @@ type Config struct {
|
|||||||
DNS string
|
DNS string
|
||||||
InterfaceName string
|
InterfaceName string
|
||||||
|
|
||||||
// Logging
|
|
||||||
LogLevel string
|
|
||||||
|
|
||||||
// HTTP server
|
|
||||||
EnableAPI bool
|
|
||||||
HTTPAddr string
|
|
||||||
SocketPath string
|
|
||||||
|
|
||||||
// Advanced
|
// Advanced
|
||||||
Holepunch bool
|
Holepunch bool
|
||||||
TlsClientCert string
|
TlsClientCert string
|
||||||
@@ -48,11 +55,7 @@ type Config struct {
|
|||||||
PingIntervalDuration time.Duration
|
PingIntervalDuration time.Duration
|
||||||
PingTimeoutDuration time.Duration
|
PingTimeoutDuration time.Duration
|
||||||
|
|
||||||
// Source tracking (not in JSON)
|
OrgID string
|
||||||
sources map[string]string
|
|
||||||
|
|
||||||
Version string
|
|
||||||
OrgID string
|
|
||||||
// DoNotCreateNewClient bool
|
// DoNotCreateNewClient bool
|
||||||
|
|
||||||
FileDescriptorTun uint32
|
FileDescriptorTun uint32
|
||||||
@@ -74,21 +77,21 @@ var (
|
|||||||
sharedBind *bind.SharedBind
|
sharedBind *bind.SharedBind
|
||||||
holePunchManager *holepunch.Manager
|
holePunchManager *holepunch.Manager
|
||||||
peerMonitor *peermonitor.PeerMonitor
|
peerMonitor *peermonitor.PeerMonitor
|
||||||
|
globalConfig GlobalConfig
|
||||||
|
globalCtx context.Context
|
||||||
stopRegister func()
|
stopRegister func()
|
||||||
stopPing chan struct{}
|
stopPing chan struct{}
|
||||||
)
|
)
|
||||||
|
|
||||||
func Init(ctx context.Context, config Config) {
|
func Init(ctx context.Context, config GlobalConfig) {
|
||||||
|
globalConfig = config
|
||||||
|
globalCtx = ctx
|
||||||
|
|
||||||
// Create a cancellable context for internal shutdown control
|
// Create a cancellable context for internal shutdown control
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
||||||
network.SetMTU(config.MTU)
|
|
||||||
|
|
||||||
if config.Holepunch {
|
|
||||||
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.HTTPAddr != "" {
|
if config.HTTPAddr != "" {
|
||||||
apiServer = api.NewAPI(config.HTTPAddr)
|
apiServer = api.NewAPI(config.HTTPAddr)
|
||||||
@@ -97,35 +100,15 @@ func Init(ctx context.Context, config Config) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
apiServer.SetVersion(config.Version)
|
apiServer.SetVersion(config.Version)
|
||||||
apiServer.SetOrgID(config.OrgID)
|
|
||||||
|
|
||||||
if err := apiServer.Start(); err != nil {
|
if err := apiServer.Start(); err != nil {
|
||||||
logger.Fatal("Failed to start HTTP server: %v", err)
|
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen for shutdown requests from the API
|
// Set up API handlers
|
||||||
go func() {
|
apiServer.SetHandlers(
|
||||||
<-apiServer.GetShutdownChannel()
|
// onConnect
|
||||||
logger.Info("Shutdown requested via API")
|
func(req api.ConnectionRequest) error {
|
||||||
// Cancel the context to trigger graceful shutdown
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var (
|
|
||||||
id = config.ID
|
|
||||||
secret = config.Secret
|
|
||||||
endpoint = config.Endpoint
|
|
||||||
userToken = config.UserToken
|
|
||||||
)
|
|
||||||
|
|
||||||
// Main event loop that handles connect, disconnect, and reconnect
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
logger.Info("Context cancelled while waiting for credentials")
|
|
||||||
goto shutdown
|
|
||||||
|
|
||||||
case req := <-apiServer.GetConnectionChannel():
|
|
||||||
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
||||||
|
|
||||||
// Stop any existing tunnel before starting a new one
|
// Stop any existing tunnel before starting a new one
|
||||||
@@ -134,67 +117,120 @@ func Init(ctx context.Context, config Config) {
|
|||||||
StopTunnel()
|
StopTunnel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the connection parameters
|
tunnelConfig := TunnelConfig{
|
||||||
id = req.ID
|
Endpoint: req.Endpoint,
|
||||||
secret = req.Secret
|
ID: req.ID,
|
||||||
endpoint = req.Endpoint
|
Secret: req.Secret,
|
||||||
userToken := req.UserToken
|
UserToken: req.UserToken,
|
||||||
|
MTU: req.MTU,
|
||||||
|
DNS: req.DNS,
|
||||||
|
InterfaceName: req.InterfaceName,
|
||||||
|
Holepunch: req.Holepunch,
|
||||||
|
TlsClientCert: req.TlsClientCert,
|
||||||
|
OrgID: req.OrgID,
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
// Parse ping interval
|
||||||
|
if req.PingInterval != "" {
|
||||||
|
tunnelConfig.PingIntervalDuration, err = time.ParseDuration(req.PingInterval)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Invalid PING_INTERVAL value: %s, using default 3 seconds", req.PingInterval)
|
||||||
|
tunnelConfig.PingIntervalDuration = 3 * time.Second
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tunnelConfig.PingIntervalDuration = 3 * time.Second
|
||||||
|
}
|
||||||
|
// Parse ping timeout
|
||||||
|
if req.PingTimeout != "" {
|
||||||
|
tunnelConfig.PingTimeoutDuration, err = time.ParseDuration(req.PingTimeout)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Invalid PING_TIMEOUT value: %s, using default 5 seconds", req.PingTimeout)
|
||||||
|
tunnelConfig.PingTimeoutDuration = 5 * time.Second
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tunnelConfig.PingTimeoutDuration = 5 * time.Second
|
||||||
|
}
|
||||||
|
if req.MTU == 0 {
|
||||||
|
tunnelConfig.MTU = 1420
|
||||||
|
}
|
||||||
|
if req.DNS == "" {
|
||||||
|
tunnelConfig.DNS = "9.9.9.9"
|
||||||
|
}
|
||||||
|
if req.InterfaceName == "" {
|
||||||
|
tunnelConfig.InterfaceName = "olm"
|
||||||
|
}
|
||||||
|
|
||||||
// Start the tunnel process with the new credentials
|
// Start the tunnel process with the new credentials
|
||||||
if id != "" && secret != "" && endpoint != "" {
|
if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" {
|
||||||
logger.Info("Starting tunnel with new credentials")
|
logger.Info("Starting tunnel with new credentials")
|
||||||
tunnelRunning = true
|
go StartTunnel(tunnelConfig)
|
||||||
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-apiServer.GetDisconnectChannel():
|
return nil
|
||||||
logger.Info("Received disconnect request via API")
|
},
|
||||||
|
// onSwitchOrg
|
||||||
|
func(req api.SwitchOrgRequest) error {
|
||||||
|
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
||||||
|
|
||||||
|
// Ensure we have an active olmClient
|
||||||
|
if olmClient == nil {
|
||||||
|
return fmt.Errorf("no active connection to switch organizations")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the orgID in the API server
|
||||||
|
apiServer.SetOrgID(req.OrgID)
|
||||||
|
|
||||||
|
// Mark as not connected to trigger re-registration
|
||||||
|
connected = false
|
||||||
|
|
||||||
|
Close()
|
||||||
|
|
||||||
|
// Clear peer statuses in API
|
||||||
|
apiServer.SetRegistered(false)
|
||||||
|
apiServer.SetTunnelIP("")
|
||||||
|
|
||||||
|
// Trigger re-registration with new orgId
|
||||||
|
logger.Info("Re-registering with new orgId: %s", req.OrgID)
|
||||||
|
publicKey := privateKey.PublicKey()
|
||||||
|
stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
"publicKey": publicKey.String(),
|
||||||
|
"relay": true, // Default to relay mode for org switch
|
||||||
|
"olmVersion": globalConfig.Version,
|
||||||
|
"orgId": req.OrgID,
|
||||||
|
}, 1*time.Second)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
// onDisconnect
|
||||||
|
func() error {
|
||||||
|
logger.Info("Processing disconnect request via API")
|
||||||
StopTunnel()
|
StopTunnel()
|
||||||
// Clear credentials so we wait for new connect call
|
return nil
|
||||||
id = ""
|
},
|
||||||
secret = ""
|
// onExit
|
||||||
endpoint = ""
|
func() error {
|
||||||
userToken = ""
|
logger.Info("Processing shutdown request via API")
|
||||||
|
cancel()
|
||||||
default:
|
return nil
|
||||||
// If we have credentials and no tunnel is running, start it
|
},
|
||||||
if id != "" && secret != "" && endpoint != "" && !tunnelRunning {
|
)
|
||||||
logger.Info("Starting tunnel process with initial credentials")
|
|
||||||
tunnelRunning = true
|
|
||||||
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
|
|
||||||
} else if id == "" || secret == "" || endpoint == "" {
|
|
||||||
// If we don't have credentials, check if API is enabled
|
|
||||||
if !config.EnableAPI {
|
|
||||||
missing := []string{}
|
|
||||||
if id == "" {
|
|
||||||
missing = append(missing, "id")
|
|
||||||
}
|
|
||||||
if secret == "" {
|
|
||||||
missing = append(missing, "secret")
|
|
||||||
}
|
|
||||||
if endpoint == "" {
|
|
||||||
missing = append(missing, "endpoint")
|
|
||||||
}
|
|
||||||
// exit the application because there is no way to provide the missing parameters
|
|
||||||
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
|
|
||||||
goto shutdown
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sleep briefly to prevent tight loop
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
shutdown:
|
|
||||||
Close()
|
|
||||||
apiServer.Stop()
|
|
||||||
logger.Info("Olm service shutting down")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartTunnel(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) {
|
func StartTunnel(config TunnelConfig) {
|
||||||
|
if tunnelRunning {
|
||||||
|
logger.Info("Tunnel already running")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelRunning = true // Also set it here in case it is called externally
|
||||||
|
|
||||||
|
if config.Holepunch {
|
||||||
|
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||||
|
}
|
||||||
|
|
||||||
// Create a cancellable context for this tunnel process
|
// Create a cancellable context for this tunnel process
|
||||||
tunnelCtx, cancel := context.WithCancel(ctx)
|
tunnelCtx, cancel := context.WithCancel(globalCtx)
|
||||||
tunnelCancel = cancel
|
tunnelCancel = cancel
|
||||||
defer func() {
|
defer func() {
|
||||||
tunnelCancel = nil
|
tunnelCancel = nil
|
||||||
@@ -205,8 +241,14 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
interfaceName = config.InterfaceName
|
interfaceName = config.InterfaceName
|
||||||
|
id = config.ID
|
||||||
|
secret = config.Secret
|
||||||
|
endpoint = config.Endpoint
|
||||||
|
userToken = config.UserToken
|
||||||
)
|
)
|
||||||
|
|
||||||
|
apiServer.SetOrgID(config.OrgID)
|
||||||
|
|
||||||
// Create a new olm client using the provided credentials
|
// Create a new olm client using the provided credentials
|
||||||
olm, err := websocket.NewClient(
|
olm, err := websocket.NewClient(
|
||||||
id, // Use provided ID
|
id, // Use provided ID
|
||||||
@@ -431,7 +473,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
|||||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil {
|
||||||
logger.Error("Failed to configure interface: %v", err)
|
logger.Error("Failed to configure interface: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -753,7 +795,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
|||||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"relay": !config.Holepunch,
|
"relay": !config.Holepunch,
|
||||||
"olmVersion": config.Version,
|
"olmVersion": globalConfig.Version,
|
||||||
"orgId": config.OrgID,
|
"orgId": config.OrgID,
|
||||||
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||||
}, 1*time.Second)
|
}, 1*time.Second)
|
||||||
@@ -777,36 +819,6 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
|||||||
}
|
}
|
||||||
defer olm.Close()
|
defer olm.Close()
|
||||||
|
|
||||||
// Listen for org switch requests from the API
|
|
||||||
go func() {
|
|
||||||
for req := range apiServer.GetSwitchOrgChannel() {
|
|
||||||
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
|
||||||
|
|
||||||
// Update the config with the new orgId
|
|
||||||
config.OrgID = req.OrgID
|
|
||||||
|
|
||||||
// Mark as not connected to trigger re-registration
|
|
||||||
connected = false
|
|
||||||
|
|
||||||
Close()
|
|
||||||
|
|
||||||
// Clear peer statuses in API
|
|
||||||
apiServer.SetRegistered(false)
|
|
||||||
apiServer.SetTunnelIP("")
|
|
||||||
apiServer.SetOrgID(config.OrgID)
|
|
||||||
|
|
||||||
// Trigger re-registration with new orgId
|
|
||||||
logger.Info("Re-registering with new orgId: %s", config.OrgID)
|
|
||||||
publicKey := privateKey.PublicKey()
|
|
||||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
|
||||||
"publicKey": publicKey.String(),
|
|
||||||
"relay": !config.Holepunch,
|
|
||||||
"olmVersion": config.Version,
|
|
||||||
"orgId": config.OrgID,
|
|
||||||
}, 1*time.Second)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait for context cancellation
|
// Wait for context cancellation
|
||||||
<-tunnelCtx.Done()
|
<-tunnelCtx.Done()
|
||||||
logger.Info("Tunnel process context cancelled, cleaning up")
|
logger.Info("Tunnel process context cancelled, cleaning up")
|
||||||
|
|||||||
Reference in New Issue
Block a user