Handle holepunches better

Former-commit-id: 136eee3302
This commit is contained in:
Owen
2025-12-01 13:54:01 -05:00
parent fb007e09a9
commit 7270b840cf
4 changed files with 167 additions and 125 deletions

View File

@@ -38,12 +38,18 @@ func IsAuthError(err error) bool {
type TokenResponse struct {
Data struct {
Token string `json:"token"`
Token string `json:"token"`
ExitNodes []ExitNode `json:"exitNodes"`
} `json:"data"`
Success bool `json:"success"`
Message string `json:"message"`
}
type ExitNode struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
@@ -71,7 +77,7 @@ type Client struct {
pingInterval time.Duration
pingTimeout time.Duration
onConnect func() error
onTokenUpdate func(token string)
onTokenUpdate func(token string, exitNodes []ExitNode)
onAuthError func(statusCode int, message string) // Callback for auth errors
writeMux sync.Mutex
clientType string // Type of client (e.g., "newt", "olm")
@@ -116,7 +122,7 @@ func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
func (c *Client) OnTokenUpdate(callback func(token string)) {
func (c *Client) OnTokenUpdate(callback func(token string, exitNodes []ExitNode)) {
c.onTokenUpdate = callback
}
@@ -212,13 +218,17 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
return c.conn.WriteJSON(msg)
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
stopChan := make(chan struct{})
updateChan := make(chan interface{})
var dataMux sync.Mutex
currentData := data
go func() {
count := 0
maxAttempts := 10
err := c.SendMessage(messageType, data) // Send immediately
err := c.SendMessage(messageType, currentData) // Send immediately
if err != nil {
logger.Error("Failed to send initial message: %v", err)
}
@@ -233,19 +243,46 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return
}
err = c.SendMessage(messageType, data)
dataMux.Lock()
err = c.SendMessage(messageType, currentData)
dataMux.Unlock()
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case newData := <-updateChan:
dataMux.Lock()
// Merge newData into currentData if both are maps
if currentMap, ok := currentData.(map[string]interface{}); ok {
if newMap, ok := newData.(map[string]interface{}); ok {
// Update or add keys from newData
for key, value := range newMap {
currentMap[key] = value
}
currentData = currentMap
} else {
// If newData is not a map, replace entirely
currentData = newData
}
} else {
// If currentData is not a map, replace entirely
currentData = newData
}
dataMux.Unlock()
case <-stopChan:
return
}
}
}()
return func() {
close(stopChan)
}
close(stopChan)
}, func(newData interface{}) {
select {
case updateChan <- newData:
case <-stopChan:
// Channel is closed, ignore update
}
}
}
// RegisterHandler registers a handler for a specific message type
@@ -255,11 +292,11 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
c.handlers[messageType] = handler
}
func (c *Client) getToken() (string, error) {
func (c *Client) getToken() (string, []ExitNode, error) {
// Parse the base URL to ensure we have the correct hostname
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return "", fmt.Errorf("failed to parse base URL: %w", err)
return "", nil, fmt.Errorf("failed to parse base URL: %w", err)
}
// Ensure we have the base URL without trailing slashes
@@ -271,7 +308,7 @@ func (c *Client) getToken() (string, error) {
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
tlsConfig, err = c.setupTLS()
if err != nil {
return "", fmt.Errorf("failed to setup TLS configuration: %w", err)
return "", nil, fmt.Errorf("failed to setup TLS configuration: %w", err)
}
}
@@ -293,7 +330,7 @@ func (c *Client) getToken() (string, error) {
jsonData, err := json.Marshal(tokenData)
if err != nil {
return "", fmt.Errorf("failed to marshal token request data: %w", err)
return "", nil, fmt.Errorf("failed to marshal token request data: %w", err)
}
// Create a new request
@@ -303,7 +340,7 @@ func (c *Client) getToken() (string, error) {
bytes.NewBuffer(jsonData),
)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
return "", nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers
@@ -319,7 +356,7 @@ func (c *Client) getToken() (string, error) {
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to request new token: %w", err)
return "", nil, fmt.Errorf("failed to request new token: %w", err)
}
defer resp.Body.Close()
@@ -329,33 +366,33 @@ func (c *Client) getToken() (string, error) {
// Return AuthError for 401/403 status codes
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return "", &AuthError{
return "", nil, &AuthError{
StatusCode: resp.StatusCode,
Message: string(body),
}
}
// For other errors (5xx, network issues, etc.), return regular error
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
return "", nil, fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
}
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("Failed to decode token response.")
return "", fmt.Errorf("failed to decode token response: %w", err)
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
}
if !tokenResp.Success {
return "", fmt.Errorf("failed to get token: %s", tokenResp.Message)
return "", nil, fmt.Errorf("failed to get token: %s", tokenResp.Message)
}
if tokenResp.Data.Token == "" {
return "", fmt.Errorf("received empty token from server")
return "", nil, fmt.Errorf("received empty token from server")
}
logger.Debug("Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, nil
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
}
func (c *Client) connectWithRetry() {
@@ -389,13 +426,13 @@ func (c *Client) connectWithRetry() {
func (c *Client) establishConnection() error {
// Get token for authentication
token, err := c.getToken()
token, exitNodes, err := c.getToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
if c.onTokenUpdate != nil {
c.onTokenUpdate(token)
c.onTokenUpdate(token, exitNodes)
}
// Parse the base URL to determine protocol and hostname