Merge pull request #10 from marcschaeferger/codex/review-opentelemetry-metrics-and-tracing

Enhance telemetry metrics and context propagation
This commit is contained in:
Marc Schäfer
2025-10-10 18:21:14 +02:00
committed by GitHub
6 changed files with 392 additions and 207 deletions

View File

@@ -45,14 +45,19 @@ var (
mBuildInfo metric.Int64ObservableGauge mBuildInfo metric.Int64ObservableGauge
// WebSocket // WebSocket
mWSConnectLatency metric.Float64Histogram mWSConnectLatency metric.Float64Histogram
mWSMessages metric.Int64Counter mWSMessages metric.Int64Counter
mWSDisconnects metric.Int64Counter
mWSKeepaliveFailure metric.Int64Counter
mWSSessionDuration metric.Float64Histogram
// Proxy // Proxy
mProxyActiveConns metric.Int64ObservableGauge mProxyActiveConns metric.Int64ObservableGauge
mProxyBufferBytes metric.Int64ObservableGauge mProxyBufferBytes metric.Int64ObservableGauge
mProxyAsyncBacklogByte metric.Int64ObservableGauge mProxyAsyncBacklogByte metric.Int64ObservableGauge
mProxyDropsTotal metric.Int64Counter mProxyDropsTotal metric.Int64Counter
mProxyAcceptsTotal metric.Int64Counter
mProxyConnDuration metric.Float64Histogram
buildVersion string buildVersion string
buildCommit string buildCommit string
@@ -179,6 +184,13 @@ func registerBuildWSProxyInstruments() error {
metric.WithUnit("s")) metric.WithUnit("s"))
mWSMessages, _ = meter.Int64Counter("newt_websocket_messages_total", mWSMessages, _ = meter.Int64Counter("newt_websocket_messages_total",
metric.WithDescription("WebSocket messages by direction and type")) metric.WithDescription("WebSocket messages by direction and type"))
mWSDisconnects, _ = meter.Int64Counter("newt_websocket_disconnects_total",
metric.WithDescription("WebSocket disconnects by reason/result"))
mWSKeepaliveFailure, _ = meter.Int64Counter("newt_websocket_keepalive_failures_total",
metric.WithDescription("WebSocket keepalive (ping/pong) failures"))
mWSSessionDuration, _ = meter.Float64Histogram("newt_websocket_session_duration_seconds",
metric.WithDescription("Duration of established WebSocket sessions"),
metric.WithUnit("s"))
// Proxy // Proxy
mProxyActiveConns, _ = meter.Int64ObservableGauge("newt_proxy_active_connections", mProxyActiveConns, _ = meter.Int64ObservableGauge("newt_proxy_active_connections",
metric.WithDescription("Proxy active connections per tunnel and protocol")) metric.WithDescription("Proxy active connections per tunnel and protocol"))
@@ -190,6 +202,11 @@ func registerBuildWSProxyInstruments() error {
metric.WithUnit("By")) metric.WithUnit("By"))
mProxyDropsTotal, _ = meter.Int64Counter("newt_proxy_drops_total", mProxyDropsTotal, _ = meter.Int64Counter("newt_proxy_drops_total",
metric.WithDescription("Proxy drops due to write errors")) metric.WithDescription("Proxy drops due to write errors"))
mProxyAcceptsTotal, _ = meter.Int64Counter("newt_proxy_accept_total",
metric.WithDescription("Proxy connection accepts by protocol and result"))
mProxyConnDuration, _ = meter.Float64Histogram("newt_proxy_connection_duration_seconds",
metric.WithDescription("Duration of completed proxy connections"),
metric.WithUnit("s"))
// Register a default callback for build info if version/commit set // Register a default callback for build info if version/commit set
reg, e := meter.RegisterCallback(func(ctx context.Context, o metric.Observer) error { reg, e := meter.RegisterCallback(func(ctx context.Context, o metric.Observer) error {
if buildVersion == "" && buildCommit == "" { if buildVersion == "" && buildCommit == "" {
@@ -328,6 +345,25 @@ func IncWSMessage(ctx context.Context, direction, msgType string) {
)...)) )...))
} }
func IncWSDisconnect(ctx context.Context, reason, result string) {
mWSDisconnects.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
attribute.String("reason", reason),
attribute.String("result", result),
)...))
}
func IncWSKeepaliveFailure(ctx context.Context, reason string) {
mWSKeepaliveFailure.Add(ctx, 1, metric.WithAttributes(attrsWithSite(
attribute.String("reason", reason),
)...))
}
func ObserveWSSessionDuration(ctx context.Context, seconds float64, result string) {
mWSSessionDuration.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(
attribute.String("result", result),
)...))
}
// --- Proxy helpers --- // --- Proxy helpers ---
func ObserveProxyActiveConnsObs(o metric.Observer, value int64, attrs []attribute.KeyValue) { func ObserveProxyActiveConnsObs(o metric.Observer, value int64, attrs []attribute.KeyValue) {
@@ -352,6 +388,31 @@ func IncProxyDrops(ctx context.Context, tunnelID, protocol string) {
mProxyDropsTotal.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...)) mProxyDropsTotal.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
} }
func IncProxyAccept(ctx context.Context, tunnelID, protocol, result, reason string) {
attrs := []attribute.KeyValue{
attribute.String("protocol", protocol),
attribute.String("result", result),
}
if reason != "" {
attrs = append(attrs, attribute.String("reason", reason))
}
if ShouldIncludeTunnelID() && tunnelID != "" {
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
}
mProxyAcceptsTotal.Add(ctx, 1, metric.WithAttributes(attrsWithSite(attrs...)...))
}
func ObserveProxyConnectionDuration(ctx context.Context, tunnelID, protocol, result string, seconds float64) {
attrs := []attribute.KeyValue{
attribute.String("protocol", protocol),
attribute.String("result", result),
}
if ShouldIncludeTunnelID() && tunnelID != "" {
attrs = append(attrs, attribute.String("tunnel_id", tunnelID))
}
mProxyConnDuration.Record(ctx, seconds, metric.WithAttributes(attrsWithSite(attrs...)...))
}
// --- Config/PKI helpers --- // --- Config/PKI helpers ---
func ObserveConfigApply(ctx context.Context, phase, result string, seconds float64) { func ObserveConfigApply(ctx context.Context, phase, result string, seconds float64) {

25
main.go
View File

@@ -586,6 +586,10 @@ func main() {
// Register handlers for different message types // Register handlers for different message types
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) { client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
logger.Info("Received registration message") logger.Info("Received registration message")
regResult := "success"
defer func() {
telemetry.IncSiteRegistration(ctx, regResult)
}()
if stopFunc != nil { if stopFunc != nil {
stopFunc() // stop the ws from sending more requests stopFunc() // stop the ws from sending more requests
stopFunc = nil // reset stopFunc to nil to avoid double stopping stopFunc = nil // reset stopFunc to nil to avoid double stopping
@@ -605,11 +609,13 @@ func main() {
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Info(fmtErrMarshaling, err) logger.Info(fmtErrMarshaling, err)
regResult = "failure"
return return
} }
if err := json.Unmarshal(jsonData, &wgData); err != nil { if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err) logger.Info("Error unmarshaling target data: %v", err)
regResult = "failure"
return return
} }
@@ -620,6 +626,7 @@ func main() {
mtuInt) mtuInt)
if err != nil { if err != nil {
logger.Error("Failed to create TUN device: %v", err) logger.Error("Failed to create TUN device: %v", err)
regResult = "failure"
} }
setDownstreamTNetstack(tnet) setDownstreamTNetstack(tnet)
@@ -633,6 +640,7 @@ func main() {
host, _, err := net.SplitHostPort(wgData.Endpoint) host, _, err := net.SplitHostPort(wgData.Endpoint)
if err != nil { if err != nil {
logger.Error("Failed to split endpoint: %v", err) logger.Error("Failed to split endpoint: %v", err)
regResult = "failure"
return return
} }
@@ -641,6 +649,7 @@ func main() {
endpoint, err := resolveDomain(wgData.Endpoint) endpoint, err := resolveDomain(wgData.Endpoint)
if err != nil { if err != nil {
logger.Error("Failed to resolve endpoint: %v", err) logger.Error("Failed to resolve endpoint: %v", err)
regResult = "failure"
return return
} }
@@ -656,12 +665,14 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
err = dev.IpcSet(config) err = dev.IpcSet(config)
if err != nil { if err != nil {
logger.Error("Failed to configure WireGuard device: %v", err) logger.Error("Failed to configure WireGuard device: %v", err)
regResult = "failure"
} }
// Bring up the device // Bring up the device
err = dev.Up() err = dev.Up()
if err != nil { if err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err) logger.Error("Failed to bring up WireGuard device: %v", err)
regResult = "failure"
} }
logger.Debug("WireGuard device created. Lets ping the server now...") logger.Debug("WireGuard device created. Lets ping the server now...")
@@ -676,10 +687,11 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
logger.Debug("Testing initial connection with reliable ping...") logger.Debug("Testing initial connection with reliable ping...")
lat, err := reliablePing(tnet, wgData.ServerIP, pingTimeout, 5) lat, err := reliablePing(tnet, wgData.ServerIP, pingTimeout, 5)
if err == nil && wgData.PublicKey != "" { if err == nil && wgData.PublicKey != "" {
telemetry.ObserveTunnelLatency(context.Background(), wgData.PublicKey, "wireguard", lat.Seconds()) telemetry.ObserveTunnelLatency(ctx, wgData.PublicKey, "wireguard", lat.Seconds())
} }
if err != nil { if err != nil {
logger.Warn("Initial reliable ping failed, but continuing: %v", err) logger.Warn("Initial reliable ping failed, but continuing: %v", err)
regResult = "failure"
} else { } else {
logger.Info("Initial connection test successful") logger.Info("Initial connection test successful")
} }
@@ -701,9 +713,6 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
connected = true connected = true
// telemetry: record a successful site registration (omit region unless available)
telemetry.IncSiteRegistration(context.Background(), "success")
// add the targets if there are any // add the targets if there are any
if len(wgData.Targets.TCP) > 0 { if len(wgData.Targets.TCP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP}) updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
@@ -738,7 +747,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
client.RegisterHandler("newt/wg/reconnect", func(msg websocket.WSMessage) { client.RegisterHandler("newt/wg/reconnect", func(msg websocket.WSMessage) {
logger.Info("Received reconnect message") logger.Info("Received reconnect message")
if wgData.PublicKey != "" { if wgData.PublicKey != "" {
telemetry.IncReconnect(context.Background(), wgData.PublicKey, "server", telemetry.ReasonServerRequest) telemetry.IncReconnect(ctx, wgData.PublicKey, "server", telemetry.ReasonServerRequest)
} }
// Close the WireGuard device and TUN // Close the WireGuard device and TUN
@@ -767,7 +776,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) { client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) {
logger.Info("Received termination message") logger.Info("Received termination message")
if wgData.PublicKey != "" { if wgData.PublicKey != "" {
telemetry.IncReconnect(context.Background(), wgData.PublicKey, "server", telemetry.ReasonServerRequest) telemetry.IncReconnect(ctx, wgData.PublicKey, "server", telemetry.ReasonServerRequest)
} }
// Close the WireGuard device and TUN // Close the WireGuard device and TUN
@@ -837,7 +846,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
}, },
} }
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"pingResults": pingResults, "pingResults": pingResults,
"newtVersion": newtVersion, "newtVersion": newtVersion,
@@ -940,7 +949,7 @@ stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
} }
// Send the ping results to the cloud for selection // Send the ping results to the cloud for selection
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"pingResults": pingResults, "pingResults": pingResults,
"newtVersion": newtVersion, "newtVersion": newtVersion,

