mirror of
https://github.com/fosrl/olm.git
synced 2026-02-26 23:06:41 +00:00
18
olm/olm.go
18
olm/olm.go
@@ -799,6 +799,24 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
olm.OnAuthError(func(statusCode int, message string) {
|
||||||
|
logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message)
|
||||||
|
apiServer.SetTerminated(true)
|
||||||
|
apiServer.SetConnectionStatus(false)
|
||||||
|
apiServer.SetRegistered(false)
|
||||||
|
apiServer.ClearPeerStatuses()
|
||||||
|
network.ClearNetworkSettings()
|
||||||
|
Close()
|
||||||
|
|
||||||
|
if globalConfig.OnAuthError != nil {
|
||||||
|
go globalConfig.OnAuthError(statusCode, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
if globalConfig.OnTerminated != nil {
|
||||||
|
go globalConfig.OnTerminated()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Connect to the WebSocket server
|
// Connect to the WebSocket server
|
||||||
if err := olm.Connect(); err != nil {
|
if err := olm.Connect(); err != nil {
|
||||||
logger.Error("Failed to connect to server: %v", err)
|
logger.Error("Failed to connect to server: %v", err)
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ type GlobalConfig struct {
|
|||||||
OnRegistered func()
|
OnRegistered func()
|
||||||
OnConnected func()
|
OnConnected func()
|
||||||
OnTerminated func()
|
OnTerminated func()
|
||||||
|
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
|
||||||
|
|
||||||
// Source tracking (not in JSON)
|
// Source tracking (not in JSON)
|
||||||
sources map[string]string
|
sources map[string]string
|
||||||
|
|||||||
@@ -20,6 +20,22 @@ import (
|
|||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AuthError represents an authentication/authorization error (401/403)
|
||||||
|
type AuthError struct {
|
||||||
|
StatusCode int
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AuthError) Error() string {
|
||||||
|
return fmt.Sprintf("authentication error (status %d): %s", e.StatusCode, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAuthError checks if an error is an authentication error
|
||||||
|
func IsAuthError(err error) bool {
|
||||||
|
_, ok := err.(*AuthError)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
Data struct {
|
Data struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
@@ -56,6 +72,7 @@ type Client struct {
|
|||||||
pingTimeout time.Duration
|
pingTimeout time.Duration
|
||||||
onConnect func() error
|
onConnect func() error
|
||||||
onTokenUpdate func(token string)
|
onTokenUpdate func(token string)
|
||||||
|
onAuthError func(statusCode int, message string) // Callback for auth errors
|
||||||
writeMux sync.Mutex
|
writeMux sync.Mutex
|
||||||
clientType string // Type of client (e.g., "newt", "olm")
|
clientType string // Type of client (e.g., "newt", "olm")
|
||||||
tlsConfig TLSConfig
|
tlsConfig TLSConfig
|
||||||
@@ -103,6 +120,10 @@ func (c *Client) OnTokenUpdate(callback func(token string)) {
|
|||||||
c.onTokenUpdate = callback
|
c.onTokenUpdate = callback
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) OnAuthError(callback func(statusCode int, message string)) {
|
||||||
|
c.onAuthError = callback
|
||||||
|
}
|
||||||
|
|
||||||
// NewClient creates a new websocket client
|
// NewClient creates a new websocket client
|
||||||
func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
|
func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
@@ -305,6 +326,16 @@ func (c *Client) getToken() (string, error) {
|
|||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||||
|
|
||||||
|
// Return AuthError for 401/403 status codes
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||||
|
return "", &AuthError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Message: string(body),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For other errors (5xx, network issues, etc.), return regular error
|
||||||
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,6 +366,18 @@ func (c *Client) connectWithRetry() {
|
|||||||
default:
|
default:
|
||||||
err := c.establishConnection()
|
err := c.establishConnection()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Check if this is an auth error (401/403)
|
||||||
|
if authErr, ok := err.(*AuthError); ok {
|
||||||
|
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
|
||||||
|
// Trigger auth error callback if set (this should terminate the tunnel)
|
||||||
|
if c.onAuthError != nil {
|
||||||
|
c.onAuthError(authErr.StatusCode, authErr.Message)
|
||||||
|
}
|
||||||
|
// Continue retrying after auth error
|
||||||
|
time.Sleep(c.reconnectInterval)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// For other errors (5xx, network issues), continue retrying
|
||||||
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||||
time.Sleep(c.reconnectInterval)
|
time.Sleep(c.reconnectInterval)
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user