mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Split up concerns so parent can call start and stop
This commit is contained in:
144
api/api.go
144
api/api.go
@@ -13,10 +13,18 @@ import (
|
||||
|
||||
// ConnectionRequest defines the structure for an incoming connection request
|
||||
type ConnectionRequest struct {
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
UserToken string `json:"userToken,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
UserToken string `json:"userToken,omitempty"`
|
||||
MTU int `json:"mtu,omitempty"`
|
||||
DNS string `json:"dns,omitempty"`
|
||||
InterfaceName string `json:"interfaceName,omitempty"`
|
||||
Holepunch bool `json:"holepunch,omitempty"`
|
||||
TlsClientCert string `json:"tlsClientCert,omitempty"`
|
||||
PingInterval string `json:"pingInterval,omitempty"`
|
||||
PingTimeout string `json:"pingTimeout,omitempty"`
|
||||
OrgID string `json:"orgId,omitempty"`
|
||||
}
|
||||
|
||||
// SwitchOrgRequest defines the structure for switching organizations
|
||||
@@ -47,33 +55,29 @@ type StatusResponse struct {
|
||||
|
||||
// API represents the HTTP server and its state
|
||||
type API struct {
|
||||
addr string
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
connectionChan chan ConnectionRequest
|
||||
switchOrgChan chan SwitchOrgRequest
|
||||
shutdownChan chan struct{}
|
||||
disconnectChan chan struct{}
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
connectedAt time.Time
|
||||
isConnected bool
|
||||
isRegistered bool
|
||||
tunnelIP string
|
||||
version string
|
||||
orgID string
|
||||
addr string
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
onConnect func(ConnectionRequest) error
|
||||
onSwitchOrg func(SwitchOrgRequest) error
|
||||
onDisconnect func() error
|
||||
onExit func() error
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
connectedAt time.Time
|
||||
isConnected bool
|
||||
isRegistered bool
|
||||
tunnelIP string
|
||||
version string
|
||||
orgID string
|
||||
}
|
||||
|
||||
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||
func NewAPI(addr string) *API {
|
||||
s := &API{
|
||||
addr: addr,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||
shutdownChan: make(chan struct{}, 1),
|
||||
disconnectChan: make(chan struct{}, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
addr: addr,
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -82,17 +86,26 @@ func NewAPI(addr string) *API {
|
||||
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
|
||||
func NewAPISocket(socketPath string) *API {
|
||||
s := &API{
|
||||
socketPath: socketPath,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||
shutdownChan: make(chan struct{}, 1),
|
||||
disconnectChan: make(chan struct{}, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
socketPath: socketPath,
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SetHandlers sets the callback functions for handling API requests
|
||||
func (s *API) SetHandlers(
|
||||
onConnect func(ConnectionRequest) error,
|
||||
onSwitchOrg func(SwitchOrgRequest) error,
|
||||
onDisconnect func() error,
|
||||
onExit func() error,
|
||||
) {
|
||||
s.onConnect = onConnect
|
||||
s.onSwitchOrg = onSwitchOrg
|
||||
s.onDisconnect = onDisconnect
|
||||
s.onExit = onExit
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *API) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
@@ -149,26 +162,6 @@ func (s *API) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConnectionChannel returns the channel for receiving connection requests
|
||||
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
|
||||
return s.connectionChan
|
||||
}
|
||||
|
||||
// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
||||
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
||||
return s.switchOrgChan
|
||||
}
|
||||
|
||||
// GetShutdownChannel returns the channel for receiving shutdown requests
|
||||
func (s *API) GetShutdownChannel() <-chan struct{} {
|
||||
return s.shutdownChan
|
||||
}
|
||||
|
||||
// GetDisconnectChannel returns the channel for receiving disconnect requests
|
||||
func (s *API) GetDisconnectChannel() <-chan struct{} {
|
||||
return s.disconnectChan
|
||||
}
|
||||
|
||||
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
@@ -277,8 +270,13 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Send the request to the main goroutine
|
||||
s.connectionChan <- req
|
||||
// Call the connect handler if set
|
||||
if s.onConnect != nil {
|
||||
if err := s.onConnect(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Connection failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -320,12 +318,12 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
logger.Info("Received exit request via API")
|
||||
|
||||
// Send shutdown signal
|
||||
select {
|
||||
case s.shutdownChan <- struct{}{}:
|
||||
// Signal sent successfully
|
||||
default:
|
||||
// Channel already has a signal, don't block
|
||||
// Call the exit handler if set
|
||||
if s.onExit != nil {
|
||||
if err := s.onExit(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Exit failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
@@ -358,14 +356,12 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||
|
||||
// Send the request to the main goroutine
|
||||
select {
|
||||
case s.switchOrgChan <- req:
|
||||
// Signal sent successfully
|
||||
default:
|
||||
// Channel already has a pending request
|
||||
http.Error(w, "Org switch already in progress", http.StatusConflict)
|
||||
return
|
||||
// Call the switch org handler if set
|
||||
if s.onSwitchOrg != nil {
|
||||
if err := s.onSwitchOrg(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Org switch failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
@@ -394,12 +390,12 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
logger.Info("Received disconnect request via API")
|
||||
|
||||
// Send disconnect signal
|
||||
select {
|
||||
case s.disconnectChan <- struct{}{}:
|
||||
// Signal sent successfully
|
||||
default:
|
||||
// Channel already has a signal, don't block
|
||||
// Call the disconnect handler if set
|
||||
if s.onDisconnect != nil {
|
||||
if err := s.onDisconnect(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Disconnect failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
|
||||
53
main.go
53
main.go
@@ -205,26 +205,41 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
}
|
||||
|
||||
// Create a new olm.Config struct and copy values from the main config
|
||||
olmConfig := olm.Config{
|
||||
Endpoint: config.Endpoint,
|
||||
ID: config.ID,
|
||||
Secret: config.Secret,
|
||||
UserToken: config.UserToken,
|
||||
MTU: config.MTU,
|
||||
DNS: config.DNS,
|
||||
InterfaceName: config.InterfaceName,
|
||||
LogLevel: config.LogLevel,
|
||||
EnableAPI: config.EnableAPI,
|
||||
HTTPAddr: config.HTTPAddr,
|
||||
SocketPath: config.SocketPath,
|
||||
Holepunch: config.Holepunch,
|
||||
TlsClientCert: config.TlsClientCert,
|
||||
PingIntervalDuration: config.PingIntervalDuration,
|
||||
PingTimeoutDuration: config.PingTimeoutDuration,
|
||||
Version: config.Version,
|
||||
OrgID: config.OrgID,
|
||||
// DoNotCreateNewClient: config.DoNotCreateNewClient,
|
||||
olmConfig := olm.GlobalConfig{
|
||||
LogLevel: config.LogLevel,
|
||||
EnableAPI: config.EnableAPI,
|
||||
HTTPAddr: config.HTTPAddr,
|
||||
SocketPath: config.SocketPath,
|
||||
Version: config.Version,
|
||||
}
|
||||
|
||||
olm.Init(ctx, olmConfig)
|
||||
|
||||
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
|
||||
tunnelConfig := olm.TunnelConfig{
|
||||
Endpoint: config.Endpoint,
|
||||
ID: config.ID,
|
||||
Secret: config.Secret,
|
||||
UserToken: config.UserToken,
|
||||
MTU: config.MTU,
|
||||
DNS: config.DNS,
|
||||
InterfaceName: config.InterfaceName,
|
||||
Holepunch: config.Holepunch,
|
||||
TlsClientCert: config.TlsClientCert,
|
||||
PingIntervalDuration: config.PingIntervalDuration,
|
||||
PingTimeoutDuration: config.PingTimeoutDuration,
|
||||
OrgID: config.OrgID,
|
||||
}
|
||||
go olm.StartTunnel(tunnelConfig)
|
||||
} else {
|
||||
logger.Info("Incomplete tunnel configuration, not starting tunnel")
|
||||
}
|
||||
|
||||
// Wait for context cancellation (from signals or API shutdown)
|
||||
<-ctx.Done()
|
||||
logger.Info("Shutdown signal received, cleaning up...")
|
||||
|
||||
// Clean up resources
|
||||
olm.Close()
|
||||
logger.Info("Shutdown complete")
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
// ConfigureInterface configures a network interface with an IP address and brings it up
|
||||
func ConfigureInterface(interfaceName string, wgData WgData) error {
|
||||
func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error {
|
||||
logger.Info("The tunnel IP is: %s", wgData.TunnelIP)
|
||||
|
||||
// Parse the IP address and network
|
||||
@@ -32,6 +32,7 @@ func ConfigureInterface(interfaceName string, wgData WgData) error {
|
||||
|
||||
// network.SetTunnelRemoteAddress() // what does this do?
|
||||
network.SetIPv4Settings([]string{destinationAddress}, []string{mask})
|
||||
network.SetMTU(mtu)
|
||||
apiServer.SetTunnelIP(destinationAddress)
|
||||
|
||||
if interfaceName == "" {
|
||||
|
||||
268
olm/olm.go
268
olm/olm.go
@@ -3,6 +3,7 @@ package olm
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
@@ -20,7 +21,21 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
type GlobalConfig struct {
|
||||
// Logging
|
||||
LogLevel string
|
||||
|
||||
// HTTP server
|
||||
EnableAPI bool
|
||||
HTTPAddr string
|
||||
SocketPath string
|
||||
Version string
|
||||
|
||||
// Source tracking (not in JSON)
|
||||
sources map[string]string
|
||||
}
|
||||
|
||||
type TunnelConfig struct {
|
||||
// Connection settings
|
||||
Endpoint string
|
||||
ID string
|
||||
@@ -32,14 +47,6 @@ type Config struct {
|
||||
DNS string
|
||||
InterfaceName string
|
||||
|
||||
// Logging
|
||||
LogLevel string
|
||||
|
||||
// HTTP server
|
||||
EnableAPI bool
|
||||
HTTPAddr string
|
||||
SocketPath string
|
||||
|
||||
// Advanced
|
||||
Holepunch bool
|
||||
TlsClientCert string
|
||||
@@ -48,11 +55,7 @@ type Config struct {
|
||||
PingIntervalDuration time.Duration
|
||||
PingTimeoutDuration time.Duration
|
||||
|
||||
// Source tracking (not in JSON)
|
||||
sources map[string]string
|
||||
|
||||
Version string
|
||||
OrgID string
|
||||
OrgID string
|
||||
// DoNotCreateNewClient bool
|
||||
|
||||
FileDescriptorTun uint32
|
||||
@@ -74,21 +77,21 @@ var (
|
||||
sharedBind *bind.SharedBind
|
||||
holePunchManager *holepunch.Manager
|
||||
peerMonitor *peermonitor.PeerMonitor
|
||||
globalConfig GlobalConfig
|
||||
globalCtx context.Context
|
||||
stopRegister func()
|
||||
stopPing chan struct{}
|
||||
)
|
||||
|
||||
func Init(ctx context.Context, config Config) {
|
||||
func Init(ctx context.Context, config GlobalConfig) {
|
||||
globalConfig = config
|
||||
globalCtx = ctx
|
||||
|
||||
// Create a cancellable context for internal shutdown control
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
||||
network.SetMTU(config.MTU)
|
||||
|
||||
if config.Holepunch {
|
||||
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||
}
|
||||
|
||||
if config.HTTPAddr != "" {
|
||||
apiServer = api.NewAPI(config.HTTPAddr)
|
||||
@@ -97,35 +100,15 @@ func Init(ctx context.Context, config Config) {
|
||||
}
|
||||
|
||||
apiServer.SetVersion(config.Version)
|
||||
apiServer.SetOrgID(config.OrgID)
|
||||
|
||||
if err := apiServer.Start(); err != nil {
|
||||
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||
}
|
||||
|
||||
// Listen for shutdown requests from the API
|
||||
go func() {
|
||||
<-apiServer.GetShutdownChannel()
|
||||
logger.Info("Shutdown requested via API")
|
||||
// Cancel the context to trigger graceful shutdown
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var (
|
||||
id = config.ID
|
||||
secret = config.Secret
|
||||
endpoint = config.Endpoint
|
||||
userToken = config.UserToken
|
||||
)
|
||||
|
||||
// Main event loop that handles connect, disconnect, and reconnect
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info("Context cancelled while waiting for credentials")
|
||||
goto shutdown
|
||||
|
||||
case req := <-apiServer.GetConnectionChannel():
|
||||
// Set up API handlers
|
||||
apiServer.SetHandlers(
|
||||
// onConnect
|
||||
func(req api.ConnectionRequest) error {
|
||||
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
|
||||
|
||||
// Stop any existing tunnel before starting a new one
|
||||
@@ -134,67 +117,120 @@ func Init(ctx context.Context, config Config) {
|
||||
StopTunnel()
|
||||
}
|
||||
|
||||
// Set the connection parameters
|
||||
id = req.ID
|
||||
secret = req.Secret
|
||||
endpoint = req.Endpoint
|
||||
userToken := req.UserToken
|
||||
tunnelConfig := TunnelConfig{
|
||||
Endpoint: req.Endpoint,
|
||||
ID: req.ID,
|
||||
Secret: req.Secret,
|
||||
UserToken: req.UserToken,
|
||||
MTU: req.MTU,
|
||||
DNS: req.DNS,
|
||||
InterfaceName: req.InterfaceName,
|
||||
Holepunch: req.Holepunch,
|
||||
TlsClientCert: req.TlsClientCert,
|
||||
OrgID: req.OrgID,
|
||||
}
|
||||
|
||||
var err error
|
||||
// Parse ping interval
|
||||
if req.PingInterval != "" {
|
||||
tunnelConfig.PingIntervalDuration, err = time.ParseDuration(req.PingInterval)
|
||||
if err != nil {
|
||||
logger.Warn("Invalid PING_INTERVAL value: %s, using default 3 seconds", req.PingInterval)
|
||||
tunnelConfig.PingIntervalDuration = 3 * time.Second
|
||||
}
|
||||
} else {
|
||||
tunnelConfig.PingIntervalDuration = 3 * time.Second
|
||||
}
|
||||
// Parse ping timeout
|
||||
if req.PingTimeout != "" {
|
||||
tunnelConfig.PingTimeoutDuration, err = time.ParseDuration(req.PingTimeout)
|
||||
if err != nil {
|
||||
logger.Warn("Invalid PING_TIMEOUT value: %s, using default 5 seconds", req.PingTimeout)
|
||||
tunnelConfig.PingTimeoutDuration = 5 * time.Second
|
||||
}
|
||||
} else {
|
||||
tunnelConfig.PingTimeoutDuration = 5 * time.Second
|
||||
}
|
||||
if req.MTU == 0 {
|
||||
tunnelConfig.MTU = 1420
|
||||
}
|
||||
if req.DNS == "" {
|
||||
tunnelConfig.DNS = "9.9.9.9"
|
||||
}
|
||||
if req.InterfaceName == "" {
|
||||
tunnelConfig.InterfaceName = "olm"
|
||||
}
|
||||
|
||||
// Start the tunnel process with the new credentials
|
||||
if id != "" && secret != "" && endpoint != "" {
|
||||
if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" {
|
||||
logger.Info("Starting tunnel with new credentials")
|
||||
tunnelRunning = true
|
||||
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
|
||||
go StartTunnel(tunnelConfig)
|
||||
}
|
||||
|
||||
case <-apiServer.GetDisconnectChannel():
|
||||
logger.Info("Received disconnect request via API")
|
||||
return nil
|
||||
},
|
||||
// onSwitchOrg
|
||||
func(req api.SwitchOrgRequest) error {
|
||||
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
||||
|
||||
// Ensure we have an active olmClient
|
||||
if olmClient == nil {
|
||||
return fmt.Errorf("no active connection to switch organizations")
|
||||
}
|
||||
|
||||
// Update the orgID in the API server
|
||||
apiServer.SetOrgID(req.OrgID)
|
||||
|
||||
// Mark as not connected to trigger re-registration
|
||||
connected = false
|
||||
|
||||
Close()
|
||||
|
||||
// Clear peer statuses in API
|
||||
apiServer.SetRegistered(false)
|
||||
apiServer.SetTunnelIP("")
|
||||
|
||||
// Trigger re-registration with new orgId
|
||||
logger.Info("Re-registering with new orgId: %s", req.OrgID)
|
||||
publicKey := privateKey.PublicKey()
|
||||
stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": true, // Default to relay mode for org switch
|
||||
"olmVersion": globalConfig.Version,
|
||||
"orgId": req.OrgID,
|
||||
}, 1*time.Second)
|
||||
|
||||
return nil
|
||||
},
|
||||
// onDisconnect
|
||||
func() error {
|
||||
logger.Info("Processing disconnect request via API")
|
||||
StopTunnel()
|
||||
// Clear credentials so we wait for new connect call
|
||||
id = ""
|
||||
secret = ""
|
||||
endpoint = ""
|
||||
userToken = ""
|
||||
|
||||
default:
|
||||
// If we have credentials and no tunnel is running, start it
|
||||
if id != "" && secret != "" && endpoint != "" && !tunnelRunning {
|
||||
logger.Info("Starting tunnel process with initial credentials")
|
||||
tunnelRunning = true
|
||||
go StartTunnel(ctx, config, id, secret, userToken, endpoint)
|
||||
} else if id == "" || secret == "" || endpoint == "" {
|
||||
// If we don't have credentials, check if API is enabled
|
||||
if !config.EnableAPI {
|
||||
missing := []string{}
|
||||
if id == "" {
|
||||
missing = append(missing, "id")
|
||||
}
|
||||
if secret == "" {
|
||||
missing = append(missing, "secret")
|
||||
}
|
||||
if endpoint == "" {
|
||||
missing = append(missing, "endpoint")
|
||||
}
|
||||
// exit the application because there is no way to provide the missing parameters
|
||||
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
|
||||
goto shutdown
|
||||
}
|
||||
}
|
||||
|
||||
// Sleep briefly to prevent tight loop
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
shutdown:
|
||||
Close()
|
||||
apiServer.Stop()
|
||||
logger.Info("Olm service shutting down")
|
||||
return nil
|
||||
},
|
||||
// onExit
|
||||
func() error {
|
||||
logger.Info("Processing shutdown request via API")
|
||||
cancel()
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func StartTunnel(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) {
|
||||
func StartTunnel(config TunnelConfig) {
|
||||
if tunnelRunning {
|
||||
logger.Info("Tunnel already running")
|
||||
return
|
||||
}
|
||||
|
||||
tunnelRunning = true // Also set it here in case it is called externally
|
||||
|
||||
if config.Holepunch {
|
||||
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||
}
|
||||
|
||||
// Create a cancellable context for this tunnel process
|
||||
tunnelCtx, cancel := context.WithCancel(ctx)
|
||||
tunnelCtx, cancel := context.WithCancel(globalCtx)
|
||||
tunnelCancel = cancel
|
||||
defer func() {
|
||||
tunnelCancel = nil
|
||||
@@ -205,8 +241,14 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
||||
|
||||
var (
|
||||
interfaceName = config.InterfaceName
|
||||
id = config.ID
|
||||
secret = config.Secret
|
||||
endpoint = config.Endpoint
|
||||
userToken = config.UserToken
|
||||
)
|
||||
|
||||
apiServer.SetOrgID(config.OrgID)
|
||||
|
||||
// Create a new olm client using the provided credentials
|
||||
olm, err := websocket.NewClient(
|
||||
id, // Use provided ID
|
||||
@@ -431,7 +473,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||
if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil {
|
||||
logger.Error("Failed to configure interface: %v", err)
|
||||
}
|
||||
|
||||
@@ -753,7 +795,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": !config.Holepunch,
|
||||
"olmVersion": config.Version,
|
||||
"olmVersion": globalConfig.Version,
|
||||
"orgId": config.OrgID,
|
||||
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||
}, 1*time.Second)
|
||||
@@ -777,36 +819,6 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u
|
||||
}
|
||||
defer olm.Close()
|
||||
|
||||
// Listen for org switch requests from the API
|
||||
go func() {
|
||||
for req := range apiServer.GetSwitchOrgChannel() {
|
||||
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
||||
|
||||
// Update the config with the new orgId
|
||||
config.OrgID = req.OrgID
|
||||
|
||||
// Mark as not connected to trigger re-registration
|
||||
connected = false
|
||||
|
||||
Close()
|
||||
|
||||
// Clear peer statuses in API
|
||||
apiServer.SetRegistered(false)
|
||||
apiServer.SetTunnelIP("")
|
||||
apiServer.SetOrgID(config.OrgID)
|
||||
|
||||
// Trigger re-registration with new orgId
|
||||
logger.Info("Re-registering with new orgId: %s", config.OrgID)
|
||||
publicKey := privateKey.PublicKey()
|
||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": !config.Holepunch,
|
||||
"olmVersion": config.Version,
|
||||
"orgId": config.OrgID,
|
||||
}, 1*time.Second)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for context cancellation
|
||||
<-tunnelCtx.Done()
|
||||
logger.Info("Tunnel process context cancelled, cleaning up")
|
||||
|
||||
Reference in New Issue
Block a user