View File

@@ -2,6 +2,7 @@ package proxy
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -97,6 +98,32 @@ func (cw *countingWriter) Write(p []byte) (int, error) {
return n, err return n, err
} }
func classifyProxyError(err error) string {
if err == nil {
return ""
}
if errors.Is(err, net.ErrClosed) {
return "closed"
}
if ne, ok := err.(net.Error); ok {
if ne.Timeout() {
return "timeout"
}
if ne.Temporary() {
return "temporary"
}
}
msg := strings.ToLower(err.Error())
switch {
case strings.Contains(msg, "refused"):
return "refused"
case strings.Contains(msg, "reset"):
return "reset"
default:
return "io_error"
}
}
// NewProxyManager creates a new proxy manager instance // NewProxyManager creates a new proxy manager instance
func NewProxyManager(tnet *netstack.Net) *ProxyManager { func NewProxyManager(tnet *netstack.Net) *ProxyManager {
return &ProxyManager{ return &ProxyManager{
@@ -467,72 +494,69 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
// Check if we're shutting down or the listener was closed telemetry.IncProxyAccept(context.Background(), pm.currentTunnelID, "tcp", "failure", classifyProxyError(err))
if !pm.running { if !pm.running {
return return
} }
// Check for specific network errors that indicate the listener is closed
if ne, ok := err.(net.Error); ok && !ne.Temporary() { if ne, ok := err.(net.Error); ok && !ne.Temporary() {
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr()) logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
return return
} }
logger.Error("Error accepting TCP connection: %v", err) logger.Error("Error accepting TCP connection: %v", err)
// Don't hammer the CPU if we hit a temporary error
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
continue continue
} }
// Count sessions only once per accepted TCP connection tunnelID := pm.currentTunnelID
if pm.currentTunnelID != "" { telemetry.IncProxyAccept(context.Background(), tunnelID, "tcp", "success", "")
state.Global().IncSessions(pm.currentTunnelID) if tunnelID != "" {
if e := pm.getEntry(pm.currentTunnelID); e != nil { state.Global().IncSessions(tunnelID)
if e := pm.getEntry(tunnelID); e != nil {
e.activeTCP.Add(1) e.activeTCP.Add(1)
} }
} }
go func() { go func(tunnelID string, accepted net.Conn) {
connStart := time.Now()
target, err := net.Dial("tcp", targetAddr) target, err := net.Dial("tcp", targetAddr)
if err != nil { if err != nil {
logger.Error("Error connecting to target: %v", err) logger.Error("Error connecting to target: %v", err)
conn.Close() accepted.Close()
telemetry.IncProxyAccept(context.Background(), tunnelID, "tcp", "failure", classifyProxyError(err))
telemetry.ObserveProxyConnectionDuration(context.Background(), tunnelID, "tcp", "failure", time.Since(connStart).Seconds())
return return
} }
// already incremented on accept entry := pm.getEntry(tunnelID)
if entry == nil {
// Create a WaitGroup to ensure both copy operations complete entry = &tunnelEntry{}
}
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
// client -> target (direction=in) go func(ent *tunnelEntry) {
go func() {
defer wg.Done() defer wg.Done()
e := pm.getEntry(pm.currentTunnelID) cw := &countingWriter{ctx: context.Background(), w: target, set: ent.attrInTCP, pm: pm, ent: ent, out: false, proto: "tcp"}
cw := &countingWriter{ctx: context.Background(), w: target, set: e.attrInTCP, pm: pm, ent: e, out: false, proto: "tcp"} _, _ = io.Copy(cw, accepted)
_, _ = io.Copy(cw, conn)
_ = target.Close() _ = target.Close()
}() }(entry)
// target -> client (direction=out) go func(ent *tunnelEntry) {
go func() {
defer wg.Done() defer wg.Done()
e := pm.getEntry(pm.currentTunnelID) cw := &countingWriter{ctx: context.Background(), w: accepted, set: ent.attrOutTCP, pm: pm, ent: ent, out: true, proto: "tcp"}
cw := &countingWriter{ctx: context.Background(), w: conn, set: e.attrOutTCP, pm: pm, ent: e, out: true, proto: "tcp"}
_, _ = io.Copy(cw, target) _, _ = io.Copy(cw, target)
_ = conn.Close() _ = accepted.Close()
}() }(entry)
// Wait for both copies to complete then session -1
wg.Wait() wg.Wait()
if pm.currentTunnelID != "" { if tunnelID != "" {
state.Global().DecSessions(pm.currentTunnelID) state.Global().DecSessions(tunnelID)
if e := pm.getEntry(pm.currentTunnelID); e != nil { if e := pm.getEntry(tunnelID); e != nil {
e.activeTCP.Add(-1) e.activeTCP.Add(-1)
} }
} }
}() telemetry.ObserveProxyConnectionDuration(context.Background(), tunnelID, "tcp", "success", time.Since(connStart).Seconds())
}(tunnelID, conn)
} }
} }
@@ -595,16 +619,20 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr) targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil { if err != nil {
logger.Error("Error resolving target address: %v", err) logger.Error("Error resolving target address: %v", err)
telemetry.IncProxyAccept(context.Background(), pm.currentTunnelID, "udp", "failure", "resolve")
continue continue
} }
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr) targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
if err != nil { if err != nil {
logger.Error("Error connecting to target: %v", err) logger.Error("Error connecting to target: %v", err)
telemetry.IncProxyAccept(context.Background(), pm.currentTunnelID, "udp", "failure", classifyProxyError(err))
continue continue
} }
tunnelID := pm.currentTunnelID
telemetry.IncProxyAccept(context.Background(), tunnelID, "udp", "success", "")
// Only increment activeUDP after a successful DialUDP // Only increment activeUDP after a successful DialUDP
if e := pm.getEntry(pm.currentTunnelID); e != nil { if e := pm.getEntry(tunnelID); e != nil {
e.activeUDP.Add(1) e.activeUDP.Add(1)
} }
@@ -612,18 +640,21 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
clientConns[clientKey] = targetConn clientConns[clientKey] = targetConn
clientsMutex.Unlock() clientsMutex.Unlock()
go func(clientKey string, targetConn *net.UDPConn, remoteAddr net.Addr) { go func(clientKey string, targetConn *net.UDPConn, remoteAddr net.Addr, tunnelID string) {
start := time.Now()
result := "success"
defer func() { defer func() {
// Always clean up when this goroutine exits // Always clean up when this goroutine exits
clientsMutex.Lock() clientsMutex.Lock()
if storedConn, exists := clientConns[clientKey]; exists && storedConn == targetConn { if storedConn, exists := clientConns[clientKey]; exists && storedConn == targetConn {
delete(clientConns, clientKey) delete(clientConns, clientKey)
targetConn.Close() targetConn.Close()
if e := pm.getEntry(pm.currentTunnelID); e != nil { if e := pm.getEntry(tunnelID); e != nil {
e.activeUDP.Add(-1) e.activeUDP.Add(-1)
} }
} }
clientsMutex.Unlock() clientsMutex.Unlock()
telemetry.ObserveProxyConnectionDuration(context.Background(), tunnelID, "udp", result, time.Since(start).Seconds())
}() }()
buffer := make([]byte, 65507) buffer := make([]byte, 65507)
@@ -631,6 +662,7 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
n, _, err := targetConn.ReadFromUDP(buffer) n, _, err := targetConn.ReadFromUDP(buffer)
if err != nil { if err != nil {
logger.Error("Error reading from target: %v", err) logger.Error("Error reading from target: %v", err)
result = "failure"
return // defer will handle cleanup return // defer will handle cleanup
} }
@@ -651,10 +683,11 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
if err != nil { if err != nil {
logger.Error("Error writing to client: %v", err) logger.Error("Error writing to client: %v", err)
telemetry.IncProxyDrops(context.Background(), pm.currentTunnelID, "udp") telemetry.IncProxyDrops(context.Background(), pm.currentTunnelID, "udp")
result = "failure"
return // defer will handle cleanup return // defer will handle cleanup
} }
} }
}(clientKey, targetConn, remoteAddr) }(clientKey, targetConn, remoteAddr, tunnelID)
} }
written, err := targetConn.Write(buffer[:n]) written, err := targetConn.Write(buffer[:n])

