From 9ce645035150cfc11ea698e1ee71b4ebc1b41362 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 30 Nov 2025 18:12:06 -0500 Subject: [PATCH] Terminate on auth token 403 or 401 Former-commit-id: 63f0a28b77a1b9b50658c133572f5c3c7302d675 --- olm/olm.go | 18 ++++++++++++++++++ olm/types.go | 1 + websocket/client.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/olm/olm.go b/olm/olm.go index 1781f73..3444a94 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -799,6 +799,24 @@ func StartTunnel(config TunnelConfig) { } }) + olm.OnAuthError(func(statusCode int, message string) { + logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message) + apiServer.SetTerminated(true) + apiServer.SetConnectionStatus(false) + apiServer.SetRegistered(false) + apiServer.ClearPeerStatuses() + network.ClearNetworkSettings() + Close() + + if globalConfig.OnAuthError != nil { + go globalConfig.OnAuthError(statusCode, message) + } + + if globalConfig.OnTerminated != nil { + go globalConfig.OnTerminated() + } + }) + // Connect to the WebSocket server if err := olm.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) diff --git a/olm/types.go b/olm/types.go index da113cc..cae876b 100644 --- a/olm/types.go +++ b/olm/types.go @@ -45,6 +45,7 @@ type GlobalConfig struct { OnRegistered func() OnConnected func() OnTerminated func() + OnAuthError func(statusCode int, message string) // Called when auth fails (401/403) // Source tracking (not in JSON) sources map[string]string diff --git a/websocket/client.go b/websocket/client.go index af46b96..64ffb45 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -20,6 +20,22 @@ import ( "github.com/gorilla/websocket" ) +// AuthError represents an authentication/authorization error (401/403) +type AuthError struct { + StatusCode int + Message string +} + +func (e *AuthError) Error() string { + return fmt.Sprintf("authentication error (status %d): %s", e.StatusCode, e.Message) +} + +// IsAuthError checks if an error is an authentication error +func IsAuthError(err error) bool { + _, ok := err.(*AuthError) + return ok +} + type TokenResponse struct { Data struct { Token string `json:"token"` @@ -56,6 +72,7 @@ type Client struct { pingTimeout time.Duration onConnect func() error onTokenUpdate func(token string) + onAuthError func(statusCode int, message string) // Callback for auth errors writeMux sync.Mutex clientType string // Type of client (e.g., "newt", "olm") tlsConfig TLSConfig @@ -103,6 +120,10 @@ func (c *Client) OnTokenUpdate(callback func(token string)) { c.onTokenUpdate = callback } +func (c *Client) OnAuthError(callback func(statusCode int, message string)) { + c.onAuthError = callback +} + // NewClient creates a new websocket client func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { config := &Config{ @@ -305,6 +326,16 @@ func (c *Client) getToken() (string, error) { if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + + // Return AuthError for 401/403 status codes + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return "", &AuthError{ + StatusCode: resp.StatusCode, + Message: string(body), + } + } + + // For other errors (5xx, network issues, etc.), return regular error return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) } @@ -335,6 +366,18 @@ func (c *Client) connectWithRetry() { default: err := c.establishConnection() if err != nil { + // Check if this is an auth error (401/403) + if authErr, ok := err.(*AuthError); ok { + logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr) + // Trigger auth error callback if set (this should terminate the tunnel) + if c.onAuthError != nil { + c.onAuthError(authErr.StatusCode, authErr.Message) + } + // Continue retrying after auth error + time.Sleep(c.reconnectInterval) + continue + } + // For other errors (5xx, network issues), continue retrying logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) time.Sleep(c.reconnectInterval) continue