mirror of
https://github.com/fosrl/olm.git
synced 2026-02-28 15:56:43 +00:00
Handle holepunches better
This commit is contained in:
2
main.go
2
main.go
@@ -235,7 +235,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
OrgID: config.OrgID,
|
OrgID: config.OrgID,
|
||||||
OverrideDNS: config.OverrideDNS,
|
OverrideDNS: config.OverrideDNS,
|
||||||
EnableUAPI: true,
|
EnableUAPI: true,
|
||||||
DisableRelay: true,
|
DisableRelay: false, // allow it to relay
|
||||||
}
|
}
|
||||||
go olm.StartTunnel(tunnelConfig)
|
go olm.StartTunnel(tunnelConfig)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
190
olm/olm.go
190
olm/olm.go
@@ -33,7 +33,6 @@ var (
|
|||||||
connected bool
|
connected bool
|
||||||
dev *device.Device
|
dev *device.Device
|
||||||
wgData WgData
|
wgData WgData
|
||||||
holePunchData HolePunchData
|
|
||||||
uapiListener net.Listener
|
uapiListener net.Listener
|
||||||
tdev tun.Device
|
tdev tun.Device
|
||||||
middleDev *olmDevice.MiddleDevice
|
middleDev *olmDevice.MiddleDevice
|
||||||
@@ -48,13 +47,22 @@ var (
|
|||||||
globalConfig GlobalConfig
|
globalConfig GlobalConfig
|
||||||
globalCtx context.Context
|
globalCtx context.Context
|
||||||
stopRegister func()
|
stopRegister func()
|
||||||
|
stopPeerSend func()
|
||||||
|
updateRegister func(newData interface{})
|
||||||
stopPing chan struct{}
|
stopPing chan struct{}
|
||||||
peerManager *peers.PeerManager
|
peerManager *peers.PeerManager
|
||||||
)
|
)
|
||||||
|
|
||||||
// initSharedBindAndHolepunch creates the shared UDP socket and holepunch manager.
|
// initTunnelInfo creates the shared UDP socket and holepunch manager.
|
||||||
// This is used during initial tunnel setup and when switching organizations.
|
// This is used during initial tunnel setup and when switching organizations.
|
||||||
func initSharedBindAndHolepunch(clientID string) error {
|
func initTunnelInfo(clientID string) error {
|
||||||
|
var err error
|
||||||
|
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to generate private key: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
sourcePort, err := util.FindAvailableUDPPort(49152, 65535)
|
sourcePort, err := util.FindAvailableUDPPort(49152, 65535)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to find available UDP port: %w", err)
|
return fmt.Errorf("failed to find available UDP port: %w", err)
|
||||||
@@ -82,7 +90,7 @@ func initSharedBindAndHolepunch(clientID string) error {
|
|||||||
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
||||||
|
|
||||||
// Create the holepunch manager
|
// Create the holepunch manager
|
||||||
holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm")
|
holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -249,82 +257,12 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
// Store the client reference globally
|
// Store the client reference globally
|
||||||
olmClient = olm
|
olmClient = olm
|
||||||
|
|
||||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to generate private key: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create shared UDP socket and holepunch manager
|
// Create shared UDP socket and holepunch manager
|
||||||
if err := initSharedBindAndHolepunch(id); err != nil {
|
if err := initTunnelInfo(id); err != nil {
|
||||||
logger.Error("%v", err)
|
logger.Error("%v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
|
||||||
logger.Debug("Received message: %v", msg.Data)
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
|
|
||||||
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
|
|
||||||
for i, node := range holePunchData.ExitNodes {
|
|
||||||
exitNodes[i] = holepunch.ExitNode{
|
|
||||||
Endpoint: node.Endpoint,
|
|
||||||
PublicKey: node.PublicKey,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start hole punching using the manager
|
|
||||||
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
|
|
||||||
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
|
|
||||||
logger.Warn("Failed to start hole punch: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
|
||||||
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
|
|
||||||
logger.Debug("Received message: %v", msg.Data)
|
|
||||||
|
|
||||||
type LegacyHolePunchData struct {
|
|
||||||
ServerPubKey string `json:"serverPubKey"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var legacyHolePunchData LegacyHolePunchData
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop any existing hole punch operations
|
|
||||||
if holePunchManager != nil {
|
|
||||||
holePunchManager.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start hole punching for the exit node
|
|
||||||
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
|
|
||||||
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
|
|
||||||
logger.Warn("Failed to start hole punch: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received message: %v", msg.Data)
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
@@ -338,9 +276,9 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
stopRegister = nil
|
stopRegister = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait 10 milliseconds to ensure the previous connection is closed
|
if updateRegister != nil {
|
||||||
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
|
updateRegister = nil
|
||||||
time.Sleep(500 * time.Millisecond)
|
}
|
||||||
|
|
||||||
// if there is an existing tunnel then close it
|
// if there is an existing tunnel then close it
|
||||||
if dev != nil {
|
if dev != nil {
|
||||||
@@ -572,6 +510,11 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received add-peer message: %v", msg.Data)
|
logger.Debug("Received add-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
if stopPeerSend != nil {
|
||||||
|
stopPeerSend()
|
||||||
|
stopPeerSend = nil
|
||||||
|
}
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error marshaling data: %v", err)
|
logger.Error("Error marshaling data: %v", err)
|
||||||
@@ -584,6 +527,8 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||||
|
|
||||||
if err := peerManager.AddPeer(siteConfig, endpoint); err != nil {
|
if err := peerManager.AddPeer(siteConfig, endpoint); err != nil {
|
||||||
logger.Error("Failed to add peer: %v", err)
|
logger.Error("Failed to add peer: %v", err)
|
||||||
return
|
return
|
||||||
@@ -753,6 +698,59 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
|
||||||
|
olm.RegisterHandler("olm/wg/peer/holepunch/site/add", func(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received peer-handshake message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling handshake data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var handshakeData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
ExitNode struct {
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
} `json:"exitNode"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling handshake data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add exit node to holepunch rotation if we have a holepunch manager
|
||||||
|
if holePunchManager != nil {
|
||||||
|
exitNode := holepunch.ExitNode{
|
||||||
|
Endpoint: handshakeData.ExitNode.Endpoint,
|
||||||
|
PublicKey: handshakeData.ExitNode.PublicKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
added := holePunchManager.AddExitNode(exitNode)
|
||||||
|
if added {
|
||||||
|
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start holepunching if not already running
|
||||||
|
if !holePunchManager.IsRunning() {
|
||||||
|
if err := holePunchManager.Start(); err != nil {
|
||||||
|
logger.Error("Failed to start holepunch manager: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send handshake acknowledgment back to server with retry
|
||||||
|
stopPeerSend, _ = olm.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||||
|
"siteId": handshakeData.SiteId,
|
||||||
|
}, 1*time.Second)
|
||||||
|
|
||||||
|
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||||
|
})
|
||||||
|
|
||||||
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
|
||||||
logger.Info("Received terminate message")
|
logger.Info("Received terminate message")
|
||||||
apiServer.SetTerminated(true)
|
apiServer.SetTerminated(true)
|
||||||
@@ -779,15 +777,17 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
publicKey := privateKey.PublicKey()
|
publicKey := privateKey.PublicKey()
|
||||||
|
|
||||||
|
// delay for 500ms to allow for time for the hp to get processed
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
if stopRegister == nil {
|
if stopRegister == nil {
|
||||||
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
|
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
|
||||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
stopRegister, updateRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"relay": !config.Holepunch,
|
"relay": !config.Holepunch,
|
||||||
"olmVersion": globalConfig.Version,
|
"olmVersion": globalConfig.Version,
|
||||||
"orgId": config.OrgID,
|
"orgId": config.OrgID,
|
||||||
"userToken": userToken,
|
"userToken": userToken,
|
||||||
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
|
||||||
}, 1*time.Second)
|
}, 1*time.Second)
|
||||||
|
|
||||||
// Invoke onRegistered callback if configured
|
// Invoke onRegistered callback if configured
|
||||||
@@ -801,9 +801,28 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
olm.OnTokenUpdate(func(token string) {
|
olm.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) {
|
||||||
if holePunchManager != nil {
|
if holePunchManager != nil {
|
||||||
holePunchManager.SetToken(token)
|
holePunchManager.SetToken(token)
|
||||||
|
|
||||||
|
logger.Debug("Got exit nodes for hole punching: %v", exitNodes)
|
||||||
|
|
||||||
|
// Convert websocket.ExitNode to holepunch.ExitNode
|
||||||
|
hpExitNodes := make([]holepunch.ExitNode, len(exitNodes))
|
||||||
|
for i, node := range exitNodes {
|
||||||
|
hpExitNodes[i] = holepunch.ExitNode{
|
||||||
|
Endpoint: node.Endpoint,
|
||||||
|
PublicKey: node.PublicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Updated hole punch exit nodes: %v", hpExitNodes)
|
||||||
|
|
||||||
|
// Start hole punching using the manager
|
||||||
|
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
|
||||||
|
if err := holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil {
|
||||||
|
logger.Warn("Failed to start hole punch: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -814,6 +833,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
apiServer.SetRegistered(false)
|
apiServer.SetRegistered(false)
|
||||||
apiServer.ClearPeerStatuses()
|
apiServer.ClearPeerStatuses()
|
||||||
network.ClearNetworkSettings()
|
network.ClearNetworkSettings()
|
||||||
|
|
||||||
Close()
|
Close()
|
||||||
|
|
||||||
if globalConfig.OnAuthError != nil {
|
if globalConfig.OnAuthError != nil {
|
||||||
@@ -864,6 +884,10 @@ func Close() {
|
|||||||
stopRegister = nil
|
stopRegister = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if updateRegister != nil {
|
||||||
|
updateRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
if peerMonitor != nil {
|
if peerMonitor != nil {
|
||||||
peerMonitor.Close() // Close() also calls Stop() internally
|
peerMonitor.Close() // Close() also calls Stop() internally
|
||||||
peerMonitor = nil
|
peerMonitor = nil
|
||||||
@@ -992,7 +1016,7 @@ func SwitchOrg(orgID string) error {
|
|||||||
Close()
|
Close()
|
||||||
|
|
||||||
// Recreate sharedBind and holepunch manager - needed because Close() releases them
|
// Recreate sharedBind and holepunch manager - needed because Close() releases them
|
||||||
if err := initSharedBindAndHolepunch(olmClient.GetConfig().ID); err != nil {
|
if err := initTunnelInfo(olmClient.GetConfig().ID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1002,7 +1026,7 @@ func SwitchOrg(orgID string) error {
|
|||||||
// Trigger re-registration with new orgId
|
// Trigger re-registration with new orgId
|
||||||
logger.Info("Re-registering with new orgId: %s", orgID)
|
logger.Info("Re-registering with new orgId: %s", orgID)
|
||||||
publicKey := privateKey.PublicKey()
|
publicKey := privateKey.PublicKey()
|
||||||
stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
stopRegister, updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"relay": true, // Default to relay mode for org switch
|
"relay": true, // Default to relay mode for org switch
|
||||||
"olmVersion": globalConfig.Version,
|
"olmVersion": globalConfig.Version,
|
||||||
|
|||||||
19
olm/types.go
19
olm/types.go
@@ -12,25 +12,6 @@ type WgData struct {
|
|||||||
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
|
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
|
||||||
}
|
}
|
||||||
|
|
||||||
type HolePunchMessage struct {
|
|
||||||
NewtID string `json:"newtId"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ExitNode struct {
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type HolePunchData struct {
|
|
||||||
ExitNodes []ExitNode `json:"exitNodes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type EncryptedHolePunchMessage struct {
|
|
||||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
|
||||||
Nonce []byte `json:"nonce"`
|
|
||||||
Ciphertext []byte `json:"ciphertext"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GlobalConfig struct {
|
type GlobalConfig struct {
|
||||||
// Logging
|
// Logging
|
||||||
LogLevel string
|
LogLevel string
|
||||||
|
|||||||
@@ -38,12 +38,18 @@ func IsAuthError(err error) bool {
|
|||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
Data struct {
|
Data struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
ExitNodes []ExitNode `json:"exitNodes"`
|
||||||
} `json:"data"`
|
} `json:"data"`
|
||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ExitNode struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
}
|
||||||
|
|
||||||
type WSMessage struct {
|
type WSMessage struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Data interface{} `json:"data"`
|
Data interface{} `json:"data"`
|
||||||
@@ -71,7 +77,7 @@ type Client struct {
|
|||||||
pingInterval time.Duration
|
pingInterval time.Duration
|
||||||
pingTimeout time.Duration
|
pingTimeout time.Duration
|
||||||
onConnect func() error
|
onConnect func() error
|
||||||
onTokenUpdate func(token string)
|
onTokenUpdate func(token string, exitNodes []ExitNode)
|
||||||
onAuthError func(statusCode int, message string) // Callback for auth errors
|
onAuthError func(statusCode int, message string) // Callback for auth errors
|
||||||
writeMux sync.Mutex
|
writeMux sync.Mutex
|
||||||
clientType string // Type of client (e.g., "newt", "olm")
|
clientType string // Type of client (e.g., "newt", "olm")
|
||||||
@@ -116,7 +122,7 @@ func (c *Client) OnConnect(callback func() error) {
|
|||||||
c.onConnect = callback
|
c.onConnect = callback
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) OnTokenUpdate(callback func(token string)) {
|
func (c *Client) OnTokenUpdate(callback func(token string, exitNodes []ExitNode)) {
|
||||||
c.onTokenUpdate = callback
|
c.onTokenUpdate = callback
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,13 +218,17 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
|
|||||||
return c.conn.WriteJSON(msg)
|
return c.conn.WriteJSON(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
|
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
|
||||||
stopChan := make(chan struct{})
|
stopChan := make(chan struct{})
|
||||||
|
updateChan := make(chan interface{})
|
||||||
|
var dataMux sync.Mutex
|
||||||
|
currentData := data
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
count := 0
|
count := 0
|
||||||
maxAttempts := 10
|
maxAttempts := 10
|
||||||
|
|
||||||
err := c.SendMessage(messageType, data) // Send immediately
|
err := c.SendMessage(messageType, currentData) // Send immediately
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to send initial message: %v", err)
|
logger.Error("Failed to send initial message: %v", err)
|
||||||
}
|
}
|
||||||
@@ -233,19 +243,46 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
|||||||
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = c.SendMessage(messageType, data)
|
dataMux.Lock()
|
||||||
|
err = c.SendMessage(messageType, currentData)
|
||||||
|
dataMux.Unlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to send message: %v", err)
|
logger.Error("Failed to send message: %v", err)
|
||||||
}
|
}
|
||||||
count++
|
count++
|
||||||
|
case newData := <-updateChan:
|
||||||
|
dataMux.Lock()
|
||||||
|
// Merge newData into currentData if both are maps
|
||||||
|
if currentMap, ok := currentData.(map[string]interface{}); ok {
|
||||||
|
if newMap, ok := newData.(map[string]interface{}); ok {
|
||||||
|
// Update or add keys from newData
|
||||||
|
for key, value := range newMap {
|
||||||
|
currentMap[key] = value
|
||||||
|
}
|
||||||
|
currentData = currentMap
|
||||||
|
} else {
|
||||||
|
// If newData is not a map, replace entirely
|
||||||
|
currentData = newData
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If currentData is not a map, replace entirely
|
||||||
|
currentData = newData
|
||||||
|
}
|
||||||
|
dataMux.Unlock()
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return func() {
|
return func() {
|
||||||
close(stopChan)
|
close(stopChan)
|
||||||
}
|
}, func(newData interface{}) {
|
||||||
|
select {
|
||||||
|
case updateChan <- newData:
|
||||||
|
case <-stopChan:
|
||||||
|
// Channel is closed, ignore update
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterHandler registers a handler for a specific message type
|
// RegisterHandler registers a handler for a specific message type
|
||||||
@@ -255,11 +292,11 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
|||||||
c.handlers[messageType] = handler
|
c.handlers[messageType] = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) getToken() (string, error) {
|
func (c *Client) getToken() (string, []ExitNode, error) {
|
||||||
// Parse the base URL to ensure we have the correct hostname
|
// Parse the base URL to ensure we have the correct hostname
|
||||||
baseURL, err := url.Parse(c.baseURL)
|
baseURL, err := url.Parse(c.baseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to parse base URL: %w", err)
|
return "", nil, fmt.Errorf("failed to parse base URL: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we have the base URL without trailing slashes
|
// Ensure we have the base URL without trailing slashes
|
||||||
@@ -271,7 +308,7 @@ func (c *Client) getToken() (string, error) {
|
|||||||
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||||
tlsConfig, err = c.setupTLS()
|
tlsConfig, err = c.setupTLS()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to setup TLS configuration: %w", err)
|
return "", nil, fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,7 +330,7 @@ func (c *Client) getToken() (string, error) {
|
|||||||
jsonData, err := json.Marshal(tokenData)
|
jsonData, err := json.Marshal(tokenData)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to marshal token request data: %w", err)
|
return "", nil, fmt.Errorf("failed to marshal token request data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new request
|
// Create a new request
|
||||||
@@ -303,7 +340,7 @@ func (c *Client) getToken() (string, error) {
|
|||||||
bytes.NewBuffer(jsonData),
|
bytes.NewBuffer(jsonData),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to create request: %w", err)
|
return "", nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set headers
|
// Set headers
|
||||||
@@ -319,7 +356,7 @@ func (c *Client) getToken() (string, error) {
|
|||||||
}
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to request new token: %w", err)
|
return "", nil, fmt.Errorf("failed to request new token: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
@@ -329,33 +366,33 @@ func (c *Client) getToken() (string, error) {
|
|||||||
|
|
||||||
// Return AuthError for 401/403 status codes
|
// Return AuthError for 401/403 status codes
|
||||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||||
return "", &AuthError{
|
return "", nil, &AuthError{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
Message: string(body),
|
Message: string(body),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For other errors (5xx, network issues, etc.), return regular error
|
// For other errors (5xx, network issues, etc.), return regular error
|
||||||
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
return "", nil, fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenResp TokenResponse
|
var tokenResp TokenResponse
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||||
logger.Error("Failed to decode token response.")
|
logger.Error("Failed to decode token response.")
|
||||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tokenResp.Success {
|
if !tokenResp.Success {
|
||||||
return "", fmt.Errorf("failed to get token: %s", tokenResp.Message)
|
return "", nil, fmt.Errorf("failed to get token: %s", tokenResp.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tokenResp.Data.Token == "" {
|
if tokenResp.Data.Token == "" {
|
||||||
return "", fmt.Errorf("received empty token from server")
|
return "", nil, fmt.Errorf("received empty token from server")
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
||||||
|
|
||||||
return tokenResp.Data.Token, nil
|
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) connectWithRetry() {
|
func (c *Client) connectWithRetry() {
|
||||||
@@ -389,13 +426,13 @@ func (c *Client) connectWithRetry() {
|
|||||||
|
|
||||||
func (c *Client) establishConnection() error {
|
func (c *Client) establishConnection() error {
|
||||||
// Get token for authentication
|
// Get token for authentication
|
||||||
token, err := c.getToken()
|
token, exitNodes, err := c.getToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get token: %w", err)
|
return fmt.Errorf("failed to get token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.onTokenUpdate != nil {
|
if c.onTokenUpdate != nil {
|
||||||
c.onTokenUpdate(token)
|
c.onTokenUpdate(token, exitNodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the base URL to determine protocol and hostname
|
// Parse the base URL to determine protocol and hostname
|
||||||
|
|||||||
Reference in New Issue
Block a user