Faster detection on ws side

This commit is contained in:
Owen
2025-06-19 16:30:31 -04:00
parent bb1318278a
commit a14f70dbaa
2 changed files with 69 additions and 43 deletions

24
main.go
View File

@@ -204,6 +204,8 @@ func main() {
id, // CLI arg takes precedence id, // CLI arg takes precedence
secret, // CLI arg takes precedence secret, // CLI arg takes precedence
endpoint, endpoint,
pingInterval,
pingTimeout,
opt, opt,
) )
if err != nil { if err != nil {
@@ -660,14 +662,26 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
publicKey = privateKey.PublicKey() publicKey = privateKey.PublicKey()
logger.Debug("Public key: %s", publicKey) logger.Debug("Public key: %s", publicKey)
// request from the server the list of nodes to ping at newt/ping/request if !connected {
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) // request from the server the list of nodes to ping at newt/ping/request
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
if wgService != nil { // Send registration message to the server for backward compatibility
wgService.LoadRemoteConfig() err := client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"backwardsCompatible": true,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent registration message")
if wgService != nil {
wgService.LoadRemoteConfig()
}
} }
logger.Info("Sent registration message")
return nil return nil
}) })

View File

@@ -29,9 +29,10 @@ type Client struct {
reconnectInterval time.Duration reconnectInterval time.Duration
isConnected bool isConnected bool
reconnectMux sync.RWMutex reconnectMux sync.RWMutex
pingInterval time.Duration
onConnect func() error pingTimeout time.Duration
onTokenUpdate func(token string) onConnect func() error
onTokenUpdate func(token string)
} }
type ClientOption func(*Client) type ClientOption func(*Client)
@@ -60,7 +61,7 @@ func (c *Client) OnTokenUpdate(callback func(token string)) {
} }
// NewClient creates a new Newt client // NewClient creates a new Newt client
func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) { func NewClient(newtID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{ config := &Config{
NewtID: newtID, NewtID: newtID,
Secret: secret, Secret: secret,
@@ -74,16 +75,16 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
done: make(chan struct{}), done: make(chan struct{}),
reconnectInterval: 10 * time.Second, reconnectInterval: 10 * time.Second,
isConnected: false, isConnected: false,
pingInterval: pingInterval,
pingTimeout: pingTimeout,
} }
// Apply options before loading config // Apply options before loading config
if opts != nil { for _, opt := range opts {
for _, opt := range opts { if opt == nil {
if opt == nil { continue
continue
}
opt(client)
} }
opt(client)
} }
// Load existing config if available // Load existing config if available
@@ -160,30 +161,6 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
c.handlers[messageType] = handler c.handlers[messageType] = handler
} }
// readPump pumps messages from the WebSocket connection
func (c *Client) readPump() {
defer c.conn.Close()
for {
select {
case <-c.done:
return
default:
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err != nil {
return
}
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
handler(msg)
}
c.handlersMux.RUnlock()
}
}
}
func (c *Client) getToken() (string, error) { func (c *Client) getToken() (string, error) {
// Parse the base URL to ensure we have the correct hostname // Parse the base URL to ensure we have the correct hostname
baseURL, err := url.Parse(c.baseURL) baseURL, err := url.Parse(c.baseURL)
@@ -380,8 +357,8 @@ func (c *Client) establishConnection() error {
// Start the ping monitor // Start the ping monitor
go c.pingMonitor() go c.pingMonitor()
// Start the read pump // Start the read pump with disconnect detection
go c.readPump() go c.readPumpWithDisconnectDetection()
if c.onConnect != nil { if c.onConnect != nil {
err := c.saveConfig() err := c.saveConfig()
@@ -396,8 +373,9 @@ func (c *Client) establishConnection() error {
return nil return nil
} }
// pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) pingMonitor() { func (c *Client) pingMonitor() {
ticker := time.NewTicker(30 * time.Second) ticker := time.NewTicker(c.pingInterval)
defer ticker.Stop() defer ticker.Stop()
for { for {
@@ -405,7 +383,10 @@ func (c *Client) pingMonitor() {
case <-c.done: case <-c.done:
return return
case <-ticker.C: case <-ticker.C:
if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { if c.conn == nil {
return
}
if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout)); err != nil {
logger.Error("Ping failed: %v", err) logger.Error("Ping failed: %v", err)
c.reconnect() c.reconnect()
return return
@@ -414,10 +395,41 @@ func (c *Client) pingMonitor() {
} }
} }
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
func (c *Client) readPumpWithDisconnectDetection() {
defer func() {
if c.conn != nil {
c.conn.Close()
}
c.reconnect()
}()
for {
select {
case <-c.done:
return
default:
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err != nil {
logger.Error("WebSocket read error: %v", err)
return // triggers reconnect via defer
}
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
handler(msg)
}
c.handlersMux.RUnlock()
}
}
}
func (c *Client) reconnect() { func (c *Client) reconnect() {
c.setConnected(false) c.setConnected(false)
if c.conn != nil { if c.conn != nil {
c.conn.Close() c.conn.Close()
c.conn = nil
} }
go c.connectWithRetry() go c.connectWithRetry()