mirror of
https://github.com/fosrl/olm.git
synced 2026-02-28 07:46:41 +00:00
Fixing endpoint handling
This commit is contained in:
11
api/api.go
11
api/api.go
@@ -53,6 +53,7 @@ type StatusResponse struct {
|
|||||||
Registered bool `json:"registered"`
|
Registered bool `json:"registered"`
|
||||||
Terminated bool `json:"terminated"`
|
Terminated bool `json:"terminated"`
|
||||||
Version string `json:"version,omitempty"`
|
Version string `json:"version,omitempty"`
|
||||||
|
Agent string `json:"agent,omitempty"`
|
||||||
OrgID string `json:"orgId,omitempty"`
|
OrgID string `json:"orgId,omitempty"`
|
||||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||||
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
|
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
|
||||||
@@ -75,6 +76,7 @@ type API struct {
|
|||||||
isRegistered bool
|
isRegistered bool
|
||||||
isTerminated bool
|
isTerminated bool
|
||||||
version string
|
version string
|
||||||
|
agent string
|
||||||
orgID string
|
orgID string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,6 +231,13 @@ func (s *API) SetVersion(version string) {
|
|||||||
s.version = version
|
s.version = version
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAgent sets the olm agent
|
||||||
|
func (s *API) SetAgent(agent string) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.agent = agent
|
||||||
|
}
|
||||||
|
|
||||||
// SetOrgID sets the organization ID
|
// SetOrgID sets the organization ID
|
||||||
func (s *API) SetOrgID(orgID string) {
|
func (s *API) SetOrgID(orgID string) {
|
||||||
s.statusMu.Lock()
|
s.statusMu.Lock()
|
||||||
@@ -329,6 +338,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
Registered: s.isRegistered,
|
Registered: s.isRegistered,
|
||||||
Terminated: s.isTerminated,
|
Terminated: s.isTerminated,
|
||||||
Version: s.version,
|
Version: s.version,
|
||||||
|
Agent: s.agent,
|
||||||
OrgID: s.orgID,
|
OrgID: s.orgID,
|
||||||
PeerStatuses: s.peerStatuses,
|
PeerStatuses: s.peerStatuses,
|
||||||
NetworkSettings: network.GetSettings(),
|
NetworkSettings: network.GetSettings(),
|
||||||
@@ -458,6 +468,7 @@ func (s *API) GetStatus() StatusResponse {
|
|||||||
Registered: s.isRegistered,
|
Registered: s.isRegistered,
|
||||||
Terminated: s.isTerminated,
|
Terminated: s.isTerminated,
|
||||||
Version: s.version,
|
Version: s.version,
|
||||||
|
Agent: s.agent,
|
||||||
OrgID: s.orgID,
|
OrgID: s.orgID,
|
||||||
PeerStatuses: s.peerStatuses,
|
PeerStatuses: s.peerStatuses,
|
||||||
NetworkSettings: network.GetSettings(),
|
NetworkSettings: network.GetSettings(),
|
||||||
|
|||||||
@@ -537,7 +537,7 @@ func SaveConfig(config *OlmConfig) error {
|
|||||||
func (c *OlmConfig) ShowConfig() {
|
func (c *OlmConfig) ShowConfig() {
|
||||||
configPath := getOlmConfigPath()
|
configPath := getOlmConfigPath()
|
||||||
|
|
||||||
fmt.Println("\n=== Olm Configuration ===\n")
|
fmt.Print("\n=== Olm Configuration ===\n\n")
|
||||||
fmt.Printf("Config File: %s\n", configPath)
|
fmt.Printf("Config File: %s\n", configPath)
|
||||||
|
|
||||||
// Check if config file exists
|
// Check if config file exists
|
||||||
@@ -548,7 +548,7 @@ func (c *OlmConfig) ShowConfig() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("\n--- Configuration Values ---")
|
fmt.Println("\n--- Configuration Values ---")
|
||||||
fmt.Println("(Format: Setting = Value [source])\n")
|
fmt.Print("(Format: Setting = Value [source])\n\n")
|
||||||
|
|
||||||
// Helper to get source or default
|
// Helper to get source or default
|
||||||
getSource := func(key string) string {
|
getSource := func(key string) string {
|
||||||
|
|||||||
3
main.go
3
main.go
@@ -194,7 +194,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
|
|||||||
fmt.Println("Olm version " + olmVersion)
|
fmt.Println("Olm version " + olmVersion)
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
logger.Info("Olm version " + olmVersion)
|
logger.Info("Olm version %s", olmVersion)
|
||||||
|
|
||||||
config.Version = olmVersion
|
config.Version = olmVersion
|
||||||
|
|
||||||
@@ -215,6 +215,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
|
|||||||
HTTPAddr: config.HTTPAddr,
|
HTTPAddr: config.HTTPAddr,
|
||||||
SocketPath: config.SocketPath,
|
SocketPath: config.SocketPath,
|
||||||
Version: config.Version,
|
Version: config.Version,
|
||||||
|
Agent: "olm-cli",
|
||||||
OnExit: cancel, // Pass cancel function directly to trigger shutdown
|
OnExit: cancel, // Pass cancel function directly to trigger shutdown
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
11
olm/olm.go
11
olm/olm.go
@@ -106,6 +106,7 @@ func Init(ctx context.Context, config GlobalConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
apiServer.SetVersion(config.Version)
|
apiServer.SetVersion(config.Version)
|
||||||
|
apiServer.SetAgent(config.Agent)
|
||||||
|
|
||||||
// Set up API handlers
|
// Set up API handlers
|
||||||
apiServer.SetHandlers(
|
apiServer.SetHandlers(
|
||||||
@@ -228,7 +229,6 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
interfaceName = config.InterfaceName
|
interfaceName = config.InterfaceName
|
||||||
id = config.ID
|
id = config.ID
|
||||||
secret = config.Secret
|
secret = config.Secret
|
||||||
endpoint = config.Endpoint
|
|
||||||
userToken = config.UserToken
|
userToken = config.UserToken
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -240,7 +240,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
secret, // Use provided secret
|
secret, // Use provided secret
|
||||||
userToken, // Use provided user token OPTIONAL
|
userToken, // Use provided user token OPTIONAL
|
||||||
config.OrgID,
|
config.OrgID,
|
||||||
endpoint, // Use provided endpoint
|
config.Endpoint, // Use provided endpoint
|
||||||
config.PingIntervalDuration,
|
config.PingIntervalDuration,
|
||||||
config.PingTimeoutDuration,
|
config.PingTimeoutDuration,
|
||||||
)
|
)
|
||||||
@@ -417,7 +417,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
apiServer.UpdatePeerStatus(site.SiteId, false, 0, siteEndpoint, false)
|
apiServer.UpdatePeerStatus(site.SiteId, false, 0, siteEndpoint, false)
|
||||||
|
|
||||||
if err := peerManager.AddPeer(site, siteEndpoint); err != nil {
|
if err := peerManager.AddPeer(site); err != nil {
|
||||||
logger.Error("Failed to add peer: %v", err)
|
logger.Error("Failed to add peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -495,7 +495,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
siteConfig.RemoteSubnets = updateData.RemoteSubnets
|
siteConfig.RemoteSubnets = updateData.RemoteSubnets
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := peerManager.UpdatePeer(siteConfig, endpoint); err != nil {
|
if err := peerManager.UpdatePeer(siteConfig); err != nil {
|
||||||
logger.Error("Failed to update peer: %v", err)
|
logger.Error("Failed to update peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -527,7 +527,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||||
|
|
||||||
if err := peerManager.AddPeer(siteConfig, endpoint); err != nil {
|
if err := peerManager.AddPeer(siteConfig); err != nil {
|
||||||
logger.Error("Failed to add peer: %v", err)
|
logger.Error("Failed to add peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -822,6 +822,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"relay": !config.Holepunch,
|
"relay": !config.Holepunch,
|
||||||
"olmVersion": globalConfig.Version,
|
"olmVersion": globalConfig.Version,
|
||||||
|
"olmAgent": globalConfig.Agent,
|
||||||
"orgId": config.OrgID,
|
"orgId": config.OrgID,
|
||||||
"userToken": userToken,
|
"userToken": userToken,
|
||||||
}, 1*time.Second)
|
}, 1*time.Second)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type GlobalConfig struct {
|
|||||||
HTTPAddr string
|
HTTPAddr string
|
||||||
SocketPath string
|
SocketPath string
|
||||||
Version string
|
Version string
|
||||||
|
Agent string
|
||||||
|
|
||||||
// Callbacks
|
// Callbacks
|
||||||
OnRegistered func()
|
OnRegistered func()
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func (pm *PeerManager) GetAllPeers() []SiteConfig {
|
|||||||
return peers
|
return peers
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *PeerManager) AddPeer(siteConfig SiteConfig, endpoint string) error {
|
func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig, endpoint string) error {
|
|||||||
wgConfig := siteConfig
|
wgConfig := siteConfig
|
||||||
wgConfig.AllowedIps = ownedIPs
|
wgConfig.AllowedIps = ownedIPs
|
||||||
|
|
||||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint); err != nil {
|
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,7 +211,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
|
|||||||
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||||
wgConfig := promotedPeer
|
wgConfig := promotedPeer
|
||||||
wgConfig.AllowedIps = ownedIPs
|
wgConfig.AllowedIps = ownedIPs
|
||||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, promotedPeer.Endpoint); err != nil {
|
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
|
||||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -225,7 +225,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error {
|
func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
@@ -234,16 +234,6 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
|
|||||||
return fmt.Errorf("peer with site ID %d not found", siteConfig.SiteId)
|
return fmt.Errorf("peer with site ID %d not found", siteConfig.SiteId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine which endpoint to use based on relay state
|
|
||||||
// If the peer is currently relayed, use the relay endpoint; otherwise use the direct endpoint
|
|
||||||
actualEndpoint := endpoint
|
|
||||||
if pm.peerMonitor != nil && pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId) {
|
|
||||||
if oldPeer.RelayEndpoint != "" {
|
|
||||||
actualEndpoint = oldPeer.RelayEndpoint
|
|
||||||
logger.Info("Peer %d is relayed, using relay endpoint: %s", siteConfig.SiteId, actualEndpoint)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If public key changed, remove old peer first
|
// If public key changed, remove old peer first
|
||||||
if siteConfig.PublicKey != oldPeer.PublicKey {
|
if siteConfig.PublicKey != oldPeer.PublicKey {
|
||||||
if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil {
|
if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil {
|
||||||
@@ -295,7 +285,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
|
|||||||
wgConfig := siteConfig
|
wgConfig := siteConfig
|
||||||
wgConfig.AllowedIps = ownedIPs
|
wgConfig.AllowedIps = ownedIPs
|
||||||
|
|
||||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, actualEndpoint); err != nil {
|
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -305,7 +295,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
|
|||||||
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||||
promotedWgConfig := promotedPeer
|
promotedWgConfig := promotedPeer
|
||||||
promotedWgConfig.AllowedIps = promotedOwnedIPs
|
promotedWgConfig.AllowedIps = promotedOwnedIPs
|
||||||
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, promotedPeer.Endpoint); err != nil {
|
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
|
||||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -464,7 +454,7 @@ func (pm *PeerManager) addAllowedIp(siteId int, ip string) error {
|
|||||||
|
|
||||||
// Only update WireGuard if we own this IP
|
// Only update WireGuard if we own this IP
|
||||||
if pm.allowedIPOwners[ip] == siteId {
|
if pm.allowedIPOwners[ip] == siteId {
|
||||||
if err := ConfigurePeer(pm.device, peer, pm.privateKey, peer.Endpoint); err != nil {
|
if err := ConfigurePeer(pm.device, peer, pm.privateKey, pm.peerMonitor.IsPeerRelayed(peer.SiteId)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -504,14 +494,14 @@ func (pm *PeerManager) removeAllowedIp(siteId int, cidr string) error {
|
|||||||
newOwner, promoted := pm.releaseAllowedIP(siteId, cidr)
|
newOwner, promoted := pm.releaseAllowedIP(siteId, cidr)
|
||||||
|
|
||||||
// Update WireGuard for this peer (to remove the IP from its config)
|
// Update WireGuard for this peer (to remove the IP from its config)
|
||||||
if err := ConfigurePeer(pm.device, peer, pm.privateKey, peer.Endpoint); err != nil {
|
if err := ConfigurePeer(pm.device, peer, pm.privateKey, pm.peerMonitor.IsPeerRelayed(peer.SiteId)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If another peer was promoted to owner, update their WireGuard config
|
// If another peer was promoted to owner, update their WireGuard config
|
||||||
if promoted && newOwner >= 0 {
|
if promoted && newOwner >= 0 {
|
||||||
if newOwnerPeer, exists := pm.peers[newOwner]; exists {
|
if newOwnerPeer, exists := pm.peers[newOwner]; exists {
|
||||||
if err := ConfigurePeer(pm.device, newOwnerPeer, pm.privateKey, newOwnerPeer.Endpoint); err != nil {
|
if err := ConfigurePeer(pm.device, newOwnerPeer, pm.privateKey, pm.peerMonitor.IsPeerRelayed(peer.SiteId)); err != nil {
|
||||||
logger.Error("Failed to promote peer %d for IP %s: %v", newOwner, cidr, err)
|
logger.Error("Failed to promote peer %d for IP %s: %v", newOwner, cidr, err)
|
||||||
} else {
|
} else {
|
||||||
logger.Info("Promoted peer %d to owner of IP %s", newOwner, cidr)
|
logger.Info("Promoted peer %d to owner of IP %s", newOwner, cidr)
|
||||||
|
|||||||
@@ -11,8 +11,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
|
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error {
|
||||||
siteHost, err := util.ResolveDomain(formatEndpoint(siteConfig.Endpoint))
|
var endpoint string
|
||||||
|
if relay && siteConfig.RelayEndpoint != "" {
|
||||||
|
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
|
||||||
|
} else {
|
||||||
|
endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
|
}
|
||||||
|
siteHost, err := util.ResolveDomain(endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -646,7 +646,9 @@ func (c *Client) readPumpWithDisconnectDetection() {
|
|||||||
|
|
||||||
c.handlersMux.RLock()
|
c.handlersMux.RLock()
|
||||||
if handler, ok := c.handlers[msg.Type]; ok {
|
if handler, ok := c.handlers[msg.Type]; ok {
|
||||||
|
logger.Debug("***********************************Running handler for message type: %s", msg.Type)
|
||||||
handler(msg)
|
handler(msg)
|
||||||
|
logger.Debug("***********************************Finished handler for message type: %s", msg.Type)
|
||||||
}
|
}
|
||||||
c.handlersMux.RUnlock()
|
c.handlersMux.RUnlock()
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user