View File

@@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@@ -42,6 +43,8 @@ type Client struct {
writeMux sync.Mutex writeMux sync.Mutex
clientType string // Type of client (e.g., "newt", "olm") clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig tlsConfig TLSConfig
metricsCtxMu sync.RWMutex
metricsCtx context.Context
} }
type ClientOption func(*Client) type ClientOption func(*Client)
@@ -85,6 +88,26 @@ func (c *Client) OnTokenUpdate(callback func(token string)) {
c.onTokenUpdate = callback c.onTokenUpdate = callback
} }
func (c *Client) metricsContext() context.Context {
c.metricsCtxMu.RLock()
defer c.metricsCtxMu.RUnlock()
if c.metricsCtx != nil {
return c.metricsCtx
}
return context.Background()
}
func (c *Client) setMetricsContext(ctx context.Context) {
c.metricsCtxMu.Lock()
c.metricsCtx = ctx
c.metricsCtxMu.Unlock()
}
// MetricsContext exposes the context used for telemetry emission when a connection is active.
func (c *Client) MetricsContext() context.Context {
return c.metricsContext()
}
// NewClient creates a new websocket client // NewClient creates a new websocket client
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{ config := &Config{
@@ -177,7 +200,7 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
if err := c.conn.WriteJSON(msg); err != nil { if err := c.conn.WriteJSON(msg); err != nil {
return err return err
} }
telemetry.IncWSMessage(context.Background(), "out", "text") telemetry.IncWSMessage(c.metricsContext(), "out", "text")
return nil return nil
} }
@@ -273,8 +296,12 @@ func (c *Client) getToken() (string, error) {
return "", fmt.Errorf("failed to marshal token request data: %w", err) return "", fmt.Errorf("failed to marshal token request data: %w", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create a new request // Create a new request
req, err := http.NewRequest( req, err := http.NewRequestWithContext(
ctx,
"POST", "POST",
baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token", baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token",
bytes.NewBuffer(jsonData), bytes.NewBuffer(jsonData),
@@ -296,7 +323,8 @@ func (c *Client) getToken() (string, error) {
} }
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
telemetry.IncConnError(context.Background(), "auth", classifyConnError(err)) telemetry.IncConnAttempt(ctx, "auth", "failure")
telemetry.IncConnError(ctx, "auth", classifyConnError(err))
return "", fmt.Errorf("failed to request new token: %w", err) return "", fmt.Errorf("failed to request new token: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -304,15 +332,15 @@ func (c *Client) getToken() (string, error) {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
telemetry.IncConnAttempt(context.Background(), "auth", "failure") telemetry.IncConnAttempt(ctx, "auth", "failure")
etype := "io_error" etype := "io_error"
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
etype = "auth_failed" etype = "auth_failed"
} }
telemetry.IncConnError(context.Background(), "auth", etype) telemetry.IncConnError(ctx, "auth", etype)
// Reconnect reason mapping for auth failures // Reconnect reason mapping for auth failures
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
telemetry.IncReconnect(context.Background(), c.config.ID, "client", telemetry.ReasonAuthError) telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonAuthError)
} }
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
} }
@@ -332,7 +360,7 @@ func (c *Client) getToken() (string, error) {
} }
logger.Debug("Received token: %s", tokenResp.Data.Token) logger.Debug("Received token: %s", tokenResp.Data.Token)
telemetry.IncConnAttempt(context.Background(), "auth", "success") telemetry.IncConnAttempt(ctx, "auth", "success")
return tokenResp.Data.Token, nil return tokenResp.Data.Token, nil
} }
@@ -357,6 +385,30 @@ func classifyConnError(err error) string {
} }
} }
func classifyWSDisconnect(err error) (result, reason string) {
if err == nil {
return "success", "normal"
}
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
return "success", "normal"
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
return "error", "timeout"
}
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
return "error", "unexpected_close"
}
msg := strings.ToLower(err.Error())
switch {
case strings.Contains(msg, "eof"):
return "error", "eof"
case strings.Contains(msg, "reset"):
return "error", "connection_reset"
default:
return "error", "read_error"
}
}
func (c *Client) connectWithRetry() { func (c *Client) connectWithRetry() {
for { for {
select { select {
@@ -375,13 +427,13 @@ func (c *Client) connectWithRetry() {
} }
func (c *Client) establishConnection() error { func (c *Client) establishConnection() error {
ctx := context.Background()
// Get token for authentication // Get token for authentication
token, err := c.getToken() token, err := c.getToken()
if err != nil { if err != nil {
// telemetry: connection attempt failed before dialing telemetry.IncConnAttempt(ctx, "websocket", "failure")
// site_id isn't globally available here; use client ID as site_id (low cardinality) telemetry.IncConnError(ctx, "websocket", classifyConnError(err))
telemetry.IncConnAttempt(context.Background(), "websocket", "failure")
telemetry.IncConnError(context.Background(), "websocket", classifyConnError(err))
return fmt.Errorf("failed to get token: %w", err) return fmt.Errorf("failed to get token: %w", err)
} }
@@ -416,7 +468,7 @@ func (c *Client) establishConnection() error {
// Connect to WebSocket (optional span) // Connect to WebSocket (optional span)
tr := otel.Tracer("newt") tr := otel.Tracer("newt")
spanCtx, span := tr.Start(context.Background(), "ws.connect") ctx, span := tr.Start(ctx, "ws.connect")
defer span.End() defer span.End()
start := time.Now() start := time.Now()
@@ -441,38 +493,40 @@ func (c *Client) establishConnection() error {
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
} }
conn, _, err := dialer.DialContext(spanCtx, u.String(), nil) conn, _, err := dialer.DialContext(ctx, u.String(), nil)
lat := time.Since(start).Seconds() lat := time.Since(start).Seconds()
if err != nil { if err != nil {
telemetry.IncConnAttempt(context.Background(), "websocket", "failure") telemetry.IncConnAttempt(ctx, "websocket", "failure")
etype := classifyConnError(err) etype := classifyConnError(err)
telemetry.IncConnError(context.Background(), "websocket", etype) telemetry.IncConnError(ctx, "websocket", etype)
telemetry.ObserveWSConnectLatency(context.Background(), lat, "failure", etype) telemetry.ObserveWSConnectLatency(ctx, lat, "failure", etype)
// Map handshake-related errors to reconnect reasons where appropriate // Map handshake-related errors to reconnect reasons where appropriate
if etype == "tls_handshake" { if etype == "tls_handshake" {
telemetry.IncReconnect(context.Background(), c.config.ID, "client", telemetry.ReasonHandshakeError) telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonHandshakeError)
} else if etype == "dial_timeout" { } else if etype == "dial_timeout" {
telemetry.IncReconnect(context.Background(), c.config.ID, "client", telemetry.ReasonTimeout) telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonTimeout)
} else { } else {
telemetry.IncReconnect(context.Background(), c.config.ID, "client", telemetry.ReasonError) telemetry.IncReconnect(ctx, c.config.ID, "client", telemetry.ReasonError)
} }
return fmt.Errorf("failed to connect to WebSocket: %w", err) return fmt.Errorf("failed to connect to WebSocket: %w", err)
} }
telemetry.IncConnAttempt(context.Background(), "websocket", "success") telemetry.IncConnAttempt(ctx, "websocket", "success")
telemetry.ObserveWSConnectLatency(context.Background(), lat, "success", "") telemetry.ObserveWSConnectLatency(ctx, lat, "success", "")
c.conn = conn c.conn = conn
c.setConnected(true) c.setConnected(true)
c.setMetricsContext(ctx)
sessionStart := time.Now()
// Wire up pong handler for metrics // Wire up pong handler for metrics
c.conn.SetPongHandler(func(appData string) error { c.conn.SetPongHandler(func(appData string) error {
telemetry.IncWSMessage(context.Background(), "in", "pong") telemetry.IncWSMessage(c.metricsContext(), "in", "pong")
return nil return nil
}) })
// Start the ping monitor // Start the ping monitor
go c.pingMonitor() go c.pingMonitor()
// Start the read pump with disconnect detection // Start the read pump with disconnect detection
go c.readPumpWithDisconnectDetection() go c.readPumpWithDisconnectDetection(sessionStart)
if c.onConnect != nil { if c.onConnect != nil {
err := c.saveConfig() err := c.saveConfig()
@@ -566,7 +620,7 @@ func (c *Client) pingMonitor() {
c.writeMux.Lock() c.writeMux.Lock()
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout)) err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
if err == nil { if err == nil {
telemetry.IncWSMessage(context.Background(), "out", "ping") telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
} }
c.writeMux.Unlock() c.writeMux.Unlock()
if err != nil { if err != nil {
@@ -577,6 +631,7 @@ func (c *Client) pingMonitor() {
return return
default: default:
logger.Error("Ping failed: %v", err) logger.Error("Ping failed: %v", err)
telemetry.IncWSKeepaliveFailure(c.metricsContext(), "ping_write")
c.reconnect() c.reconnect()
return return
} }
@@ -586,11 +641,19 @@ func (c *Client) pingMonitor() {
} }
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error // readPumpWithDisconnectDetection reads messages and triggers reconnect on error
func (c *Client) readPumpWithDisconnectDetection() { func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
ctx := c.metricsContext()
disconnectReason := "shutdown"
disconnectResult := "success"
defer func() { defer func() {
if c.conn != nil { if c.conn != nil {
c.conn.Close() c.conn.Close()
} }
if !started.IsZero() {
telemetry.ObserveWSSessionDuration(ctx, time.Since(started).Seconds(), disconnectResult)
}
telemetry.IncWSDisconnect(ctx, disconnectReason, disconnectResult)
// Only attempt reconnect if we're not shutting down // Only attempt reconnect if we're not shutting down
select { select {
case <-c.done: case <-c.done:
@@ -604,12 +667,14 @@ func (c *Client) readPumpWithDisconnectDetection() {
for { for {
select { select {
case <-c.done: case <-c.done:
disconnectReason = "shutdown"
disconnectResult = "success"
return return
default: default:
var msg WSMessage var msg WSMessage
err := c.conn.ReadJSON(&msg) err := c.conn.ReadJSON(&msg)
if err == nil { if err == nil {
telemetry.IncWSMessage(context.Background(), "in", "text") telemetry.IncWSMessage(c.metricsContext(), "in", "text")
} }
if err != nil { if err != nil {
// Check if we're shutting down before logging error // Check if we're shutting down before logging error
@@ -617,13 +682,18 @@ func (c *Client) readPumpWithDisconnectDetection() {
case <-c.done: case <-c.done:
// Expected during shutdown, don't log as error // Expected during shutdown, don't log as error
logger.Debug("WebSocket connection closed during shutdown") logger.Debug("WebSocket connection closed during shutdown")
disconnectReason = "shutdown"
disconnectResult = "success"
return return
default: default:
// Unexpected error during normal operation // Unexpected error during normal operation
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { disconnectResult, disconnectReason = classifyWSDisconnect(err)
logger.Error("WebSocket read error: %v", err) if disconnectResult == "error" {
} else { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Debug("WebSocket connection closed: %v", err) logger.Error("WebSocket read error: %v", err)
} else {
logger.Debug("WebSocket connection closed: %v", err)
}
} }
return // triggers reconnect via defer return // triggers reconnect via defer
} }

View File

@@ -280,6 +280,15 @@ func (s *WireGuardService) LoadRemoteConfig() error {
} }
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
ctx := context.Background()
if s.client != nil {
ctx = s.client.MetricsContext()
}
result := "success"
defer func() {
telemetry.IncConfigReload(ctx, result)
}()
var config WgConfig var config WgConfig
logger.Debug("Received message: %v", msg) logger.Debug("Received message: %v", msg)
@@ -288,11 +297,13 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Info("Error marshaling data: %v", err) logger.Info("Error marshaling data: %v", err)
result = "failure"
return return
} }
if err := json.Unmarshal(jsonData, &config); err != nil { if err := json.Unmarshal(jsonData, &config); err != nil {
logger.Info("Error unmarshaling target data: %v", err) logger.Info("Error unmarshaling target data: %v", err)
result = "failure"
return return
} }
s.config = config s.config = config
@@ -303,27 +314,28 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
} }
// telemetry: config reload success // telemetry: config reload success
telemetry.IncConfigReload(context.Background(), "success")
// Optional reconnect reason mapping: config change // Optional reconnect reason mapping: config change
if s.serverPubKey != "" { if s.serverPubKey != "" {
telemetry.IncReconnect(context.Background(), s.serverPubKey, "client", telemetry.ReasonConfigChange) telemetry.IncReconnect(ctx, s.serverPubKey, "client", telemetry.ReasonConfigChange)
} }
// Ensure the WireGuard interface and peers are configured // Ensure the WireGuard interface and peers are configured
start := time.Now() start := time.Now()
if err := s.ensureWireguardInterface(config); err != nil { if err := s.ensureWireguardInterface(config); err != nil {
logger.Error("Failed to ensure WireGuard interface: %v", err) logger.Error("Failed to ensure WireGuard interface: %v", err)
telemetry.ObserveConfigApply(context.Background(), "interface", "failure", time.Since(start).Seconds()) telemetry.ObserveConfigApply(ctx, "interface", "failure", time.Since(start).Seconds())
result = "failure"
} else { } else {
telemetry.ObserveConfigApply(context.Background(), "interface", "success", time.Since(start).Seconds()) telemetry.ObserveConfigApply(ctx, "interface", "success", time.Since(start).Seconds())
} }
startPeers := time.Now() startPeers := time.Now()
if err := s.ensureWireguardPeers(config.Peers); err != nil { if err := s.ensureWireguardPeers(config.Peers); err != nil {
logger.Error("Failed to ensure WireGuard peers: %v", err) logger.Error("Failed to ensure WireGuard peers: %v", err)
telemetry.ObserveConfigApply(context.Background(), "peer", "failure", time.Since(startPeers).Seconds()) telemetry.ObserveConfigApply(ctx, "peer", "failure", time.Since(startPeers).Seconds())
result = "failure"
} else { } else {
telemetry.ObserveConfigApply(context.Background(), "peer", "success", time.Since(startPeers).Seconds()) telemetry.ObserveConfigApply(ctx, "peer", "success", time.Since(startPeers).Seconds())
} }
} }