diff --git a/main.go b/main.go index b4025f8..f440db8 100644 --- a/main.go +++ b/main.go @@ -204,6 +204,8 @@ func main() { id, // CLI arg takes precedence secret, // CLI arg takes precedence endpoint, + pingInterval, + pingTimeout, opt, ) if err != nil { @@ -660,14 +662,26 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub publicKey = privateKey.PublicKey() logger.Debug("Public key: %s", publicKey) - // request from the server the list of nodes to ping at newt/ping/request - stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) + if !connected { + // request from the server the list of nodes to ping at newt/ping/request + stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) - if wgService != nil { - wgService.LoadRemoteConfig() + // Send registration message to the server for backward compatibility + err := client.SendMessage("newt/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "backwardsCompatible": true, + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + logger.Info("Sent registration message") + + if wgService != nil { + wgService.LoadRemoteConfig() + } } - logger.Info("Sent registration message") return nil }) diff --git a/websocket/client.go b/websocket/client.go index 6b34627..4bd2c7d 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -29,9 +29,10 @@ type Client struct { reconnectInterval time.Duration isConnected bool reconnectMux sync.RWMutex - - onConnect func() error - onTokenUpdate func(token string) + pingInterval time.Duration + pingTimeout time.Duration + onConnect func() error + onTokenUpdate func(token string) } type ClientOption func(*Client) @@ -60,7 +61,7 @@ func (c *Client) OnTokenUpdate(callback func(token string)) { } // NewClient creates a new Newt client -func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) { +func NewClient(newtID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { config := &Config{ NewtID: newtID, Secret: secret, @@ -74,16 +75,16 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C done: make(chan struct{}), reconnectInterval: 10 * time.Second, isConnected: false, + pingInterval: pingInterval, + pingTimeout: pingTimeout, } // Apply options before loading config - if opts != nil { - for _, opt := range opts { - if opt == nil { - continue - } - opt(client) + for _, opt := range opts { + if opt == nil { + continue } + opt(client) } // Load existing config if available @@ -160,30 +161,6 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { c.handlers[messageType] = handler } -// readPump pumps messages from the WebSocket connection -func (c *Client) readPump() { - defer c.conn.Close() - - for { - select { - case <-c.done: - return - default: - var msg WSMessage - err := c.conn.ReadJSON(&msg) - if err != nil { - return - } - - c.handlersMux.RLock() - if handler, ok := c.handlers[msg.Type]; ok { - handler(msg) - } - c.handlersMux.RUnlock() - } - } -} - func (c *Client) getToken() (string, error) { // Parse the base URL to ensure we have the correct hostname baseURL, err := url.Parse(c.baseURL) @@ -380,8 +357,8 @@ func (c *Client) establishConnection() error { // Start the ping monitor go c.pingMonitor() - // Start the read pump - go c.readPump() + // Start the read pump with disconnect detection + go c.readPumpWithDisconnectDetection() if c.onConnect != nil { err := c.saveConfig() @@ -396,8 +373,9 @@ func (c *Client) establishConnection() error { return nil } +// pingMonitor sends pings at a short interval and triggers reconnect on failure func (c *Client) pingMonitor() { - ticker := time.NewTicker(30 * time.Second) + ticker := time.NewTicker(c.pingInterval) defer ticker.Stop() for { @@ -405,7 +383,10 @@ func (c *Client) pingMonitor() { case <-c.done: return case <-ticker.C: - if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { + if c.conn == nil { + return + } + if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout)); err != nil { logger.Error("Ping failed: %v", err) c.reconnect() return @@ -414,10 +395,41 @@ func (c *Client) pingMonitor() { } } +// readPumpWithDisconnectDetection reads messages and triggers reconnect on error +func (c *Client) readPumpWithDisconnectDetection() { + defer func() { + if c.conn != nil { + c.conn.Close() + } + c.reconnect() + }() + + for { + select { + case <-c.done: + return + default: + var msg WSMessage + err := c.conn.ReadJSON(&msg) + if err != nil { + logger.Error("WebSocket read error: %v", err) + return // triggers reconnect via defer + } + + c.handlersMux.RLock() + if handler, ok := c.handlers[msg.Type]; ok { + handler(msg) + } + c.handlersMux.RUnlock() + } + } +} + func (c *Client) reconnect() { c.setConnected(false) if c.conn != nil { c.conn.Close() + c.conn = nil } go c.connectWithRetry()