From 3822b1a0657690dffa13f187462b124654fcf5cb Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 19 Dec 2025 16:45:11 -0500 Subject: [PATCH 01/52] Add version and send it down Former-commit-id: 52273a81c8d2498768511768beaefb4c5ac71043 --- websocket/client.go | 41 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/websocket/client.go b/websocket/client.go index 1c5afaf..f620f8a 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -54,8 +54,9 @@ type ExitNode struct { } type WSMessage struct { - Type string `json:"type"` - Data interface{} `json:"data"` + Type string `json:"type"` + Data interface{} `json:"data"` + ConfigVersion int `json:"configVersion,omitempty"` } // this is not json anymore @@ -87,6 +88,8 @@ type Client struct { clientType string // Type of client (e.g., "newt", "olm") tlsConfig TLSConfig configNeedsSave bool // Flag to track if config needs to be saved + configVersion int // Latest config version received from server + configVersionMux sync.RWMutex } type ClientOption func(*Client) @@ -590,8 +593,19 @@ func (c *Client) pingMonitor() { if c.conn == nil { return } + // Send application-level ping with config version + c.configVersionMux.RLock() + configVersion := c.configVersion + c.configVersionMux.RUnlock() + + pingMsg := WSMessage{ + Type: "ping", + Data: map[string]interface{}{}, + ConfigVersion: configVersion, + } + c.writeMux.Lock() - err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout)) + err := c.conn.WriteJSON(pingMsg) c.writeMux.Unlock() if err != nil { // Check if we're shutting down before logging error and reconnecting @@ -609,6 +623,22 @@ func (c *Client) pingMonitor() { } } +// GetConfigVersion returns the current config version +func (c *Client) GetConfigVersion() int { + c.configVersionMux.RLock() + defer c.configVersionMux.RUnlock() + return c.configVersion +} + +// setConfigVersion updates the config version if the new version is higher +func (c *Client) setConfigVersion(version int) { + c.configVersionMux.Lock() + defer c.configVersionMux.Unlock() + if version > c.configVersion { + c.configVersion = version + } +} + // readPumpWithDisconnectDetection reads messages and triggers reconnect on error func (c *Client) readPumpWithDisconnectDetection() { defer func() { @@ -650,6 +680,11 @@ func (c *Client) readPumpWithDisconnectDetection() { } } + // Update config version from incoming message + if msg.ConfigVersion > 0 { + c.setConfigVersion(msg.ConfigVersion) + } + c.handlersMux.RLock() if handler, ok := c.handlers[msg.Type]; ok { handler(msg) From dde79bb2dc769c86be17ba57349757333683d7fa Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 21 Dec 2025 20:57:20 -0500 Subject: [PATCH 02/52] Fix go mod Former-commit-id: e355d8db5fb9d629a2155640121d265287cf41fe --- go.mod | 45 --------------------------- go.sum | 98 ---------------------------------------------------------- 2 files changed, 143 deletions(-) diff --git a/go.mod b/go.mod index baf9a13..59992a3 100644 --- a/go.mod +++ b/go.mod @@ -16,64 +16,19 @@ require ( ) require ( - github.com/beorn7/perks v1.0.1 // indirect - github.com/cenkalti/backoff/v5 v5.0.3 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/containerd/errdefs v0.3.0 // indirect - github.com/containerd/errdefs/pkg v0.3.0 // indirect - github.com/distribution/reference v0.6.0 // indirect - github.com/docker/docker v28.5.2+incompatible // indirect - github.com/docker/go-connections v0.6.0 // indirect - github.com/docker/go-units v0.4.0 // indirect - github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/go-logr/logr v1.4.3 // indirect - github.com/go-logr/stdr v1.2.2 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/go-cmp v0.7.0 // indirect - github.com/google/uuid v1.6.0 // indirect - github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc // indirect - github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect - github.com/moby/docker-image-spec v1.3.1 // indirect - github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.0 // indirect - github.com/pkg/errors v0.9.1 // indirect - github.com/prometheus/client_golang v1.23.2 // indirect - github.com/prometheus/client_model v0.6.2 // indirect - github.com/prometheus/common v0.66.1 // indirect - github.com/prometheus/otlptranslator v0.0.2 // indirect - github.com/prometheus/procfs v0.17.0 // indirect github.com/vishvananda/netlink v1.3.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect - go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0 // indirect - go.opentelemetry.io/otel v1.38.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 // indirect - go.opentelemetry.io/otel/exporters/prometheus v0.60.0 // indirect - go.opentelemetry.io/otel/metric v1.38.0 // indirect - go.opentelemetry.io/otel/sdk v1.38.0 // indirect - go.opentelemetry.io/otel/sdk/metric v1.38.0 // indirect - go.opentelemetry.io/otel/trace v1.38.0 // indirect - go.opentelemetry.io/proto/otlp v1.7.1 // indirect - go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect golang.org/x/mod v0.30.0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.18.0 // indirect - golang.org/x/text v0.31.0 // indirect golang.org/x/time v0.12.0 // indirect golang.org/x/tools v0.39.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect - google.golang.org/grpc v1.76.0 // indirect - google.golang.org/protobuf v1.36.8 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) replace github.com/fosrl/newt => ../newt diff --git a/go.sum b/go.sum index f37df33..f6ca61a 100644 --- a/go.sum +++ b/go.sum @@ -1,103 +1,19 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= -github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= -github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= -github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/containerd/errdefs v0.3.0 h1:FSZgGOeK4yuT/+DnF07/Olde/q4KBoMsaamhXxIMDp4= -github.com/containerd/errdefs v0.3.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= -github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= -github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk= -github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= -github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= -github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM= -github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= -github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= -github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= -github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= -github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= -github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/fosrl/newt v0.0.0-20251208171729-6d7985689552 h1:51pHUtoqQhYPS9OiBDHLgYV44X/CBzR5J7GuWO3izhU= -github.com/fosrl/newt v0.0.0-20251208171729-6d7985689552/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= -github.com/fosrl/newt v0.0.0-20251216233525-ff7fe1275b26 h1:ocuDvo6/bgoVByu8yhCnBVEhaQGwkilN9HUIPw00yYI= -github.com/fosrl/newt v0.0.0-20251216233525-ff7fe1275b26/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= -github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= -github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= -github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc h1:GN2Lv3MGO7AS6PrRoT6yV5+wkrOpcszoIsO4+4ds248= -github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= -github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= -github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= -github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= -github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= -github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= -github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= -github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= -github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= -github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= -github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= -github.com/prometheus/otlptranslator v0.0.2 h1:+1CdeLVrRQ6Psmhnobldo0kTp96Rj80DRXRd5OSnMEQ= -github.com/prometheus/otlptranslator v0.0.2/go.mod h1:P8AwMgdD7XEr6QRUJ2QWLpiAZTgTE2UYgjlu3svompI= -github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= -github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg= -go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0 h1:PeBoRj6af6xMI7qCupwFvTbbnd49V7n5YpG6pg8iDYQ= -go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0/go.mod h1:ingqBCtMCe8I4vpz/UVzCW6sxoqgZB37nao91mLQ3Bw= -go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= -go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= -go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 h1:vl9obrcoWVKp/lwl8tRE33853I8Xru9HFbw/skNeLs8= -go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0/go.mod h1:GAXRxmLJcVM3u22IjTg74zWBrRCKq8BnOqUVLodpcpw= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk= -go.opentelemetry.io/otel/exporters/prometheus v0.60.0 h1:cGtQxGvZbnrWdC2GyjZi0PDKVSLWP/Jocix3QWfXtbo= -go.opentelemetry.io/otel/exporters/prometheus v0.60.0/go.mod h1:hkd1EekxNo69PTV4OWFGZcKQiIqg0RfuWExcPKFvepk= -go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= -go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= -go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= -go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= -go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= -go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= -go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= -go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= -go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4= -go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE= -go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= -go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= @@ -112,8 +28,6 @@ golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= @@ -126,18 +40,6 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdI golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= -google.golang.org/genproto v0.0.0-20230920204549-e6e6cdab5c13 h1:vlzZttNJGVqTsRFU9AmdnrcO1Znh8Ew9kCD//yjigk0= -google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY= -google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc= -google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A= -google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c= -google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= -google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU= From b76259bc31aba4782c5de8df2ae699f6e5c2587a Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 24 Dec 2025 10:06:25 -0500 Subject: [PATCH 03/52] Add sync message Former-commit-id: d01f180941c6c854f73274c86c281260bd653875 --- go.sum | 3 -- olm/olm.go | 147 ++++++++++++++++++++++++++++++++++++++++++++++++++++ olm/util.go | 43 +++++++++++++++ 3 files changed, 190 insertions(+), 3 deletions(-) diff --git a/go.sum b/go.sum index 7e94e2a..9bf88e2 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,7 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -<<<<<<< HEAD -======= github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01 h1:VpuI42l4enih//6IFFQDln/B7WukfMePxIRIpXsNe/0= github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= ->>>>>>> dev github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= diff --git a/olm/olm.go b/olm/olm.go index f84ee4f..4cbb391 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -453,6 +453,153 @@ func StartTunnel(config TunnelConfig) { logger.Info("WireGuard device created.") }) + // Handler for syncing peer configuration - reconciles expected state with actual state + olm.RegisterHandler("olm/sync", func(msg websocket.WSMessage) { + logger.Debug("Received sync message: %v", msg.Data) + + if !connected { + logger.Warn("Not connected, ignoring sync request") + return + } + + if peerManager == nil { + logger.Warn("Peer manager not initialized, ignoring sync request") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling sync data: %v", err) + return + } + + var wgData WgData + if err := json.Unmarshal(jsonData, &wgData); err != nil { + logger.Error("Error unmarshaling sync data: %v", err) + return + } + + // Build a map of expected peers from the incoming data + expectedPeers := make(map[int]peers.SiteConfig) + for _, site := range wgData.Sites { + expectedPeers[site.SiteId] = site + } + + // Get all current peers + currentPeers := peerManager.GetAllPeers() + currentPeerMap := make(map[int]peers.SiteConfig) + for _, peer := range currentPeers { + currentPeerMap[peer.SiteId] = peer + } + + // Find peers to remove (in current but not in expected) + for siteId := range currentPeerMap { + if _, exists := expectedPeers[siteId]; !exists { + logger.Info("Sync: Removing peer for site %d (no longer in expected config)", siteId) + if err := peerManager.RemovePeer(siteId); err != nil { + logger.Error("Sync: Failed to remove peer %d: %v", siteId, err) + } else { + // Remove any exit nodes associated with this peer from hole punching + if holePunchManager != nil { + removed := holePunchManager.RemoveExitNodesByPeer(siteId) + if removed > 0 { + logger.Info("Sync: Removed %d exit nodes associated with peer %d from hole punch rotation", removed, siteId) + } + } + } + } + } + + // Find peers to add (in expected but not in current) and peers to update + for siteId, expectedSite := range expectedPeers { + if _, exists := currentPeerMap[siteId]; !exists { + // New peer - add it using the add flow (with holepunch) + logger.Info("Sync: Adding new peer for site %d", siteId) + + // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it + holePunchManager.TriggerHolePunch() + + // TODO: do we need to send the message to the cloud to add the peer that way? + if err := peerManager.AddPeer(expectedSite); err != nil { + logger.Error("Sync: Failed to add peer %d: %v", siteId, err) + } else { + logger.Info("Sync: Successfully added peer for site %d", siteId) + } + } else { + // Existing peer - check if update is needed + currentSite := currentPeerMap[siteId] + needsUpdate := false + + // Check if any fields have changed + if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint { + needsUpdate = true + } + if expectedSite.RelayEndpoint != "" && expectedSite.RelayEndpoint != currentSite.RelayEndpoint { + needsUpdate = true + } + if expectedSite.PublicKey != "" && expectedSite.PublicKey != currentSite.PublicKey { + needsUpdate = true + } + if expectedSite.ServerIP != "" && expectedSite.ServerIP != currentSite.ServerIP { + needsUpdate = true + } + if expectedSite.ServerPort != 0 && expectedSite.ServerPort != currentSite.ServerPort { + needsUpdate = true + } + // Check remote subnets + if expectedSite.RemoteSubnets != nil && !slicesEqual(expectedSite.RemoteSubnets, currentSite.RemoteSubnets) { + needsUpdate = true + } + // Check aliases + if expectedSite.Aliases != nil && !aliasesEqual(expectedSite.Aliases, currentSite.Aliases) { + needsUpdate = true + } + + if needsUpdate { + logger.Info("Sync: Updating peer for site %d", siteId) + + // Merge expected data with current data + siteConfig := currentSite + if expectedSite.Endpoint != "" { + siteConfig.Endpoint = expectedSite.Endpoint + } + if expectedSite.RelayEndpoint != "" { + siteConfig.RelayEndpoint = expectedSite.RelayEndpoint + } + if expectedSite.PublicKey != "" { + siteConfig.PublicKey = expectedSite.PublicKey + } + if expectedSite.ServerIP != "" { + siteConfig.ServerIP = expectedSite.ServerIP + } + if expectedSite.ServerPort != 0 { + siteConfig.ServerPort = expectedSite.ServerPort + } + if expectedSite.RemoteSubnets != nil { + siteConfig.RemoteSubnets = expectedSite.RemoteSubnets + } + if expectedSite.Aliases != nil { + siteConfig.Aliases = expectedSite.Aliases + } + + if err := peerManager.UpdatePeer(siteConfig); err != nil { + logger.Error("Sync: Failed to update peer %d: %v", siteId, err) + } else { + // If the endpoint changed, trigger holepunch to refresh NAT mappings + if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint { + logger.Info("Sync: Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", siteId) + holePunchManager.TriggerHolePunch() + holePunchManager.ResetInterval() + } + logger.Info("Sync: Successfully updated peer for site %d", siteId) + } + } + } + } + + logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers)) + }) + olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { logger.Debug("Received update-peer message: %v", msg.Data) diff --git a/olm/util.go b/olm/util.go index 6bfd171..d138755 100644 --- a/olm/util.go +++ b/olm/util.go @@ -5,6 +5,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/network" + "github.com/fosrl/olm/peers" "github.com/fosrl/olm/websocket" ) @@ -53,3 +54,45 @@ func GetNetworkSettingsJSON() (string, error) { func GetNetworkSettingsIncrementor() int { return network.GetIncrementor() } + +// slicesEqual compares two string slices for equality (order-independent) +func slicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + // Create a map to count occurrences in slice a + counts := make(map[string]int) + for _, v := range a { + counts[v]++ + } + // Check if slice b has the same elements + for _, v := range b { + counts[v]-- + if counts[v] < 0 { + return false + } + } + return true +} + +// aliasesEqual compares two Alias slices for equality (order-independent) +func aliasesEqual(a, b []peers.Alias) bool { + if len(a) != len(b) { + return false + } + // Create a map to count occurrences in slice a (using alias+address as key) + counts := make(map[string]int) + for _, v := range a { + key := v.Alias + "|" + v.AliasAddress + counts[key]++ + } + // Check if slice b has the same elements + for _, v := range b { + key := v.Alias + "|" + v.AliasAddress + counts[key]-- + if counts[key] < 0 { + return false + } + } + return true +} From 148f5fde23ee2f3aff8cfbb452f99451bdf16305 Mon Sep 17 00:00:00 2001 From: Varun Narravula Date: Tue, 23 Dec 2025 15:33:04 -0800 Subject: [PATCH 04/52] fix(ci): add back missing docker build local image rule Former-commit-id: 6d2afb4c72f7956ccb9509e8aed018636070d1d7 --- .github/workflows/test.yml | 2 +- Makefile | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2349f3a..6fe7514 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,4 +22,4 @@ jobs: run: make go-build-release - name: Build Docker image - run: make docker-build-release + run: make docker-build diff --git a/Makefile b/Makefile index 8eed5c2..55ebf81 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,9 @@ all: local local: CGO_ENABLED=0 go build -o ./bin/olm +docker-build: + docker build -t fosrl/olm:latest . + docker-build-release: @if [ -z "$(tag)" ]; then \ echo "Error: tag is required. Usage: make docker-build-release tag="; \ From f8dc1342103a74fda006c05104eb03b5373acf9e Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Mon, 29 Dec 2025 17:28:12 -0500 Subject: [PATCH 05/52] add content-length header to status payload Former-commit-id: 8152d4133f1ae85b2632c48983aeb3ea68f0fd2a --- api/api.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/api/api.go b/api/api.go index 787f958..a7e2f24 100644 --- a/api/api.go +++ b/api/api.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "strconv" "sync" "time" @@ -358,7 +359,6 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { } s.statusMu.RLock() - defer s.statusMu.RUnlock() resp := StatusResponse{ Connected: s.isConnected, @@ -371,8 +371,18 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { NetworkSettings: network.GetSettings(), } + s.statusMu.RUnlock() + + data, err := json.Marshal(resp) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + w.Header().Set("Content-Length", strconv.Itoa(len(data))) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(data) } // handleHealth handles the /health endpoint From 28910ce1880dc97a18be9ab5535cfb6e89db28a5 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 29 Dec 2025 17:49:58 -0500 Subject: [PATCH 06/52] Add stub Former-commit-id: ece4239aaa70318a0246c0b2e17b4e3e8d306e7d --- dns/override/dns_override_android.go | 18 ++++++++++++++++++ dns/override/dns_override_ios.go | 17 +++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 dns/override/dns_override_android.go create mode 100644 dns/override/dns_override_ios.go diff --git a/dns/override/dns_override_android.go b/dns/override/dns_override_android.go new file mode 100644 index 0000000..af1d946 --- /dev/null +++ b/dns/override/dns_override_android.go @@ -0,0 +1,18 @@ +//go:build android + +package olm + +import ( + "github.com/fosrl/olm/dns" +) + +// SetupDNSOverride is a no-op on Android +// Android handles DNS through the VpnService API at the Java/Kotlin layer +func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { + return nil +} + +// RestoreDNSOverride is a no-op on Android +func RestoreDNSOverride() error { + return nil +} \ No newline at end of file diff --git a/dns/override/dns_override_ios.go b/dns/override/dns_override_ios.go new file mode 100644 index 0000000..109d471 --- /dev/null +++ b/dns/override/dns_override_ios.go @@ -0,0 +1,17 @@ +//go:build ios + +package olm + +import ( + "github.com/fosrl/olm/dns" +) + +// SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system +func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { + return nil +} + +// RestoreDNSOverride is a no-op on iOS as DNS configuration is handled by the system +func RestoreDNSOverride() error { + return nil +} \ No newline at end of file From c56696bab1ce19dc67563f4915521a5e82b60c89 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 30 Dec 2025 16:59:36 -0500 Subject: [PATCH 07/52] Use a different method on android Former-commit-id: adf4c21f7b280f50e5356325b202be2e554d9333 --- api/api.go | 12 ++++++++++++ olm/olm.go | 12 ++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/api/api.go b/api/api.go index a7e2f24..91d9f37 100644 --- a/api/api.go +++ b/api/api.go @@ -102,6 +102,14 @@ func NewAPISocket(socketPath string) *API { return s } +func NewAPIStub() *API { + s := &API{ + peerStatuses: make(map[int]*PeerStatus), + } + + return s +} + // SetHandlers sets the callback functions for handling API requests func (s *API) SetHandlers( onConnect func(ConnectionRequest) error, @@ -117,6 +125,10 @@ func (s *API) SetHandlers( // Start starts the HTTP server func (s *API) Start() error { + if s.socketPath == "" && s.addr == "" { + return fmt.Errorf("either socketPath or addr must be provided to start the API server") + } + mux := http.NewServeMux() mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) diff --git a/olm/olm.go b/olm/olm.go index f84ee4f..9cc1f51 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -111,6 +111,9 @@ func Init(ctx context.Context, config GlobalConfig) { apiServer = api.NewAPI(config.HTTPAddr) } else if config.SocketPath != "" { apiServer = api.NewAPISocket(config.SocketPath) + } else { + // this is so is not null but it cant be started without either the socket path or http addr + apiServer = api.NewAPIStub() } apiServer.SetVersion(config.Version) @@ -304,7 +307,12 @@ func StartTunnel(config TunnelConfig) { tdev, err = func() (tun.Device, error) { if config.FileDescriptorTun != 0 { - return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) + if runtime.GOOS == "android" { // otherwise we get a permission denied + theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(config.FileDescriptorTun)) + return theTun, err + } else { + return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) + } } var ifName = interfaceName if runtime.GOOS == "darwin" { // this is if we dont pass a fd @@ -811,7 +819,7 @@ func StartTunnel(config TunnelConfig) { Endpoint: handshakeData.ExitNode.Endpoint, RelayPort: relayPort, PublicKey: handshakeData.ExitNode.PublicKey, - SiteIds: []int{siteId}, + SiteIds: []int{siteId}, } added := holePunchManager.AddExitNode(exitNode) From cce87424906e6919355a7c1e7de0bd7f57afd53c Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 30 Dec 2025 21:38:07 -0500 Subject: [PATCH 08/52] Try to make the tun replacable Former-commit-id: 6be095888755c22f8cc6621688e75f7c040eaf57 --- device/middle_device.go | 640 +++++++++++++++++++++++++++++----------- device/tun_unix.go | 6 + dns/dns_proxy.go | 11 +- olm/olm.go | 53 +++- 4 files changed, 517 insertions(+), 193 deletions(-) diff --git a/device/middle_device.go b/device/middle_device.go index b031871..2a5d9b9 100644 --- a/device/middle_device.go +++ b/device/middle_device.go @@ -1,9 +1,12 @@ package device import ( + "io" "net/netip" "os" "sync" + "sync/atomic" + "time" "github.com/fosrl/newt/logger" "golang.zx2c4.com/wireguard/tun" @@ -18,14 +21,68 @@ type FilterRule struct { Handler PacketHandler } -// MiddleDevice wraps a TUN device with packet filtering capabilities -type MiddleDevice struct { +// closeAwareDevice wraps a tun.Device along with a flag +// indicating whether its Close method was called. +type closeAwareDevice struct { + isClosed atomic.Bool tun.Device - rules []FilterRule - mutex sync.RWMutex - readCh chan readResult - injectCh chan []byte - closed chan struct{} + closeEventCh chan struct{} + wg sync.WaitGroup + closeOnce sync.Once +} + +func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice { + return &closeAwareDevice{ + Device: tunDevice, + isClosed: atomic.Bool{}, + closeEventCh: make(chan struct{}), + } +} + +// redirectEvents redirects the Events() method of the underlying tun.Device +// to the given channel. +func (c *closeAwareDevice) redirectEvents(out chan tun.Event) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + for { + select { + case ev, ok := <-c.Device.Events(): + if !ok { + return + } + + if ev == tun.EventDown { + continue + } + + select { + case out <- ev: + case <-c.closeEventCh: + return + } + case <-c.closeEventCh: + return + } + } + }() +} + +// Close calls the underlying Device's Close method +// after setting isClosed to true. +func (c *closeAwareDevice) Close() (err error) { + c.closeOnce.Do(func() { + c.isClosed.Store(true) + close(c.closeEventCh) + err = c.Device.Close() + c.wg.Wait() + }) + + return err +} + +func (c *closeAwareDevice) IsClosed() bool { + return c.isClosed.Load() } type readResult struct { @@ -36,58 +93,124 @@ type readResult struct { err error } +// MiddleDevice wraps a TUN device with packet filtering capabilities +// and supports swapping the underlying device. +type MiddleDevice struct { + devices []*closeAwareDevice + mu sync.Mutex + cond *sync.Cond + rules []FilterRule + rulesMutex sync.RWMutex + readCh chan readResult + injectCh chan []byte + closed atomic.Bool + events chan tun.Event +} + // NewMiddleDevice creates a new filtered TUN device wrapper func NewMiddleDevice(device tun.Device) *MiddleDevice { d := &MiddleDevice{ - Device: device, + devices: make([]*closeAwareDevice, 0), rules: make([]FilterRule, 0), - readCh: make(chan readResult), + readCh: make(chan readResult, 16), injectCh: make(chan []byte, 100), - closed: make(chan struct{}), + events: make(chan tun.Event, 16), } - go d.pump() + d.cond = sync.NewCond(&d.mu) + + if device != nil { + d.AddDevice(device) + } + return d } -func (d *MiddleDevice) pump() { +// AddDevice adds a new underlying TUN device, closing any previous one +func (d *MiddleDevice) AddDevice(device tun.Device) { + d.mu.Lock() + if d.closed.Load() { + d.mu.Unlock() + _ = device.Close() + return + } + + var toClose *closeAwareDevice + if len(d.devices) > 0 { + toClose = d.devices[len(d.devices)-1] + } + + cad := newCloseAwareDevice(device) + cad.redirectEvents(d.events) + + d.devices = []*closeAwareDevice{cad} + + // Start pump for the new device + go d.pump(cad) + + d.cond.Broadcast() + d.mu.Unlock() + + if toClose != nil { + logger.Debug("MiddleDevice: Closing previous device") + if err := toClose.Close(); err != nil { + logger.Debug("MiddleDevice: Error closing previous device: %v", err) + } + } +} + +func (d *MiddleDevice) pump(dev *closeAwareDevice) { const defaultOffset = 16 - batchSize := d.Device.BatchSize() - logger.Debug("MiddleDevice: pump started") + batchSize := dev.BatchSize() + logger.Debug("MiddleDevice: pump started for device") for { - // Check closed first with priority - select { - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel") + // Check if this device is closed + if dev.IsClosed() { + logger.Debug("MiddleDevice: pump exiting, device is closed") + return + } + + // Check if MiddleDevice itself is closed + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed") return - default: } // Allocate buffers for reading - // We allocate new buffers for each read to avoid race conditions - // since we pass them to the channel bufs := make([][]byte, batchSize) sizes := make([]int, batchSize) for i := range bufs { bufs[i] = make([]byte, 2048) // Standard MTU + headroom } - n, err := d.Device.Read(bufs, sizes, defaultOffset) + n, err := dev.Read(bufs, sizes, defaultOffset) - // Check closed again after read returns - select { - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)") + // Check if device was closed during read + if dev.IsClosed() { + logger.Debug("MiddleDevice: pump exiting, device closed during read") return - default: } - // Now try to send the result + // Check if MiddleDevice was closed during read + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read") + return + } + + // Try to send the result select { case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)") - return + default: + // Channel full, check if we should exit + if dev.IsClosed() || d.closed.Load() { + return + } + // Try again with blocking + select { + case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: + case <-dev.closeEventCh: + return + } } if err != nil { @@ -99,16 +222,21 @@ func (d *MiddleDevice) pump() { // InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN) func (d *MiddleDevice) InjectOutbound(packet []byte) { + if d.closed.Load() { + return + } select { case d.injectCh <- packet: - case <-d.closed: + default: + // Channel full, drop packet + logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full") } } // AddRule adds a packet filtering rule func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) { - d.mutex.Lock() - defer d.mutex.Unlock() + d.rulesMutex.Lock() + defer d.rulesMutex.Unlock() d.rules = append(d.rules, FilterRule{ DestIP: destIP, Handler: handler, @@ -117,8 +245,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) { // RemoveRule removes all rules for a given destination IP func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { - d.mutex.Lock() - defer d.mutex.Unlock() + d.rulesMutex.Lock() + defer d.rulesMutex.Unlock() newRules := make([]FilterRule, 0, len(d.rules)) for _, rule := range d.rules { if rule.DestIP != destIP { @@ -130,18 +258,113 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { // Close stops the device func (d *MiddleDevice) Close() error { - select { - case <-d.closed: - // Already closed - return nil - default: - logger.Debug("MiddleDevice: Closing, signaling closed channel") - close(d.closed) + if !d.closed.CompareAndSwap(false, true) { + return nil // already closed } - logger.Debug("MiddleDevice: Closing underlying TUN device") - err := d.Device.Close() - logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err) - return err + + d.mu.Lock() + devices := d.devices + d.devices = nil + d.cond.Broadcast() + d.mu.Unlock() + + var lastErr error + logger.Debug("MiddleDevice: Closing %d devices", len(devices)) + for _, device := range devices { + if err := device.Close(); err != nil { + logger.Debug("MiddleDevice: Error closing device: %v", err) + lastErr = err + } + } + + close(d.events) + return lastErr +} + +// Events returns the events channel +func (d *MiddleDevice) Events() <-chan tun.Event { + return d.events +} + +// File returns the underlying file descriptor +func (d *MiddleDevice) File() *os.File { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return nil + } + continue + } + + file := dev.File() + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return file + } +} + +// MTU returns the MTU of the underlying device +func (d *MiddleDevice) MTU() (int, error) { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } + continue + } + + mtu, err := dev.MTU() + if err == nil { + return mtu, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return 0, err + } +} + +// Name returns the name of the underlying device +func (d *MiddleDevice) Name() (string, error) { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return "", io.EOF + } + continue + } + + name, err := dev.Name() + if err == nil { + return name, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return "", err + } +} + +// BatchSize returns the batch size +func (d *MiddleDevice) BatchSize() int { + dev := d.peekLast() + if dev == nil { + return 1 + } + return dev.BatchSize() } // extractDestIP extracts destination IP from packet (fast path) @@ -176,156 +399,231 @@ func extractDestIP(packet []byte) (netip.Addr, bool) { // Read intercepts packets going UP from the TUN device (towards WireGuard) func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { - // Check if already closed first (non-blocking) - select { - case <-d.closed: - logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)") - return 0, os.ErrClosed - default: - } - - // Now block waiting for data - select { - case res := <-d.readCh: - if res.err != nil { - logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err) - return 0, res.err + for { + if d.closed.Load() { + logger.Debug("MiddleDevice: Read returning io.EOF, device closed") + return 0, io.EOF } - // Copy packets from result to provided buffers - count := 0 - for i := 0; i < res.n && i < len(bufs); i++ { - // Handle offset mismatch if necessary - // We assume the pump used defaultOffset (16) - // If caller asks for different offset, we need to shift - src := res.bufs[i] - srcOffset := res.offset - srcSize := res.sizes[i] - - // Calculate where the packet data starts and ends in src - pktData := src[srcOffset : srcOffset+srcSize] - - // Ensure dest buffer is large enough - if len(bufs[i]) < offset+len(pktData) { - continue // Skip if buffer too small + // Wait for a device to be available + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF } - - copy(bufs[i][offset:], pktData) - sizes[i] = len(pktData) - count++ - } - n = count - - case pkt := <-d.injectCh: - if len(bufs) == 0 { - return 0, nil - } - if len(bufs[0]) < offset+len(pkt) { - return 0, nil // Buffer too small - } - copy(bufs[0][offset:], pkt) - sizes[0] = len(pkt) - n = 1 - - case <-d.closed: - logger.Debug("MiddleDevice: Read returning os.ErrClosed") - return 0, os.ErrClosed // Signal that device is closed - } - - d.mutex.RLock() - rules := d.rules - d.mutex.RUnlock() - - if len(rules) == 0 { - return n, nil - } - - // Process packets and filter out handled ones - writeIdx := 0 - for readIdx := 0; readIdx < n; readIdx++ { - packet := bufs[readIdx][offset : offset+sizes[readIdx]] - - destIP, ok := extractDestIP(packet) - if !ok { - // Can't parse, keep packet - if writeIdx != readIdx { - bufs[writeIdx] = bufs[readIdx] - sizes[writeIdx] = sizes[readIdx] - } - writeIdx++ continue } - // Check if packet matches any rule - handled := false - for _, rule := range rules { - if rule.DestIP == destIP { - if rule.Handler(packet) { - // Packet was handled and should be dropped - handled = true - break + // Now block waiting for data from readCh or injectCh + select { + case res := <-d.readCh: + if res.err != nil { + // Check if device was swapped + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err) + return 0, res.err + } + + // Copy packets from result to provided buffers + count := 0 + for i := 0; i < res.n && i < len(bufs); i++ { + src := res.bufs[i] + srcOffset := res.offset + srcSize := res.sizes[i] + + pktData := src[srcOffset : srcOffset+srcSize] + + if len(bufs[i]) < offset+len(pktData) { + continue + } + + copy(bufs[i][offset:], pktData) + sizes[i] = len(pktData) + count++ + } + n = count + + case pkt := <-d.injectCh: + if len(bufs) == 0 { + return 0, nil + } + if len(bufs[0]) < offset+len(pkt) { + return 0, nil + } + copy(bufs[0][offset:], pkt) + sizes[0] = len(pkt) + n = 1 + } + + // Apply filtering rules + d.rulesMutex.RLock() + rules := d.rules + d.rulesMutex.RUnlock() + + if len(rules) == 0 { + return n, nil + } + + // Process packets and filter out handled ones + writeIdx := 0 + for readIdx := 0; readIdx < n; readIdx++ { + packet := bufs[readIdx][offset : offset+sizes[readIdx]] + + destIP, ok := extractDestIP(packet) + if !ok { + if writeIdx != readIdx { + bufs[writeIdx] = bufs[readIdx] + sizes[writeIdx] = sizes[readIdx] + } + writeIdx++ + continue + } + + handled := false + for _, rule := range rules { + if rule.DestIP == destIP { + if rule.Handler(packet) { + handled = true + break + } } } - } - if !handled { - // Keep packet - if writeIdx != readIdx { - bufs[writeIdx] = bufs[readIdx] - sizes[writeIdx] = sizes[readIdx] + if !handled { + if writeIdx != readIdx { + bufs[writeIdx] = bufs[readIdx] + sizes[writeIdx] = sizes[readIdx] + } + writeIdx++ } - writeIdx++ } - } - return writeIdx, err + return writeIdx, nil + } } // Write intercepts packets going DOWN to the TUN device (from WireGuard) func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) { - d.mutex.RLock() - rules := d.rules - d.mutex.RUnlock() + for { + if d.closed.Load() { + return 0, io.EOF + } - if len(rules) == 0 { - return d.Device.Write(bufs, offset) - } - - // Filter packets going down - filteredBufs := make([][]byte, 0, len(bufs)) - for _, buf := range bufs { - if len(buf) <= offset { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } continue } - packet := buf[offset:] - destIP, ok := extractDestIP(packet) - if !ok { - // Can't parse, keep packet - filteredBufs = append(filteredBufs, buf) - continue - } + d.rulesMutex.RLock() + rules := d.rules + d.rulesMutex.RUnlock() - // Check if packet matches any rule - handled := false - for _, rule := range rules { - if rule.DestIP == destIP { - if rule.Handler(packet) { - // Packet was handled and should be dropped - handled = true - break + var filteredBufs [][]byte + if len(rules) == 0 { + filteredBufs = bufs + } else { + filteredBufs = make([][]byte, 0, len(bufs)) + for _, buf := range bufs { + if len(buf) <= offset { + continue + } + + packet := buf[offset:] + destIP, ok := extractDestIP(packet) + if !ok { + filteredBufs = append(filteredBufs, buf) + continue + } + + handled := false + for _, rule := range rules { + if rule.DestIP == destIP { + if rule.Handler(packet) { + handled = true + break + } + } + } + + if !handled { + filteredBufs = append(filteredBufs, buf) } } } - if !handled { - filteredBufs = append(filteredBufs, buf) + if len(filteredBufs) == 0 { + return len(bufs), nil } - } - if len(filteredBufs) == 0 { - return len(bufs), nil // All packets were handled - } + n, err := dev.Write(filteredBufs, offset) + if err == nil { + return n, nil + } - return d.Device.Write(filteredBufs, offset) + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } } + +func (d *MiddleDevice) waitForDevice() bool { + d.mu.Lock() + defer d.mu.Unlock() + + for len(d.devices) == 0 && !d.closed.Load() { + d.cond.Wait() + } + return !d.closed.Load() +} + +func (d *MiddleDevice) peekLast() *closeAwareDevice { + d.mu.Lock() + defer d.mu.Unlock() + + if len(d.devices) == 0 { + return nil + } + + return d.devices[len(d.devices)-1] +} + +// WriteToTun writes packets directly to the underlying TUN device, +// bypassing WireGuard. This is useful for sending packets that should +// appear to come from the TUN interface (e.g., DNS responses from a proxy). +// Unlike Write(), this does not go through packet filtering rules. +func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) { + for { + if d.closed.Load() { + return 0, io.EOF + } + + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } + continue + } + + n, err := dev.Write(bufs, offset) + if err == nil { + return n, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } +} \ No newline at end of file diff --git a/device/tun_unix.go b/device/tun_unix.go index c9bab60..22cec13 100644 --- a/device/tun_unix.go +++ b/device/tun_unix.go @@ -5,6 +5,7 @@ package device import ( "net" "os" + "runtime" "github.com/fosrl/newt/logger" "golang.org/x/sys/unix" @@ -13,6 +14,11 @@ import ( ) func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { + if runtime.GOOS == "android" { // otherwise we get a permission denied + theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd)) + return theTun, err + } + dupTunFd, err := unix.Dup(int(tunFd)) if err != nil { logger.Error("Unable to dup tun fd: %v", err) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 6d56379..748a5a9 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -12,7 +12,6 @@ import ( "github.com/fosrl/newt/util" "github.com/fosrl/olm/device" "github.com/miekg/dns" - "golang.zx2c4.com/wireguard/tun" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" @@ -36,8 +35,7 @@ type DNSProxy struct { upstreamDNS []string tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally mtu int - tunDevice tun.Device // Direct reference to underlying TUN device for responses - middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering + middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes recordStore *DNSRecordStore // Local DNS records // Tunnel DNS fields - for sending queries over WireGuard @@ -53,7 +51,7 @@ type DNSProxy struct { } // NewDNSProxy creates a new DNS proxy -func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) { +func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) { proxyIP, err := PickIPFromSubnet(utilitySubnet) if err != nil { return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err) @@ -68,7 +66,6 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in proxy := &DNSProxy{ proxyIP: proxyIP, mtu: mtu, - tunDevice: tunDevice, middleDevice: middleDevice, upstreamDNS: upstreamDns, tunnelDNS: tunnelDns, @@ -694,9 +691,9 @@ func (p *DNSProxy) runPacketSender() { pos += len(slice) } - // Write packet to TUN device + // Write packet to TUN device via MiddleDevice // offset=16 indicates packet data starts at position 16 in the buffer - _, err := p.tunDevice.Write([][]byte{buf}, offset) + _, err := p.middleDevice.WriteToTun([][]byte{buf}, offset) if err != nil { logger.Error("Failed to write DNS response to TUN: %v", err) } diff --git a/olm/olm.go b/olm/olm.go index 9cc1f51..a3bb694 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -35,6 +35,7 @@ var ( uapiListener net.Listener tdev tun.Device middleDev *olmDevice.MiddleDevice + interfaceName string dnsProxy *dns.DNSProxy apiServer *api.API olmClient *websocket.Client @@ -237,11 +238,11 @@ func StartTunnel(config TunnelConfig) { stopPing = make(chan struct{}) var ( - interfaceName = config.InterfaceName - id = config.ID - secret = config.Secret - userToken = config.UserToken + id = config.ID + secret = config.Secret + userToken = config.UserToken ) + interfaceName = config.InterfaceName apiServer.SetOrgID(config.OrgID) @@ -307,12 +308,7 @@ func StartTunnel(config TunnelConfig) { tdev, err = func() (tun.Device, error) { if config.FileDescriptorTun != 0 { - if runtime.GOOS == "android" { // otherwise we get a permission denied - theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(config.FileDescriptorTun)) - return theTun, err - } else { - return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) - } + return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) } var ifName = interfaceName if runtime.GOOS == "darwin" { // this is if we dont pass a fd @@ -329,11 +325,11 @@ func StartTunnel(config TunnelConfig) { return } - if config.FileDescriptorTun == 0 { - if realInterfaceName, err2 := tdev.Name(); err2 == nil { - interfaceName = realInterfaceName - } + // if config.FileDescriptorTun == 0 { + if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? + interfaceName = realInterfaceName } + // } // Wrap TUN device with packet filter for DNS proxy middleDev = olmDevice.NewMiddleDevice(tdev) @@ -389,7 +385,7 @@ func StartTunnel(config TunnelConfig) { } // Create and start DNS proxy - dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP) + dnsProxy, err = dns.NewDNSProxy(middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP) if err != nil { logger.Error("Failed to create DNS proxy: %v", err) } @@ -956,6 +952,33 @@ func StartTunnel(config TunnelConfig) { logger.Info("Tunnel process context cancelled, cleaning up") } +func AddDevice(fd uint32) { + if middleDev == nil { + logger.Error("MiddleDevice is nil, cannot add device") + return + } + + if tunnelConfig.MTU == 0 { + logger.Error("No MTU configured, cannot create device") + return + } + + tdev, err := olmDevice.CreateTUNFromFD(fd, tunnelConfig.MTU) + + if err != nil { + logger.Error("Failed to create TUN device: %v", err) + return + } + + // if config.FileDescriptorTun == 0 { + if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? + interfaceName = realInterfaceName + } + + // Here we replace the existing TUN device in the middle device with the new one + middleDev.AddDevice(tdev) +} + func Close() { // Restore original DNS configuration // we do this first to avoid any DNS issues if something else gets stuck From f08b17c7bd2043742f097742c5c801cd5e7a643c Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 31 Dec 2025 11:22:09 -0500 Subject: [PATCH 09/52] Middle device working but not closing Former-commit-id: c85fcc434ba4059a2952ecd0a3d54916f8bebc29 --- olm/olm.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index a3bb694..38d3324 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -952,22 +952,20 @@ func StartTunnel(config TunnelConfig) { logger.Info("Tunnel process context cancelled, cleaning up") } -func AddDevice(fd uint32) { +func AddDevice(fd uint32) error { if middleDev == nil { - logger.Error("MiddleDevice is nil, cannot add device") - return + return fmt.Errorf("middle device is not initialized") } if tunnelConfig.MTU == 0 { - logger.Error("No MTU configured, cannot create device") - return + // error + return fmt.Errorf("tunnel MTU is not set") } tdev, err := olmDevice.CreateTUNFromFD(fd, tunnelConfig.MTU) if err != nil { - logger.Error("Failed to create TUN device: %v", err) - return + return fmt.Errorf("failed to create TUN device from fd: %v", err) } // if config.FileDescriptorTun == 0 { @@ -977,6 +975,8 @@ func AddDevice(fd uint32) { // Here we replace the existing TUN device in the middle device with the new one middleDev.AddDevice(tdev) + + return nil } func Close() { From aeb908b68cb52c35a50925261ad6b8ad0836a093 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 31 Dec 2025 11:33:00 -0500 Subject: [PATCH 10/52] Exiting the middle device works now? Former-commit-id: d76b3c366f4f97865d993d5c579dce6a79d6891a --- device/middle_device.go | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/device/middle_device.go b/device/middle_device.go index 2a5d9b9..7dfbec8 100644 --- a/device/middle_device.go +++ b/device/middle_device.go @@ -163,6 +163,13 @@ func (d *MiddleDevice) pump(dev *closeAwareDevice) { batchSize := dev.BatchSize() logger.Debug("MiddleDevice: pump started for device") + // Recover from panic if readCh is closed while we're trying to send + defer func() { + if r := recover(); r != nil { + logger.Debug("MiddleDevice: pump recovered from panic (channel closed)") + } + }() + for { // Check if this device is closed if dev.IsClosed() { @@ -197,7 +204,12 @@ func (d *MiddleDevice) pump(dev *closeAwareDevice) { return } - // Try to send the result + // Try to send the result - check closed state first to avoid sending on closed channel + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, device closed before send") + return + } + select { case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: default: @@ -225,6 +237,13 @@ func (d *MiddleDevice) InjectOutbound(packet []byte) { if d.closed.Load() { return } + // Use defer/recover to handle panic from sending on closed channel + // This can happen during shutdown race conditions + defer func() { + if r := recover(); r != nil { + logger.Debug("MiddleDevice: InjectOutbound recovered from panic (channel closed)") + } + }() select { case d.injectCh <- packet: default: @@ -268,6 +287,8 @@ func (d *MiddleDevice) Close() error { d.cond.Broadcast() d.mu.Unlock() + // Close underlying devices first - this causes the pump goroutines to exit + // when their read operations return errors var lastErr error logger.Debug("MiddleDevice: Closing %d devices", len(devices)) for _, device := range devices { @@ -277,7 +298,12 @@ func (d *MiddleDevice) Close() error { } } + // Now close channels to unblock any remaining readers + // The pump should have exited by now, but close channels to be safe + close(d.readCh) + close(d.injectCh) close(d.events) + return lastErr } @@ -416,7 +442,11 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err // Now block waiting for data from readCh or injectCh select { - case res := <-d.readCh: + case res, ok := <-d.readCh: + if !ok { + // Channel closed, device is shutting down + return 0, io.EOF + } if res.err != nil { // Check if device was swapped if dev.IsClosed() { @@ -446,7 +476,11 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err } n = count - case pkt := <-d.injectCh: + case pkt, ok := <-d.injectCh: + if !ok { + // Channel closed, device is shutting down + return 0, io.EOF + } if len(bufs) == 0 { return 0, nil } From 1b43f029a94fc3284880d28b8d28668fdca775a2 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 31 Dec 2025 15:42:51 -0500 Subject: [PATCH 11/52] Dont pass in dns proxy to override Former-commit-id: 51dd927f9b1a0d74bedaf1b0f34b046f4a937dbd --- dns/override/dns_override_android.go | 6 ++---- dns/override/dns_override_darwin.go | 9 ++------- dns/override/dns_override_ios.go | 6 ++---- dns/override/dns_override_unix.go | 19 +++++++------------ dns/override/dns_override_windows.go | 9 ++------- olm/olm.go | 6 ++++-- 6 files changed, 19 insertions(+), 36 deletions(-) diff --git a/dns/override/dns_override_android.go b/dns/override/dns_override_android.go index af1d946..d3fd78e 100644 --- a/dns/override/dns_override_android.go +++ b/dns/override/dns_override_android.go @@ -2,13 +2,11 @@ package olm -import ( - "github.com/fosrl/olm/dns" -) +import "net/netip" // SetupDNSOverride is a no-op on Android // Android handles DNS through the VpnService API at the Java/Kotlin layer -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { return nil } diff --git a/dns/override/dns_override_darwin.go b/dns/override/dns_override_darwin.go index 6ccc3fb..c1c3789 100644 --- a/dns/override/dns_override_darwin.go +++ b/dns/override/dns_override_darwin.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on macOS // Uses scutil for DNS configuration -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error configurator, err = platform.NewDarwinDNSConfigurator() if err != nil { @@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/dns/override/dns_override_ios.go b/dns/override/dns_override_ios.go index 109d471..6c95c71 100644 --- a/dns/override/dns_override_ios.go +++ b/dns/override/dns_override_ios.go @@ -2,12 +2,10 @@ package olm -import ( - "github.com/fosrl/olm/dns" -) +import "net/netip" // SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { return nil } diff --git a/dns/override/dns_override_unix.go b/dns/override/dns_override_unix.go index c3b31e8..12cb692 100644 --- a/dns/override/dns_override_unix.go +++ b/dns/override/dns_override_unix.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD // Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error // Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability @@ -32,7 +27,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName) if err == nil { logger.Info("Using systemd-resolved DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err) @@ -40,7 +35,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName) if err == nil { logger.Info("Using NetworkManager DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err) @@ -48,7 +43,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName) if err == nil { logger.Info("Using resolvconf DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create resolvconf configurator: %v, falling back", err) } @@ -60,11 +55,11 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { } logger.Info("Using file-based DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } // setDNS is a helper function to set DNS and log the results -func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error { +func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error { // Get current DNS servers before changing currentDNS, err := conf.GetCurrentDNS() if err != nil { @@ -75,7 +70,7 @@ func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/dns/override/dns_override_windows.go b/dns/override/dns_override_windows.go index a564079..16bbca1 100644 --- a/dns/override/dns_override_windows.go +++ b/dns/override/dns_override_windows.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on Windows // Uses registry-based configuration (automatically extracts interface GUID) -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error configurator, err = platform.NewWindowsDNSConfigurator(interfaceName) if err != nil { @@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/olm/olm.go b/olm/olm.go index 38d3324..4d12952 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -439,10 +439,12 @@ func StartTunnel(config TunnelConfig) { if config.OverrideDNS { // Set up DNS override to use our DNS proxy - if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy); err != nil { + if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy.GetProxyIP()); err != nil { logger.Error("Failed to setup DNS override: %v", err) return } + + network.SetDNSServers([]string{dnsProxy.GetProxyIP().String()}) } apiServer.SetRegistered(true) @@ -975,7 +977,7 @@ func AddDevice(fd uint32) error { // Here we replace the existing TUN device in the middle device with the new one middleDev.AddDevice(tdev) - + return nil } From 83edde34494264e0917f134086cb49ae05d72d82 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 31 Dec 2025 18:01:25 -0500 Subject: [PATCH 12/52] Fix build on darwin Former-commit-id: fbeb5be88d9a2c126c232e0df78efef4fa51ead8 --- device/tun_darwin.go | 44 ++++++++++++++++++++++++++++ device/{tun_unix.go => tun_linux.go} | 2 +- 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 device/tun_darwin.go rename device/{tun_unix.go => tun_linux.go} (98%) diff --git a/device/tun_darwin.go b/device/tun_darwin.go new file mode 100644 index 0000000..f763f74 --- /dev/null +++ b/device/tun_darwin.go @@ -0,0 +1,44 @@ +//go:build darwin + +package device + +import ( + "net" + "os" + + "github.com/fosrl/newt/logger" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/tun" +) + +func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { + dupTunFd, err := unix.Dup(int(tunFd)) + if err != nil { + logger.Error("Unable to dup tun fd: %v", err) + return nil, err + } + + err = unix.SetNonblock(dupTunFd, true) + if err != nil { + unix.Close(dupTunFd) + return nil, err + } + + file := os.NewFile(uintptr(dupTunFd), "/dev/tun") + device, err := tun.CreateTUNFromFile(file, mtuInt) + if err != nil { + file.Close() + return nil, err + } + + return device, nil +} + +func UapiOpen(interfaceName string) (*os.File, error) { + return ipc.UAPIOpen(interfaceName) +} + +func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { + return ipc.UAPIListen(interfaceName, fileUAPI) +} diff --git a/device/tun_unix.go b/device/tun_linux.go similarity index 98% rename from device/tun_unix.go rename to device/tun_linux.go index 22cec13..902f269 100644 --- a/device/tun_unix.go +++ b/device/tun_linux.go @@ -1,4 +1,4 @@ -//go:build !windows +//go:build linux package device From 1ed27fec1a09beaf1d4190d129878903e7547f0b Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Thu, 1 Jan 2026 17:38:01 -0500 Subject: [PATCH 13/52] set mtu to 0 on darwin Former-commit-id: fbe686961ed233f1e50cc4ccb46336e7e5938c8d --- device/tun_darwin.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/device/tun_darwin.go b/device/tun_darwin.go index f763f74..df87d53 100644 --- a/device/tun_darwin.go +++ b/device/tun_darwin.go @@ -26,7 +26,7 @@ func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { } file := os.NewFile(uintptr(dupTunFd), "/dev/tun") - device, err := tun.CreateTUNFromFile(file, mtuInt) + device, err := tun.CreateTUNFromFile(file, 0) if err != nil { file.Close() return nil, err From 7b7eae617a2ac1b297e04a8d6c97597c48571e65 Mon Sep 17 00:00:00 2001 From: Varun Narravula Date: Wed, 31 Dec 2025 15:07:06 -0800 Subject: [PATCH 14/52] chore: format files using gofmt Former-commit-id: 5cfa0dfb9781d57bd369fa3f4631f9721f64a5a9 --- dns/dns_proxy.go | 10 +++--- dns/dns_records.go | 2 +- dns/dns_records_test.go | 68 ++++++++++++++++++++--------------------- dns/platform/darwin.go | 2 +- olm/olm.go | 2 +- 5 files changed, 42 insertions(+), 42 deletions(-) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 6d56379..6c9891a 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -34,18 +34,18 @@ type DNSProxy struct { ep *channel.Endpoint proxyIP netip.Addr upstreamDNS []string - tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally + tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally mtu int tunDevice tun.Device // Direct reference to underlying TUN device for responses middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering recordStore *DNSRecordStore // Local DNS records // Tunnel DNS fields - for sending queries over WireGuard - tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries) - tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries - tunnelEp *channel.Endpoint + tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries) + tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries + tunnelEp *channel.Endpoint tunnelActivePorts map[uint16]bool - tunnelPortsLock sync.Mutex + tunnelPortsLock sync.Mutex ctx context.Context cancel context.CancelFunc diff --git a/dns/dns_records.go b/dns/dns_records.go index ed57b77..5308b0e 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -322,4 +322,4 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool { } return matchWildcardInternal(pattern, domain, pi+1, di+1) -} \ No newline at end of file +} diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go index 0bb18a1..f922afb 100644 --- a/dns/dns_records_test.go +++ b/dns/dns_records_test.go @@ -37,7 +37,7 @@ func TestWildcardMatching(t *testing.T) { domain: "autoco.internal.", expected: false, }, - + // Question mark wildcard tests { name: "host-0?.autoco.internal matches host-01.autoco.internal", @@ -63,7 +63,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-012.autoco.internal.", expected: false, }, - + // Combined wildcard tests { name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal", @@ -83,7 +83,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-01.autoco.internal.", expected: false, }, - + // Multiple asterisks { name: "*.*. autoco.internal matches any.thing.autoco.internal", @@ -97,7 +97,7 @@ func TestWildcardMatching(t *testing.T) { domain: "single.autoco.internal.", expected: false, }, - + // Asterisk in middle { name: "host-*.autoco.internal matches host-anything.autoco.internal", @@ -111,7 +111,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-.autoco.internal.", expected: true, }, - + // Multiple question marks { name: "host-??.autoco.internal matches host-01.autoco.internal", @@ -125,7 +125,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-1.autoco.internal.", expected: false, }, - + // Exact match (no wildcards) { name: "exact.autoco.internal matches exact.autoco.internal", @@ -139,7 +139,7 @@ func TestWildcardMatching(t *testing.T) { domain: "other.autoco.internal.", expected: false, }, - + // Edge cases { name: "* matches anything", @@ -154,7 +154,7 @@ func TestWildcardMatching(t *testing.T) { expected: true, }, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := matchWildcard(tt.pattern, tt.domain) @@ -167,21 +167,21 @@ func TestWildcardMatching(t *testing.T) { func TestDNSRecordStoreWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard records wildcardIP := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", wildcardIP) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Add exact record exactIP := net.ParseIP("10.0.0.2") err = store.AddRecord("exact.autoco.internal", exactIP) if err != nil { t.Fatalf("Failed to add exact record: %v", err) } - + // Test exact match takes precedence ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -190,7 +190,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) { if !ips[0].Equal(exactIP) { t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) } - + // Test wildcard match ips = store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -199,7 +199,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) { if !ips[0].Equal(wildcardIP) { t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) } - + // Test non-match (base domain) ips = store.GetRecords("autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -209,14 +209,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) { func TestDNSRecordStoreComplexWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add complex wildcard pattern ip1 := net.ParseIP("10.0.0.1") err := store.AddRecord("*.host-0?.autoco.internal", ip1) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Test matching domain ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -225,13 +225,13 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { if len(ips) > 0 && !ips[0].Equal(ip1) { t.Errorf("Expected IP %v, got %v", ip1, ips[0]) } - + // Test non-matching domain (missing prefix) ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) } - + // Test non-matching domain (wrong ? position) ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -241,23 +241,23 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { func TestDNSRecordStoreRemoveWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard record ip := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Verify it exists ips := store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP before removal, got %d", len(ips)) } - + // Remove wildcard record store.RemoveRecord("*.autoco.internal", nil) - + // Verify it's gone ips = store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -267,40 +267,40 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { func TestDNSRecordStoreMultipleWildcards(t *testing.T) { store := NewDNSRecordStore() - + // Add multiple wildcard patterns that don't overlap ip1 := net.ParseIP("10.0.0.1") ip2 := net.ParseIP("10.0.0.2") ip3 := net.ParseIP("10.0.0.3") - + err := store.AddRecord("*.prod.autoco.internal", ip1) if err != nil { t.Fatalf("Failed to add first wildcard: %v", err) } - + err = store.AddRecord("*.dev.autoco.internal", ip2) if err != nil { t.Fatalf("Failed to add second wildcard: %v", err) } - + // Add a broader wildcard that matches both err = store.AddRecord("*.autoco.internal", ip3) if err != nil { t.Fatalf("Failed to add third wildcard: %v", err) } - + // Test domain matching only the prod pattern and the broad pattern ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) } - + // Test domain matching only the dev pattern and the broad pattern ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) } - + // Test domain matching only the broad pattern ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -310,14 +310,14 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) { func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add IPv6 wildcard record ip := net.ParseIP("2001:db8::1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add IPv6 wildcard record: %v", err) } - + // Test wildcard match for IPv6 ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) if len(ips) != 1 { @@ -330,21 +330,21 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { func TestHasRecordWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard record ip := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Test HasRecord with wildcard match if !store.HasRecord("host.autoco.internal.", RecordTypeA) { t.Error("Expected HasRecord to return true for wildcard match") } - + // Test HasRecord with non-match if store.HasRecord("autoco.internal.", RecordTypeA) { t.Error("Expected HasRecord to return false for base domain") } -} \ No newline at end of file +} diff --git a/dns/platform/darwin.go b/dns/platform/darwin.go index 61cc81b..8054c57 100644 --- a/dns/platform/darwin.go +++ b/dns/platform/darwin.go @@ -416,4 +416,4 @@ func (d *DarwinDNSConfigurator) clearState() error { logger.Debug("Cleared DNS state file") return nil -} \ No newline at end of file +} diff --git a/olm/olm.go b/olm/olm.go index f84ee4f..02257b8 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -811,7 +811,7 @@ func StartTunnel(config TunnelConfig) { Endpoint: handshakeData.ExitNode.Endpoint, RelayPort: relayPort, PublicKey: handshakeData.ExitNode.PublicKey, - SiteIds: []int{siteId}, + SiteIds: []int{siteId}, } added := holePunchManager.AddExitNode(exitNode) From c565a46a6f7fea8715ce4d3f050a12fcafc596ec Mon Sep 17 00:00:00 2001 From: Varun Narravula Date: Wed, 31 Dec 2025 15:44:19 -0800 Subject: [PATCH 15/52] feat(logger): configure log file path thorugh global options Former-commit-id: 577d89f4fb84d67a170b678bbe0fd844505686d9 --- olm/olm.go | 15 ++++++++++++--- olm/types.go | 3 ++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index 02257b8..98ae6fb 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -100,6 +100,17 @@ func Init(ctx context.Context, config GlobalConfig) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) + if config.LogFilePath != "" { + logFile, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + logger.Fatal("Failed to open log file: %v", err) + } + + // TODO: figure out how to close file, if set + logger.SetOutput(logFile) + return + } + logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() if err != nil { @@ -306,7 +317,7 @@ func StartTunnel(config TunnelConfig) { if config.FileDescriptorTun != 0 { return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) } - var ifName = interfaceName + ifName := interfaceName if runtime.GOOS == "darwin" { // this is if we dont pass a fd ifName, err = network.FindUnusedUTUN() if err != nil { @@ -315,7 +326,6 @@ func StartTunnel(config TunnelConfig) { } return tun.CreateTUN(ifName, config.MTU) }() - if err != nil { logger.Error("Failed to create TUN device: %v", err) return @@ -361,7 +371,6 @@ func StartTunnel(config TunnelConfig) { for { conn, err := uapiListener.Accept() if err != nil { - return } go dev.IpcHandle(conn) diff --git a/olm/types.go b/olm/types.go index b7153af..14cc044 100644 --- a/olm/types.go +++ b/olm/types.go @@ -14,7 +14,8 @@ type WgData struct { type GlobalConfig struct { // Logging - LogLevel string + LogLevel string + LogFilePath string // HTTP server EnableAPI bool From 5b637bb4caac627959868705b80907521295f24b Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 12 Jan 2026 12:20:59 -0800 Subject: [PATCH 16/52] Add expo backoff Former-commit-id: faae551aca7447b717e50137a237c78542cb14cf --- main.go | 1 + olm/olm.go | 12 +++++++ olm/types.go | 3 ++ peers/monitor/monitor.go | 61 +++++++++++++++++++++++++++------ peers/monitor/wgtester.go | 71 +++++++++++++++++++++++++++++++-------- 5 files changed, 123 insertions(+), 25 deletions(-) diff --git a/main.go b/main.go index f6c6973..5b6c15e 100644 --- a/main.go +++ b/main.go @@ -219,6 +219,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt Agent: "Olm CLI", OnExit: cancel, // Pass cancel function directly to trigger shutdown OnTerminated: cancel, + PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE } olm.Init(ctx, olmConfig) diff --git a/olm/olm.go b/olm/olm.go index 4d12952..03cf02b 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" "net" + "net/http" + _ "net/http/pprof" "os" "runtime" "strconv" @@ -101,6 +103,16 @@ func Init(ctx context.Context, config GlobalConfig) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) + // Start pprof server if enabled + if config.PprofAddr != "" { + go func() { + logger.Info("Starting pprof server on %s", config.PprofAddr) + if err := http.ListenAndServe(config.PprofAddr, nil); err != nil { + logger.Error("Failed to start pprof server: %v", err) + } + }() + } + logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() if err != nil { diff --git a/olm/types.go b/olm/types.go index b7153af..a43121f 100644 --- a/olm/types.go +++ b/olm/types.go @@ -23,6 +23,9 @@ type GlobalConfig struct { Version string Agent string + // Debugging + PprofAddr string // Address to serve pprof on (e.g., "localhost:6060") + // Callbacks OnRegistered func() OnConnected func() diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 27bc408..1ec267e 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -61,6 +61,13 @@ type PeerMonitor struct { holepunchMaxAttempts int // max consecutive failures before triggering relay holepunchFailures map[int]int // siteID -> consecutive failure count + // Exponential backoff fields for holepunch monitor + holepunchMinInterval time.Duration // Minimum interval (initial) + holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) + holepunchBackoffMultiplier float64 // Multiplier for each stable check + holepunchStableCount map[int]int // siteID -> consecutive stable status count + holepunchCurrentInterval time.Duration // Current interval with backoff applied + // Rapid initial test fields rapidTestInterval time.Duration // interval between rapid test attempts rapidTestTimeout time.Duration // timeout for each rapid test attempt @@ -101,6 +108,12 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total apiServer: apiServer, wgConnectionStatus: make(map[int]bool), + // Exponential backoff settings for holepunch monitor + holepunchMinInterval: 2 * time.Second, + holepunchMaxInterval: 30 * time.Second, + holepunchBackoffMultiplier: 1.5, + holepunchStableCount: make(map[int]int), + holepunchCurrentInterval: 2 * time.Second, } if err := pm.initNetstack(); err != nil { @@ -172,6 +185,7 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st client.SetPacketInterval(pm.interval) client.SetTimeout(pm.timeout) client.SetMaxAttempts(pm.maxAttempts) + client.SetMaxInterval(30 * time.Second) // Allow backoff up to 30 seconds when stable pm.monitors[siteID] = client @@ -470,31 +484,50 @@ func (pm *PeerMonitor) stopHolepunchMonitor() { logger.Info("Stopped holepunch connection monitor") } -// runHolepunchMonitor runs the holepunch monitoring loop +// runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff func (pm *PeerMonitor) runHolepunchMonitor() { - ticker := time.NewTicker(pm.holepunchInterval) - defer ticker.Stop() + pm.mutex.Lock() + pm.holepunchCurrentInterval = pm.holepunchMinInterval + pm.mutex.Unlock() - // Do initial check immediately - pm.checkHolepunchEndpoints() + timer := time.NewTimer(0) // Fire immediately for initial check + defer timer.Stop() for { select { case <-pm.holepunchStopChan: return - case <-ticker.C: - pm.checkHolepunchEndpoints() + case <-timer.C: + anyStatusChanged := pm.checkHolepunchEndpoints() + + pm.mutex.Lock() + if anyStatusChanged { + // Reset to minimum interval on any status change + pm.holepunchCurrentInterval = pm.holepunchMinInterval + } else { + // Apply exponential backoff when stable + newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier) + if newInterval > pm.holepunchMaxInterval { + newInterval = pm.holepunchMaxInterval + } + pm.holepunchCurrentInterval = newInterval + } + currentInterval := pm.holepunchCurrentInterval + pm.mutex.Unlock() + + timer.Reset(currentInterval) } } } // checkHolepunchEndpoints tests all holepunch endpoints -func (pm *PeerMonitor) checkHolepunchEndpoints() { +// Returns true if any endpoint's status changed +func (pm *PeerMonitor) checkHolepunchEndpoints() bool { pm.mutex.Lock() // Check if we're still running before doing any work if !pm.running { pm.mutex.Unlock() - return + return false } endpoints := make(map[int]string, len(pm.holepunchEndpoints)) for siteID, endpoint := range pm.holepunchEndpoints { @@ -504,6 +537,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { maxAttempts := pm.holepunchMaxAttempts pm.mutex.Unlock() + anyStatusChanged := false + for siteID, endpoint := range endpoints { // logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) result := pm.holepunchTester.TestEndpoint(endpoint, timeout) @@ -529,7 +564,9 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { pm.mutex.Unlock() // Log status changes - if !exists || previousStatus != result.Success { + statusChanged := !exists || previousStatus != result.Success + if statusChanged { + anyStatusChanged = true if result.Success { logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT) } else { @@ -562,7 +599,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { pm.mutex.Unlock() if !stillRunning { - return // Stop processing if shutdown is in progress + return anyStatusChanged // Stop processing if shutdown is in progress } if !result.Success && !isRelayed && failureCount >= maxAttempts { @@ -579,6 +616,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { } } } + + return anyStatusChanged } // GetHolepunchStatus returns the current holepunch status for all endpoints diff --git a/peers/monitor/wgtester.go b/peers/monitor/wgtester.go index dac2008..21f788a 100644 --- a/peers/monitor/wgtester.go +++ b/peers/monitor/wgtester.go @@ -36,6 +36,12 @@ type Client struct { timeout time.Duration maxAttempts int dialer Dialer + + // Exponential backoff fields + minInterval time.Duration // Minimum interval (initial) + maxInterval time.Duration // Maximum interval (cap for backoff) + backoffMultiplier float64 // Multiplier for each stable check + stableCountToBackoff int // Number of stable checks before backing off } // Dialer is a function that creates a connection @@ -50,18 +56,23 @@ type ConnectionStatus struct { // NewClient creates a new connection test client func NewClient(serverAddr string, dialer Dialer) (*Client, error) { return &Client{ - serverAddr: serverAddr, - shutdownCh: make(chan struct{}), - packetInterval: 2 * time.Second, - timeout: 500 * time.Millisecond, // Timeout for individual packets - maxAttempts: 3, // Default max attempts - dialer: dialer, + serverAddr: serverAddr, + shutdownCh: make(chan struct{}), + packetInterval: 2 * time.Second, + minInterval: 2 * time.Second, + maxInterval: 30 * time.Second, + backoffMultiplier: 1.5, + stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off + timeout: 500 * time.Millisecond, // Timeout for individual packets + maxAttempts: 3, // Default max attempts + dialer: dialer, }, nil } // SetPacketInterval changes how frequently packets are sent in monitor mode func (c *Client) SetPacketInterval(interval time.Duration) { c.packetInterval = interval + c.minInterval = interval } // SetTimeout changes the timeout for waiting for responses @@ -74,6 +85,16 @@ func (c *Client) SetMaxAttempts(attempts int) { c.maxAttempts = attempts } +// SetMaxInterval sets the maximum backoff interval +func (c *Client) SetMaxInterval(interval time.Duration) { + c.maxInterval = interval +} + +// SetBackoffMultiplier sets the multiplier for exponential backoff +func (c *Client) SetBackoffMultiplier(multiplier float64) { + c.backoffMultiplier = multiplier +} + // UpdateServerAddr updates the server address and resets the connection func (c *Client) UpdateServerAddr(serverAddr string) { c.connLock.Lock() @@ -138,6 +159,9 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { binary.BigEndian.PutUint32(packet[0:4], magicHeader) packet[4] = packetTypeRequest + // Reusable response buffer + responseBuffer := make([]byte, packetSize) + // Send multiple attempts as specified for attempt := 0; attempt < c.maxAttempts; attempt++ { select { @@ -157,20 +181,17 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { return false, 0 } - // logger.Debug("Attempting to send monitor packet to %s", c.serverAddr) _, err := c.conn.Write(packet) if err != nil { c.connLock.Unlock() logger.Info("Error sending packet: %v", err) continue } - // logger.Debug("Successfully sent monitor packet") // Set read deadline c.conn.SetReadDeadline(time.Now().Add(c.timeout)) // Wait for response - responseBuffer := make([]byte, packetSize) n, err := c.conn.Read(responseBuffer) c.connLock.Unlock() @@ -238,28 +259,50 @@ func (c *Client) StartMonitor(callback MonitorCallback) error { go func() { var lastConnected bool firstRun := true + stableCount := 0 + currentInterval := c.minInterval - ticker := time.NewTicker(c.packetInterval) - defer ticker.Stop() + timer := time.NewTimer(currentInterval) + defer timer.Stop() for { select { case <-c.shutdownCh: return - case <-ticker.C: + case <-timer.C: ctx, cancel := context.WithTimeout(context.Background(), c.timeout) connected, rtt := c.TestConnection(ctx) cancel() + statusChanged := connected != lastConnected + // Callback if status changed or it's the first check - if connected != lastConnected || firstRun { + if statusChanged || firstRun { callback(ConnectionStatus{ Connected: connected, RTT: rtt, }) lastConnected = connected firstRun = false + // Reset backoff on status change + stableCount = 0 + currentInterval = c.minInterval + } else { + // Status is stable, increment counter + stableCount++ + + // Apply exponential backoff after stable threshold + if stableCount >= c.stableCountToBackoff { + newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier) + if newInterval > c.maxInterval { + newInterval = c.maxInterval + } + currentInterval = newInterval + } } + + // Reset timer with current interval + timer.Reset(currentInterval) } } }() @@ -278,4 +321,4 @@ func (c *Client) StopMonitor() { close(c.shutdownCh) c.monitorRunning = false -} +} \ No newline at end of file From 20e0c18845e8062053ecb503de22dc13f5556f99 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 12 Jan 2026 12:29:42 -0800 Subject: [PATCH 17/52] Try to reduce cpu when idle Former-commit-id: ba91478b89832990cb366e2362a21c93e5ce698e --- dns/dns_proxy.go | 76 ++++++++++++++++------------------------ peers/monitor/monitor.go | 67 +++++++++++++++-------------------- 2 files changed, 60 insertions(+), 83 deletions(-) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 748a5a9..d010bc6 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -599,12 +599,12 @@ func (p *DNSProxy) runTunnelPacketSender() { defer p.wg.Done() logger.Debug("DNS tunnel packet sender goroutine started") - ticker := time.NewTicker(1 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-p.ctx.Done(): + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := p.tunnelEp.ReadContext(p.ctx) + if pkt == nil { + // Context was cancelled or endpoint closed logger.Debug("DNS tunnel packet sender exiting") // Drain any remaining packets for { @@ -615,36 +615,28 @@ func (p *DNSProxy) runTunnelPacketSender() { pkt.DecRef() } return - case <-ticker.C: - // Try to read packets - for i := 0; i < 10; i++ { - pkt := p.tunnelEp.Read() - if pkt == nil { - break - } - - // Extract packet data - slices := pkt.AsSlices() - if len(slices) > 0 { - var totalSize int - for _, slice := range slices { - totalSize += len(slice) - } - - buf := make([]byte, totalSize) - pos := 0 - for _, slice := range slices { - copy(buf[pos:], slice) - pos += len(slice) - } - - // Inject into MiddleDevice (outbound to WG) - p.middleDevice.InjectOutbound(buf) - } - - pkt.DecRef() - } } + + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + p.middleDevice.InjectOutbound(buf) + } + + pkt.DecRef() } } @@ -657,18 +649,12 @@ func (p *DNSProxy) runPacketSender() { const offset = 16 for { - select { - case <-p.ctx.Done(): - return - default: - } - - // Read packets from netstack endpoint - pkt := p.ep.Read() + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := p.ep.ReadContext(p.ctx) if pkt == nil { - // No packet available, small sleep to avoid busy loop - time.Sleep(1 * time.Millisecond) - continue + // Context was cancelled or endpoint closed + return } // Extract packet data as slices diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 1ec267e..45dd090 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -42,7 +42,7 @@ type PeerMonitor struct { stack *stack.Stack ep *channel.Endpoint activePorts map[uint16]bool - portsLock sync.Mutex + portsLock sync.RWMutex nsCtx context.Context nsCancel context.CancelFunc nsWg sync.WaitGroup @@ -809,9 +809,9 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool { } // Check if we are listening on this port - pm.portsLock.Lock() + pm.portsLock.RLock() active := pm.activePorts[uint16(port)] - pm.portsLock.Unlock() + pm.portsLock.RUnlock() if !active { return false @@ -842,13 +842,12 @@ func (pm *PeerMonitor) runPacketSender() { defer pm.nsWg.Done() logger.Debug("PeerMonitor: Packet sender goroutine started") - // Use a ticker to periodically check for packets without blocking indefinitely - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-pm.nsCtx.Done(): + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := pm.ep.ReadContext(pm.nsCtx) + if pkt == nil { + // Context was cancelled or endpoint closed logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets") // Drain any remaining packets before exiting for { @@ -860,36 +859,28 @@ func (pm *PeerMonitor) runPacketSender() { } logger.Debug("PeerMonitor: Packet sender goroutine exiting") return - case <-ticker.C: - // Try to read packets in batches - for i := 0; i < 10; i++ { - pkt := pm.ep.Read() - if pkt == nil { - break - } - - // Extract packet data - slices := pkt.AsSlices() - if len(slices) > 0 { - var totalSize int - for _, slice := range slices { - totalSize += len(slice) - } - - buf := make([]byte, totalSize) - pos := 0 - for _, slice := range slices { - copy(buf[pos:], slice) - pos += len(slice) - } - - // Inject into MiddleDevice (outbound to WG) - pm.middleDev.InjectOutbound(buf) - } - - pkt.DecRef() - } } + + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + pm.middleDev.InjectOutbound(buf) + } + + pkt.DecRef() } } From 9c0b4fcd5f1dd78e498b73be073871deeaa3d6bd Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 13 Jan 2026 11:51:51 -0800 Subject: [PATCH 18/52] Fix error checking Former-commit-id: 231808476b1087357629b4765285f30900844441 --- websocket/client.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/websocket/client.go b/websocket/client.go index f620f8a..74b0401 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -414,7 +415,8 @@ func (c *Client) connectWithRetry() { err := c.establishConnection() if err != nil { // Check if this is an auth error (401/403) - if authErr, ok := err.(*AuthError); ok { + var authErr *AuthError + if errors.As(err, &authErr) { logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr) // Trigger auth error callback if set (this should terminate the tunnel) if c.onAuthError != nil { From dada0cc1242a4d0921987b079576d14ac8e21366 Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Tue, 13 Jan 2026 14:30:02 -0800 Subject: [PATCH 19/52] add low power state for testing Former-commit-id: 996fe59999c64e63ec74a33c6b2f792d7d9130d4 --- olm/olm.go | 188 ++++++++++++++++++++++++++++++++++++++- peers/manager.go | 7 ++ peers/monitor/monitor.go | 29 ++++-- 3 files changed, 218 insertions(+), 6 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index 774a3cb..21295be 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -53,6 +53,11 @@ var ( updateRegister func(newData interface{}) stopPing chan struct{} peerManager *peers.PeerManager + // Power mode management + currentPowerMode string + originalPeerInterval time.Duration + originalHolepunchMinInterval time.Duration + originalHolepunchMaxInterval time.Duration ) // initTunnelInfo creates the shared UDP socket and holepunch manager. @@ -112,7 +117,7 @@ func Init(ctx context.Context, config GlobalConfig) { } }() } - + if config.LogFilePath != "" { logFile, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) if err != nil { @@ -432,6 +437,18 @@ func StartTunnel(config TunnelConfig) { APIServer: apiServer, }) + // Capture original intervals for power mode management + if peerManager != nil { + peerMonitor := peerManager.GetPeerMonitor() + if peerMonitor != nil { + originalPeerInterval = 2 * time.Second // Default peer interval + originalHolepunchMinInterval, originalHolepunchMaxInterval = peerMonitor.GetHolepunchIntervals() + } + } + + // Initialize power mode to normal + currentPowerMode = "normal" + for i := range wgData.Sites { site := wgData.Sites[i] var siteEndpoint string @@ -1156,3 +1173,172 @@ func SwitchOrg(orgID string) error { return nil } + +// SetPowerMode switches between normal and low power modes +// In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes +// In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored +func SetPowerMode(mode string) error { + // Validate mode + if mode != "normal" && mode != "low" { + return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode) + } + + // If already in the requested mode, return early + if currentPowerMode == mode { + logger.Debug("Already in %s power mode", mode) + return nil + } + + logger.Info("Switching to %s power mode", mode) + + if mode == "low" { + // Low Power Mode: Close websocket and reduce monitoring frequency + + // Close websocket connection - this stops: + // - WebSocket ping monitor (via pingMonitor() goroutine) + // - Application ping messages (via keepSendingPing() goroutine) + if olmClient != nil { + logger.Info("Closing websocket connection for low power mode") + if err := olmClient.Close(); err != nil { + logger.Error("Error closing websocket: %v", err) + } + } + + // Stop application ping goroutine + if stopPing != nil { + select { + case <-stopPing: + // Channel already closed + default: + close(stopPing) + } + } + + // Stop peer monitoring + if peerManager != nil { + peerManager.Stop() + } + + // Store original intervals if not already stored + if originalPeerInterval == 0 && peerManager != nil { + peerMonitor := peerManager.GetPeerMonitor() + if peerMonitor != nil { + originalPeerInterval = 2 * time.Second // Default peer interval + originalHolepunchMinInterval, originalHolepunchMaxInterval = peerMonitor.GetHolepunchIntervals() + } + } + + // Set monitoring intervals to 10 minutes + if peerManager != nil { + peerMonitor := peerManager.GetPeerMonitor() + if peerMonitor != nil { + lowPowerInterval := 10 * time.Minute + peerMonitor.SetInterval(lowPowerInterval) + peerMonitor.SetHolepunchInterval(lowPowerInterval, lowPowerInterval) + logger.Info("Set monitoring intervals to 10 minutes for low power mode") + } + } + + // Restart peer monitoring with new intervals (but websocket remains closed) + if peerManager != nil { + peerManager.Start() + } + + currentPowerMode = "low" + logger.Info("Switched to low power mode") + + } else { + // Normal Power Mode: Restore intervals and reconnect websocket + + // Restore monitoring intervals to original values + if peerManager != nil { + peerMonitor := peerManager.GetPeerMonitor() + if peerMonitor != nil { + // Restore peer interval + if originalPeerInterval == 0 { + originalPeerInterval = 2 * time.Second // Default if not captured + } + peerMonitor.SetInterval(originalPeerInterval) + + // Restore holepunch intervals + if originalHolepunchMinInterval == 0 { + originalHolepunchMinInterval = 2 * time.Second // Default if not captured + } + if originalHolepunchMaxInterval == 0 { + originalHolepunchMaxInterval = 30 * time.Second // Default if not captured + } + peerMonitor.SetHolepunchInterval(originalHolepunchMinInterval, originalHolepunchMaxInterval) + logger.Info("Restored monitoring intervals to normal (peer: %v, holepunch: %v-%v)", + originalPeerInterval, originalHolepunchMinInterval, originalHolepunchMaxInterval) + } + } + + // Restart peer monitoring with restored intervals + if peerManager != nil { + peerManager.Start() + } + + // Reconnect websocket - this restarts: + // - WebSocket ping monitor + // - Application ping messages (via OnConnect callback) + // Note: Since websocket client's Close() permanently closes the done channel, + // we need to create a new client instance and re-register handlers + if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" { + logger.Info("Reconnecting websocket for normal power mode") + + // Close old client if it exists + if olmClient != nil { + olmClient.Close() + } + + // Recreate stopPing channel for application pings + stopPing = make(chan struct{}) + + // Create a new websocket client + var ( + id = tunnelConfig.ID + secret = tunnelConfig.Secret + userToken = tunnelConfig.UserToken + ) + + olm, err := websocket.NewClient( + id, + secret, + userToken, + tunnelConfig.OrgID, + tunnelConfig.Endpoint, + tunnelConfig.PingIntervalDuration, + tunnelConfig.PingTimeoutDuration, + ) + if err != nil { + logger.Error("Failed to create new websocket client: %v", err) + return fmt.Errorf("failed to create new websocket client: %w", err) + } + + // Store the new client + olmClient = olm + + // Re-register essential handlers (simplified - only the most critical ones) + // The full handler registration happens in StartTunnel, so this is just for reconnection + olm.OnConnect(func() error { + logger.Info("Websocket Reconnected") + apiServer.SetConnectionStatus(true) + go keepSendingPing(olm) + return nil + }) + + // Connect to the WebSocket server + if err := olm.Connect(); err != nil { + logger.Error("Failed to reconnect websocket: %v", err) + return fmt.Errorf("failed to reconnect websocket: %w", err) + } + } else { + logger.Warn("Cannot reconnect websocket: tunnel config not available") + } + + currentPowerMode = "normal" + logger.Info("Switched to normal power mode") + } + + return nil +} diff --git a/peers/manager.go b/peers/manager.go index af781e5..56f3707 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -84,6 +84,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) { return peer, ok } +// GetPeerMonitor returns the internal peer monitor instance +func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor { + pm.mu.RLock() + defer pm.mu.RUnlock() + return pm.peerMonitor +} + func (pm *PeerManager) GetAllPeers() []SiteConfig { pm.mu.RLock() defer pm.mu.RUnlock() diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 45dd090..2bb0c80 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -62,11 +62,11 @@ type PeerMonitor struct { holepunchFailures map[int]int // siteID -> consecutive failure count // Exponential backoff fields for holepunch monitor - holepunchMinInterval time.Duration // Minimum interval (initial) - holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) - holepunchBackoffMultiplier float64 // Multiplier for each stable check - holepunchStableCount map[int]int // siteID -> consecutive stable status count - holepunchCurrentInterval time.Duration // Current interval with backoff applied + holepunchMinInterval time.Duration // Minimum interval (initial) + holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) + holepunchBackoffMultiplier float64 // Multiplier for each stable check + holepunchStableCount map[int]int // siteID -> consecutive stable status count + holepunchCurrentInterval time.Duration // Current interval with backoff applied // Rapid initial test fields rapidTestInterval time.Duration // interval between rapid test attempts @@ -167,6 +167,25 @@ func (pm *PeerMonitor) SetMaxAttempts(attempts int) { } } +// SetHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) SetHolepunchInterval(minInterval, maxInterval time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.holepunchMinInterval = minInterval + pm.holepunchMaxInterval = maxInterval + // Reset current interval to the new minimum + pm.holepunchCurrentInterval = minInterval +} + +// GetHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) GetHolepunchIntervals() (minInterval, maxInterval time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + return pm.holepunchMinInterval, pm.holepunchMaxInterval +} + // AddPeer adds a new peer to monitor func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint string) error { pm.mutex.Lock() From 15e96a779cc3ff109e4c4cf6a46ae0cdbd359ec9 Mon Sep 17 00:00:00 2001 From: Varun Narravula Date: Mon, 5 Jan 2026 01:41:54 -0800 Subject: [PATCH 20/52] refactor(olm): convert global state into an olm instance Former-commit-id: b755f77d95ecc9d645806fa33a11d91261cfd059 --- api/api.go | 29 +- main.go | 12 +- olm/connect.go | 223 +++++++++ olm/data.go | 197 ++++++++ olm/olm.go | 976 ++++++++------------------------------- olm/peer.go | 195 ++++++++ olm/{util.go => ping.go} | 6 +- olm/types.go | 2 +- 8 files changed, 841 insertions(+), 799 deletions(-) create mode 100644 olm/connect.go create mode 100644 olm/data.go create mode 100644 olm/peer.go rename olm/{util.go => ping.go} (89%) diff --git a/api/api.go b/api/api.go index 91d9f37..a6ac9cd 100644 --- a/api/api.go +++ b/api/api.go @@ -63,23 +63,26 @@ type StatusResponse struct { // API represents the HTTP server and its state type API struct { - addr string - socketPath string - listener net.Listener - server *http.Server + addr string + socketPath string + listener net.Listener + server *http.Server + onConnect func(ConnectionRequest) error onSwitchOrg func(SwitchOrgRequest) error onDisconnect func() error onExit func() error + statusMu sync.RWMutex peerStatuses map[int]*PeerStatus connectedAt time.Time isConnected bool isRegistered bool isTerminated bool - version string - agent string - orgID string + + version string + agent string + orgID string } // NewAPI creates a new HTTP server that listens on a TCP address @@ -173,7 +176,7 @@ func (s *API) Stop() error { // Close the server first, which will also close the listener gracefully if s.server != nil { - s.server.Close() + _ = s.server.Close() } // Clean up socket file if using Unix socket @@ -358,7 +361,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusAccepted) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "connection request accepted", }) } @@ -406,7 +409,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "ok", }) } @@ -423,7 +426,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) { // Return a success response first w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "shutdown initiated", }) @@ -472,7 +475,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "org switch request accepted", }) } @@ -506,7 +509,7 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "disconnect initiated", }) } diff --git a/main.go b/main.go index 5b6c15e..2bf8dcd 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/updates" - "github.com/fosrl/olm/olm" + olmpkg "github.com/fosrl/olm/olm" ) func main() { @@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt } // Create a new olm.Config struct and copy values from the main config - olmConfig := olm.GlobalConfig{ + olmConfig := olmpkg.OlmConfig{ LogLevel: config.LogLevel, EnableAPI: config.EnableAPI, HTTPAddr: config.HTTPAddr, @@ -222,13 +222,17 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE } - olm.Init(ctx, olmConfig) + olm, err := olmpkg.Init(ctx, olmConfig) + if err != nil { + logger.Fatal("Failed to initialize olm: %v", err) + } + if err := olm.StartApi(); err != nil { logger.Fatal("Failed to start API server: %v", err) } if config.ID != "" && config.Secret != "" && config.Endpoint != "" { - tunnelConfig := olm.TunnelConfig{ + tunnelConfig := olmpkg.TunnelConfig{ Endpoint: config.Endpoint, ID: config.ID, Secret: config.Secret, diff --git a/olm/connect.go b/olm/connect.go new file mode 100644 index 0000000..568c731 --- /dev/null +++ b/olm/connect.go @@ -0,0 +1,223 @@ +package olm + +import ( + "encoding/json" + "fmt" + "os" + "runtime" + "strconv" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + olmDevice "github.com/fosrl/olm/device" + "github.com/fosrl/olm/dns" + dnsOverride "github.com/fosrl/olm/dns/override" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" +) + +func (o *Olm) handleConnect(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + + var wgData WgData + + if o.connected { + logger.Info("Already connected. Ignoring new connection request.") + return + } + + if o.stopRegister != nil { + o.stopRegister() + o.stopRegister = nil + } + + if o.updateRegister != nil { + o.updateRegister = nil + } + + // if there is an existing tunnel then close it + if o.dev != nil { + logger.Info("Got new message. Closing existing tunnel!") + o.dev.Close() + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &wgData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + + o.tdev, err = func() (tun.Device, error) { + if o.tunnelConfig.FileDescriptorTun != 0 { + return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU) + } + ifName := o.tunnelConfig.InterfaceName + if runtime.GOOS == "darwin" { // this is if we dont pass a fd + ifName, err = network.FindUnusedUTUN() + if err != nil { + return nil, err + } + } + return tun.CreateTUN(ifName, o.tunnelConfig.MTU) + }() + if err != nil { + logger.Error("Failed to create TUN device: %v", err) + return + } + + // if config.FileDescriptorTun == 0 { + if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? + o.tunnelConfig.InterfaceName = realInterfaceName + } + // } + + // Wrap TUN device with packet filter for DNS proxy + o.middleDev = olmDevice.NewMiddleDevice(o.tdev) + + wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") + // Use filtered device instead of raw TUN device + o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger)) + + if o.tunnelConfig.EnableUAPI { + fileUAPI, err := func() (*os.File, error) { + if o.tunnelConfig.FileDescriptorUAPI != 0 { + fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) + } + return os.NewFile(uintptr(fd), ""), nil + } + return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName) + }() + if err != nil { + logger.Error("UAPI listen error: %v", err) + os.Exit(1) + return + } + + o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI) + if err != nil { + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) + } + + go func() { + for { + conn, err := o.uapiListener.Accept() + if err != nil { + return + } + go o.dev.IpcHandle(conn) + } + }() + logger.Info("UAPI listener started") + } + + if err = o.dev.Up(); err != nil { + logger.Error("Failed to bring up WireGuard device: %v", err) + } + + // Extract interface IP (strip CIDR notation if present) + interfaceIP := wgData.TunnelIP + if strings.Contains(interfaceIP, "/") { + interfaceIP = strings.Split(interfaceIP, "/")[0] + } + + // Create and start DNS proxy + o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP) + if err != nil { + logger.Error("Failed to create DNS proxy: %v", err) + } + + if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil { + logger.Error("Failed to o.tunnelConfigure interface: %v", err) + } + + if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet + logger.Error("Failed to add route for utility subnet: %v", err) + } + + // Create peer manager with integrated peer monitoring + o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ + Device: o.dev, + DNSProxy: o.dnsProxy, + InterfaceName: o.tunnelConfig.InterfaceName, + PrivateKey: o.privateKey, + MiddleDev: o.middleDev, + LocalIP: interfaceIP, + SharedBind: o.sharedBind, + WSClient: o.olmClient, + APIServer: o.apiServer, + }) + + for i := range wgData.Sites { + site := wgData.Sites[i] + var siteEndpoint string + // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer + if site.RelayEndpoint != "" { + siteEndpoint = site.RelayEndpoint + } else { + siteEndpoint = site.Endpoint + } + + o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) + + if err := o.peerManager.AddPeer(site); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + logger.Info("Configured peer %s", site.PublicKey) + } + + o.peerManager.Start() + + if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime + logger.Error("Failed to start DNS proxy: %v", err) + } + + if o.tunnelConfig.OverrideDNS { + // Set up DNS override to use our DNS proxy + if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { + logger.Error("Failed to setup DNS override: %v", err) + return + } + + network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()}) + } + + o.apiServer.SetRegistered(true) + + o.connected = true + + // Invoke onConnected callback if configured + if o.olmConfig.OnConnected != nil { + go o.olmConfig.OnConnected() + } + + logger.Info("WireGuard device created.") +} + +func (o *Olm) handleTerminate(msg websocket.WSMessage) { + logger.Info("Received terminate message") + o.apiServer.SetTerminated(true) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) + o.apiServer.ClearPeerStatuses() + + network.ClearNetworkSettings() + + o.Close() + + if o.olmConfig.OnTerminated != nil { + go o.olmConfig.OnTerminated() + } +} diff --git a/olm/data.go b/olm/data.go new file mode 100644 index 0000000..9c8d33f --- /dev/null +++ b/olm/data.go @@ -0,0 +1,197 @@ +package olm + +import ( + "encoding/json" + "time" + + "github.com/fosrl/newt/holepunch" + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" +) + +func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) { + logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var addSubnetsData peers.PeerAdd + if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil { + logger.Error("Error unmarshaling add-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId) + return + } + + // Add new subnets + for _, subnet := range addSubnetsData.RemoteSubnets { + if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases + for _, alias := range addSubnetsData.Aliases { + if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } +} + +func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) { + logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeSubnetsData peers.RemovePeerData + if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil { + logger.Error("Error unmarshaling remove-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId) + return + } + + // Remove subnets + for _, subnet := range removeSubnetsData.RemoteSubnets { + if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Remove aliases + for _, alias := range removeSubnetsData.Aliases { + if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } +} + +func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) { + logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateSubnetsData peers.UpdatePeerData + if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil { + logger.Error("Error unmarshaling update-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId) + return + } + + // Add new subnets BEFORE removing old ones to preserve shared subnets + // This ensures that if an old and new subnet are the same on different peers, + // the route won't be temporarily removed + for _, subnet := range updateSubnetsData.NewRemoteSubnets { + if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Remove old subnets after new ones are added + for _, subnet := range updateSubnetsData.OldRemoteSubnets { + if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases BEFORE removing old ones to preserve shared IP addresses + // This ensures that if an old and new alias share the same IP, the IP won't be + // temporarily removed from the allowed IPs list + for _, alias := range updateSubnetsData.NewAliases { + if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } + + // Remove old aliases after new ones are added + for _, alias := range updateSubnetsData.OldAliases { + if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } + + logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) +} + +func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { + logger.Debug("Received peer-handshake message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling handshake data: %v", err) + return + } + + var handshakeData struct { + SiteId int `json:"siteId"` + ExitNode struct { + PublicKey string `json:"publicKey"` + Endpoint string `json:"endpoint"` + RelayPort uint16 `json:"relayPort"` + } `json:"exitNode"` + } + + if err := json.Unmarshal(jsonData, &handshakeData); err != nil { + logger.Error("Error unmarshaling handshake data: %v", err) + return + } + + // Get existing peer from PeerManager + _, exists := o.peerManager.GetPeer(handshakeData.SiteId) + if exists { + logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) + return + } + + relayPort := handshakeData.ExitNode.RelayPort + if relayPort == 0 { + relayPort = 21820 // default relay port + } + + siteId := handshakeData.SiteId + exitNode := holepunch.ExitNode{ + Endpoint: handshakeData.ExitNode.Endpoint, + RelayPort: relayPort, + PublicKey: handshakeData.ExitNode.PublicKey, + SiteIds: []int{siteId}, + } + + added := o.holePunchManager.AddExitNode(exitNode) + if added { + logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) + } else { + logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) + } + + o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt + o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud + + // Send handshake acknowledgment back to server with retry + o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + }, 1*time.Second) + + logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) +} diff --git a/olm/olm.go b/olm/olm.go index 774a3cb..6d8f7a5 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -2,15 +2,11 @@ package olm import ( "context" - "encoding/json" "fmt" "net" "net/http" _ "net/http/pprof" "os" - "runtime" - "strconv" - "strings" "time" "github.com/fosrl/newt/bind" @@ -30,41 +26,49 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -var ( - privateKey wgtypes.Key - connected bool - dev *device.Device - uapiListener net.Listener - tdev tun.Device - middleDev *olmDevice.MiddleDevice - interfaceName string +type Olm struct { + privateKey wgtypes.Key + logFile *os.File + + connected bool + tunnelRunning bool + + uapiListener net.Listener + dev *device.Device + tdev tun.Device + middleDev *olmDevice.MiddleDevice + sharedBind *bind.SharedBind + dnsProxy *dns.DNSProxy apiServer *api.API olmClient *websocket.Client - tunnelCancel context.CancelFunc - tunnelRunning bool - sharedBind *bind.SharedBind holePunchManager *holepunch.Manager - globalConfig GlobalConfig - tunnelConfig TunnelConfig - globalCtx context.Context - stopRegister func() - stopPeerSend func() - updateRegister func(newData interface{}) - stopPing chan struct{} peerManager *peers.PeerManager -) + + olmCtx context.Context + tunnelCancel context.CancelFunc + + olmConfig OlmConfig + tunnelConfig TunnelConfig + + stopRegister func() + stopPeerSend func() + updateRegister func(newData any) + + stopPing chan struct{} +} // initTunnelInfo creates the shared UDP socket and holepunch manager. // This is used during initial tunnel setup and when switching organizations. -func initTunnelInfo(clientID string) error { - var err error - privateKey, err = wgtypes.GeneratePrivateKey() +func (o *Olm) initTunnelInfo(clientID string) error { + privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { logger.Error("Failed to generate private key: %v", err) return err } + o.privateKey = privateKey + sourcePort, err := util.FindAvailableUDPPort(49152, 65535) if err != nil { return fmt.Errorf("failed to find available UDP port: %w", err) @@ -80,27 +84,26 @@ func initTunnelInfo(clientID string) error { return fmt.Errorf("failed to create UDP socket: %w", err) } - sharedBind, err = bind.New(udpConn) + sharedBind, err := bind.New(udpConn) if err != nil { - udpConn.Close() + _ = udpConn.Close() return fmt.Errorf("failed to create shared bind: %w", err) } + o.sharedBind = sharedBind + // Add a reference for the hole punch senders (creator already has one reference for WireGuard) sharedBind.AddRef() logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) // Create the holepunch manager - holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) + o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) return nil } -func Init(ctx context.Context, config GlobalConfig) { - globalConfig = config - globalCtx = ctx - +func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) // Start pprof server if enabled @@ -112,25 +115,27 @@ func Init(ctx context.Context, config GlobalConfig) { } }() } - + + var logFile *os.File if config.LogFilePath != "" { - logFile, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + file, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) if err != nil { logger.Fatal("Failed to open log file: %v", err) + return nil, err } - // TODO: figure out how to close file, if set - logger.SetOutput(logFile) - return + logger.SetOutput(file) + logFile = file } logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() if err != nil { logger.Fatal("Insufficient permissions to create native TUN interface: %v", err) - return + return nil, err } + var apiServer *api.API if config.HTTPAddr != "" { apiServer = api.NewAPI(config.HTTPAddr) } else if config.SocketPath != "" { @@ -143,18 +148,24 @@ func Init(ctx context.Context, config GlobalConfig) { apiServer.SetVersion(config.Version) apiServer.SetAgent(config.Agent) - // Set up API handlers - apiServer.SetHandlers( + newOlm := &Olm{ + logFile: logFile, + olmCtx: ctx, + apiServer: apiServer, + olmConfig: config, + } + + newOlm.registerAPICallbacks() + + return newOlm, nil +} + +func (o *Olm) registerAPICallbacks() { + o.apiServer.SetHandlers( // onConnect func(req api.ConnectionRequest) error { logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) - // Stop any existing tunnel before starting a new one - if olmClient != nil { - logger.Info("Stopping existing tunnel before starting new connection") - StopTunnel() - } - tunnelConfig := TunnelConfig{ Endpoint: req.Endpoint, ID: req.ID, @@ -208,7 +219,7 @@ func Init(ctx context.Context, config GlobalConfig) { // Start the tunnel process with the new credentials if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" { logger.Info("Starting tunnel with new credentials") - go StartTunnel(tunnelConfig) + go o.StartTunnel(tunnelConfig) } return nil @@ -216,66 +227,64 @@ func Init(ctx context.Context, config GlobalConfig) { // onSwitchOrg func(req api.SwitchOrgRequest) error { logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID) - return SwitchOrg(req.OrgID) + return o.SwitchOrg(req.OrgID) }, // onDisconnect func() error { logger.Info("Processing disconnect request via API") - return StopTunnel() + return o.StopTunnel() }, // onExit func() error { logger.Info("Processing shutdown request via API") - Close() - if globalConfig.OnExit != nil { - globalConfig.OnExit() + o.Close() + if o.olmConfig.OnExit != nil { + o.olmConfig.OnExit() } return nil }, ) } -func StartTunnel(config TunnelConfig) { - if tunnelRunning { +func (o *Olm) StartTunnel(config TunnelConfig) { + if o.tunnelRunning { logger.Info("Tunnel already running") return } - tunnelRunning = true // Also set it here in case it is called externally - tunnelConfig = config + o.tunnelRunning = true // Also set it here in case it is called externally + o.tunnelConfig = config // Reset terminated status when tunnel starts - apiServer.SetTerminated(false) + o.apiServer.SetTerminated(false) // debug print out the whole config logger.Debug("Starting tunnel with config: %+v", config) // Create a cancellable context for this tunnel process - tunnelCtx, cancel := context.WithCancel(globalCtx) - tunnelCancel = cancel - defer func() { - tunnelCancel = nil - }() + tunnelCtx, cancel := context.WithCancel(o.olmCtx) + o.tunnelCancel = cancel // Recreate channels for this tunnel session - stopPing = make(chan struct{}) + o.stopPing = make(chan struct{}) var ( id = config.ID secret = config.Secret userToken = config.UserToken ) - interfaceName = config.InterfaceName - apiServer.SetOrgID(config.OrgID) + o.tunnelConfig.InterfaceName = config.InterfaceName - // Create a new olm client using the provided credentials - olm, err := websocket.NewClient( - id, // Use provided ID - secret, // Use provided secret - userToken, // Use provided user token OPTIONAL + o.apiServer.SetOrgID(config.OrgID) + + // Create a new olmClient client using the provided credentials + olmClient, err := websocket.NewClient( + id, + secret, + userToken, config.OrgID, - config.Endpoint, // Use provided endpoint + config.Endpoint, config.PingIntervalDuration, config.PingTimeoutDuration, ) @@ -284,638 +293,70 @@ func StartTunnel(config TunnelConfig) { return } - // Store the client reference globally - olmClient = olm - // Create shared UDP socket and holepunch manager - if err := initTunnelInfo(id); err != nil { + if err := o.initTunnelInfo(id); err != nil { logger.Error("%v", err) return } - olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - - var wgData WgData - - if connected { - logger.Info("Already connected. Ignoring new connection request.") - return - } - - if stopRegister != nil { - stopRegister() - stopRegister = nil - } - - if updateRegister != nil { - updateRegister = nil - } - - // if there is an existing tunnel then close it - if dev != nil { - logger.Info("Got new message. Closing existing tunnel!") - dev.Close() - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &wgData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - tdev, err = func() (tun.Device, error) { - if config.FileDescriptorTun != 0 { - return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) - } - ifName := interfaceName - if runtime.GOOS == "darwin" { // this is if we dont pass a fd - ifName, err = network.FindUnusedUTUN() - if err != nil { - return nil, err - } - } - return tun.CreateTUN(ifName, config.MTU) - }() - if err != nil { - logger.Error("Failed to create TUN device: %v", err) - return - } - - // if config.FileDescriptorTun == 0 { - if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? - interfaceName = realInterfaceName - } - // } - - // Wrap TUN device with packet filter for DNS proxy - middleDev = olmDevice.NewMiddleDevice(tdev) - - wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") - // Use filtered device instead of raw TUN device - dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger)) - - if config.EnableUAPI { - fileUAPI, err := func() (*os.File, error) { - if config.FileDescriptorUAPI != 0 { - fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32) - if err != nil { - return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) - } - return os.NewFile(uintptr(fd), ""), nil - } - return olmDevice.UapiOpen(interfaceName) - }() - if err != nil { - logger.Error("UAPI listen error: %v", err) - os.Exit(1) - return - } - - uapiListener, err = olmDevice.UapiListen(interfaceName, fileUAPI) - if err != nil { - logger.Error("Failed to listen on uapi socket: %v", err) - os.Exit(1) - } - - go func() { - for { - conn, err := uapiListener.Accept() - if err != nil { - return - } - go dev.IpcHandle(conn) - } - }() - logger.Info("UAPI listener started") - } - - if err = dev.Up(); err != nil { - logger.Error("Failed to bring up WireGuard device: %v", err) - } - - // Extract interface IP (strip CIDR notation if present) - interfaceIP := wgData.TunnelIP - if strings.Contains(interfaceIP, "/") { - interfaceIP = strings.Split(interfaceIP, "/")[0] - } - - // Create and start DNS proxy - dnsProxy, err = dns.NewDNSProxy(middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP) - if err != nil { - logger.Error("Failed to create DNS proxy: %v", err) - } - - if err = network.ConfigureInterface(interfaceName, wgData.TunnelIP, config.MTU); err != nil { - logger.Error("Failed to configure interface: %v", err) - } - - if network.AddRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet - logger.Error("Failed to add route for utility subnet: %v", err) - } - - // Create peer manager with integrated peer monitoring - peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ - Device: dev, - DNSProxy: dnsProxy, - InterfaceName: interfaceName, - PrivateKey: privateKey, - MiddleDev: middleDev, - LocalIP: interfaceIP, - SharedBind: sharedBind, - WSClient: olm, - APIServer: apiServer, - }) - - for i := range wgData.Sites { - site := wgData.Sites[i] - var siteEndpoint string - // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer - if site.RelayEndpoint != "" { - siteEndpoint = site.RelayEndpoint - } else { - siteEndpoint = site.Endpoint - } - - apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) - - if err := peerManager.AddPeer(site); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - - logger.Info("Configured peer %s", site.PublicKey) - } - - peerManager.Start() - - if err := dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime - logger.Error("Failed to start DNS proxy: %v", err) - } - - if config.OverrideDNS { - // Set up DNS override to use our DNS proxy - if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy.GetProxyIP()); err != nil { - logger.Error("Failed to setup DNS override: %v", err) - return - } - - network.SetDNSServers([]string{dnsProxy.GetProxyIP().String()}) - } - - apiServer.SetRegistered(true) - - connected = true - - // Invoke onConnected callback if configured - if globalConfig.OnConnected != nil { - go globalConfig.OnConnected() - } - - logger.Info("WireGuard device created.") - }) - - olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { - logger.Debug("Received update-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var updateData peers.SiteConfig - if err := json.Unmarshal(jsonData, &updateData); err != nil { - logger.Error("Error unmarshaling update data: %v", err) - return - } - - // Get existing peer from PeerManager - existingPeer, exists := peerManager.GetPeer(updateData.SiteId) - if !exists { - logger.Warn("Peer with site ID %d not found", updateData.SiteId) - return - } - - // Create updated site config by merging with existing data - siteConfig := existingPeer - - if updateData.Endpoint != "" { - siteConfig.Endpoint = updateData.Endpoint - } - if updateData.RelayEndpoint != "" { - siteConfig.RelayEndpoint = updateData.RelayEndpoint - } - if updateData.PublicKey != "" { - siteConfig.PublicKey = updateData.PublicKey - } - if updateData.ServerIP != "" { - siteConfig.ServerIP = updateData.ServerIP - } - if updateData.ServerPort != 0 { - siteConfig.ServerPort = updateData.ServerPort - } - if updateData.RemoteSubnets != nil { - siteConfig.RemoteSubnets = updateData.RemoteSubnets - } - - if err := peerManager.UpdatePeer(siteConfig); err != nil { - logger.Error("Failed to update peer: %v", err) - return - } - - // If the endpoint changed, trigger holepunch to refresh NAT mappings - if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { - logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) - holePunchManager.TriggerHolePunch() - holePunchManager.ResetInterval() - } - - // Update successful - logger.Info("Successfully updated peer for site %d", updateData.SiteId) - }) - - // Handler for adding a new peer - olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { - logger.Debug("Received add-peer message: %v", msg.Data) - - if stopPeerSend != nil { - stopPeerSend() - stopPeerSend = nil - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var siteConfig peers.SiteConfig - if err := json.Unmarshal(jsonData, &siteConfig); err != nil { - logger.Error("Error unmarshaling add data: %v", err) - return - } - - holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it - - if err := peerManager.AddPeer(siteConfig); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - - // Add successful - logger.Info("Successfully added peer for site %d", siteConfig.SiteId) - }) - - // Handler for removing a peer - olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) { - logger.Debug("Received remove-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var removeData peers.PeerRemove - if err := json.Unmarshal(jsonData, &removeData); err != nil { - logger.Error("Error unmarshaling remove data: %v", err) - return - } - - if err := peerManager.RemovePeer(removeData.SiteId); err != nil { - logger.Error("Failed to remove peer: %v", err) - return - } - - // Remove any exit nodes associated with this peer from hole punching - if holePunchManager != nil { - removed := holePunchManager.RemoveExitNodesByPeer(removeData.SiteId) - if removed > 0 { - logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId) - } - } - - // Remove successful - logger.Info("Successfully removed peer for site %d", removeData.SiteId) - }) - - // Handler for adding remote subnets to a peer - olm.RegisterHandler("olm/wg/peer/data/add", func(msg websocket.WSMessage) { - logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var addSubnetsData peers.PeerAdd - if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil { - logger.Error("Error unmarshaling add-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(addSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId) - return - } - - // Add new subnets - for _, subnet := range addSubnetsData.RemoteSubnets { - if err := peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to add allowed IP %s: %v", subnet, err) - } - } - - // Add new aliases - for _, alias := range addSubnetsData.Aliases { - if err := peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil { - logger.Error("Failed to add alias %s: %v", alias.Alias, err) - } - } - }) - - // Handler for removing remote subnets from a peer - olm.RegisterHandler("olm/wg/peer/data/remove", func(msg websocket.WSMessage) { - logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var removeSubnetsData peers.RemovePeerData - if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil { - logger.Error("Error unmarshaling remove-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(removeSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId) - return - } - - // Remove subnets - for _, subnet := range removeSubnetsData.RemoteSubnets { - if err := peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to remove allowed IP %s: %v", subnet, err) - } - } - - // Remove aliases - for _, alias := range removeSubnetsData.Aliases { - if err := peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil { - logger.Error("Failed to remove alias %s: %v", alias.Alias, err) - } - } - }) - - // Handler for updating remote subnets of a peer (remove old, add new in one operation) - olm.RegisterHandler("olm/wg/peer/data/update", func(msg websocket.WSMessage) { - logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var updateSubnetsData peers.UpdatePeerData - if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil { - logger.Error("Error unmarshaling update-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(updateSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", updateSubnetsData.SiteId) - return - } - - // Add new subnets BEFORE removing old ones to preserve shared subnets - // This ensures that if an old and new subnet are the same on different peers, - // the route won't be temporarily removed - for _, subnet := range updateSubnetsData.NewRemoteSubnets { - if err := peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to add allowed IP %s: %v", subnet, err) - } - } - - // Remove old subnets after new ones are added - for _, subnet := range updateSubnetsData.OldRemoteSubnets { - if err := peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to remove allowed IP %s: %v", subnet, err) - } - } - - // Add new aliases BEFORE removing old ones to preserve shared IP addresses - // This ensures that if an old and new alias share the same IP, the IP won't be - // temporarily removed from the allowed IPs list - for _, alias := range updateSubnetsData.NewAliases { - if err := peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil { - logger.Error("Failed to add alias %s: %v", alias.Alias, err) - } - } - - // Remove old aliases after new ones are added - for _, alias := range updateSubnetsData.OldAliases { - if err := peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil { - logger.Error("Failed to remove alias %s: %v", alias.Alias, err) - } - } - - logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) - }) - - olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) { - logger.Debug("Received relay-peer message: %v", msg.Data) - - // Check if peerManager is still valid (may be nil during shutdown) - if peerManager == nil { - logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") - return - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var relayData peers.RelayPeerData - if err := json.Unmarshal(jsonData, &relayData); err != nil { - logger.Error("Error unmarshaling relay data: %v", err) - return - } - - primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - // Update HTTP server to mark this peer as using relay - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) - - peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort) - }) - - olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) { - logger.Debug("Received unrelay-peer message: %v", msg.Data) - - // Check if peerManager is still valid (may be nil during shutdown) - if peerManager == nil { - logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") - return - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var relayData peers.UnRelayPeerData - if err := json.Unmarshal(jsonData, &relayData); err != nil { - logger.Error("Error unmarshaling relay data: %v", err) - return - } - - primaryRelay, err := util.ResolveDomain(relayData.Endpoint) - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - // Update HTTP server to mark this peer as using relay - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) - - peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) - }) + // Handlers for managing connection status + olmClient.RegisterHandler("olm/wg/connect", o.handleConnect) + olmClient.RegisterHandler("olm/terminate", o.handleTerminate) + + // Handlers for managing peers + olmClient.RegisterHandler("olm/wg/peer/add", o.handleWgPeerAdd) + olmClient.RegisterHandler("olm/wg/peer/remove", o.handleWgPeerRemove) + olmClient.RegisterHandler("olm/wg/peer/update", o.handleWgPeerUpdate) + olmClient.RegisterHandler("olm/wg/peer/relay", o.handleWgPeerRelay) + olmClient.RegisterHandler("olm/wg/peer/unrelay", o.handleWgPeerUnrelay) + + // Handlers for managing remote subnets to a peer + olmClient.RegisterHandler("olm/wg/peer/data/add", o.handleWgPeerAddData) + olmClient.RegisterHandler("olm/wg/peer/data/remove", o.handleWgPeerRemoveData) + olmClient.RegisterHandler("olm/wg/peer/data/update", o.handleWgPeerUpdateData) // Handler for peer handshake - adds exit node to holepunch rotation and notifies server - olm.RegisterHandler("olm/wg/peer/holepunch/site/add", func(msg websocket.WSMessage) { - logger.Debug("Received peer-handshake message: %v", msg.Data) + olmClient.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling handshake data: %v", err) - return - } - - var handshakeData struct { - SiteId int `json:"siteId"` - ExitNode struct { - PublicKey string `json:"publicKey"` - Endpoint string `json:"endpoint"` - RelayPort uint16 `json:"relayPort"` - } `json:"exitNode"` - } - - if err := json.Unmarshal(jsonData, &handshakeData); err != nil { - logger.Error("Error unmarshaling handshake data: %v", err) - return - } - - // Get existing peer from PeerManager - _, exists := peerManager.GetPeer(handshakeData.SiteId) - if exists { - logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) - return - } - - relayPort := handshakeData.ExitNode.RelayPort - if relayPort == 0 { - relayPort = 21820 // default relay port - } - - siteId := handshakeData.SiteId - exitNode := holepunch.ExitNode{ - Endpoint: handshakeData.ExitNode.Endpoint, - RelayPort: relayPort, - PublicKey: handshakeData.ExitNode.PublicKey, - SiteIds: []int{siteId}, - } - - added := holePunchManager.AddExitNode(exitNode) - if added { - logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) - } else { - logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) - } - - holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt - holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud - - // Send handshake acknowledgment back to server with retry - stopPeerSend, _ = olm.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": handshakeData.SiteId, - }, 1*time.Second) - - logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) - }) - - olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { - logger.Info("Received terminate message") - apiServer.SetTerminated(true) - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) - apiServer.ClearPeerStatuses() - network.ClearNetworkSettings() - Close() - - if globalConfig.OnTerminated != nil { - go globalConfig.OnTerminated() - } - }) - - olm.RegisterHandler("pong", func(msg websocket.WSMessage) { - logger.Debug("Received pong message") - }) - - olm.OnConnect(func() error { + olmClient.OnConnect(func() error { logger.Info("Websocket Connected") - apiServer.SetConnectionStatus(true) + o.apiServer.SetConnectionStatus(true) - if connected { + if o.connected { logger.Debug("Already connected, skipping registration") return nil } - publicKey := privateKey.PublicKey() + publicKey := o.privateKey.PublicKey() // delay for 500ms to allow for time for the hp to get processed time.Sleep(500 * time.Millisecond) - if stopRegister == nil { + if o.stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) - stopRegister, updateRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + o.stopRegister, o.updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]any{ "publicKey": publicKey.String(), "relay": !config.Holepunch, - "olmVersion": globalConfig.Version, - "olmAgent": globalConfig.Agent, + "olmVersion": o.olmConfig.Version, + "olmAgent": o.olmConfig.Agent, "orgId": config.OrgID, "userToken": userToken, }, 1*time.Second) // Invoke onRegistered callback if configured - if globalConfig.OnRegistered != nil { - go globalConfig.OnRegistered() + if o.olmConfig.OnRegistered != nil { + go o.olmConfig.OnRegistered() } } - go keepSendingPing(olm) + go o.keepSendingPing(olmClient) return nil }) - olm.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { - holePunchManager.SetToken(token) + olmClient.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { + o.holePunchManager.SetToken(token) logger.Debug("Got exit nodes for hole punching: %v", exitNodes) @@ -939,141 +380,113 @@ func StartTunnel(config TunnelConfig) { // Start hole punching using the manager logger.Info("Starting hole punch for %d exit nodes", len(exitNodes)) - if err := holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { + if err := o.holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { logger.Warn("Failed to start hole punch: %v", err) } }) - olm.OnAuthError(func(statusCode int, message string) { + olmClient.OnAuthError(func(statusCode int, message string) { logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message) - apiServer.SetTerminated(true) - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) - apiServer.ClearPeerStatuses() + o.apiServer.SetTerminated(true) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) + o.apiServer.ClearPeerStatuses() network.ClearNetworkSettings() - Close() + o.Close() - if globalConfig.OnAuthError != nil { - go globalConfig.OnAuthError(statusCode, message) + if o.olmConfig.OnAuthError != nil { + go o.olmConfig.OnAuthError(statusCode, message) } - if globalConfig.OnTerminated != nil { - go globalConfig.OnTerminated() + if o.olmConfig.OnTerminated != nil { + go o.olmConfig.OnTerminated() } }) // Connect to the WebSocket server - if err := olm.Connect(); err != nil { + if err := olmClient.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) return } - defer olm.Close() + defer func() { _ = olmClient.Close() }() + + o.olmClient = olmClient // Wait for context cancellation <-tunnelCtx.Done() logger.Info("Tunnel process context cancelled, cleaning up") } -func AddDevice(fd uint32) error { - if middleDev == nil { - return fmt.Errorf("middle device is not initialized") - } - - if tunnelConfig.MTU == 0 { - // error - return fmt.Errorf("tunnel MTU is not set") - } - - tdev, err := olmDevice.CreateTUNFromFD(fd, tunnelConfig.MTU) - - if err != nil { - return fmt.Errorf("failed to create TUN device from fd: %v", err) - } - - // if config.FileDescriptorTun == 0 { - if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? - interfaceName = realInterfaceName - } - - // Here we replace the existing TUN device in the middle device with the new one - middleDev.AddDevice(tdev) - - return nil -} - -func Close() { +func (o *Olm) Close() { // Restore original DNS configuration // we do this first to avoid any DNS issues if something else gets stuck if err := dnsOverride.RestoreDNSOverride(); err != nil { logger.Error("Failed to restore DNS: %v", err) } - // Stop hole punch manager - if holePunchManager != nil { - holePunchManager.Stop() - holePunchManager = nil + if o.holePunchManager != nil { + o.holePunchManager.Stop() + o.holePunchManager = nil } - if stopPing != nil { - select { - case <-stopPing: - // Channel already closed - default: - close(stopPing) - } + if o.stopPing != nil { + close(o.stopPing) + o.stopPing = nil } - if stopRegister != nil { - stopRegister() - stopRegister = nil + if o.stopRegister != nil { + o.stopRegister() + o.stopRegister = nil } - if updateRegister != nil { - updateRegister = nil + // Close() also calls Stop() internally + if o.peerManager != nil { + o.peerManager.Close() + o.peerManager = nil } - if peerManager != nil { - peerManager.Close() // Close() also calls Stop() internally - peerManager = nil + if o.uapiListener != nil { + _ = o.uapiListener.Close() + o.uapiListener = nil } - if uapiListener != nil { - uapiListener.Close() - uapiListener = nil + if o.logFile != nil { + _ = o.logFile.Close() + o.logFile = nil } // Stop DNS proxy first - it uses the middleDev for packet filtering - logger.Debug("Stopping DNS proxy") - if dnsProxy != nil { - dnsProxy.Stop() - dnsProxy = nil + if o.dnsProxy != nil { + logger.Debug("Stopping DNS proxy") + o.dnsProxy.Stop() + o.dnsProxy = nil } // Close MiddleDevice first - this closes the TUN and signals the closed channel // This unblocks the pump goroutine and allows WireGuard's TUN reader to exit - logger.Debug("Closing MiddleDevice") - if middleDev != nil { - middleDev.Close() - middleDev = nil + // Note: o.tdev is closed by o.middleDev.Close() since middleDev wraps it + if o.middleDev != nil { + logger.Debug("Closing MiddleDevice") + _ = o.middleDev.Close() + o.middleDev = nil } - // Note: tdev is closed by middleDev.Close() since middleDev wraps it - tdev = nil // Now close WireGuard device - its TUN reader should have exited by now - logger.Debug("Closing WireGuard device") - if dev != nil { - dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference - dev = nil + // This will call sharedBind.Close() which releases WireGuard's reference + if o.dev != nil { + logger.Debug("Closing WireGuard device") + o.dev.Close() + o.dev = nil } - // Release the hole punch reference to the shared bind - if sharedBind != nil { - // Release hole punch reference (WireGuard already released its reference via dev.Close()) - logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount()) - sharedBind.Release() - sharedBind = nil + // Release the hole punch reference to the shared bind (WireGuard already + // released its reference via dev.Close()) + if o.sharedBind != nil { + logger.Debug("Releasing shared bind (refcount before release: %d)", o.sharedBind.GetRefCount()) + _ = o.sharedBind.Release() logger.Info("Released shared UDP bind") + o.sharedBind = nil } logger.Info("Olm service stopped") @@ -1081,78 +494,85 @@ func Close() { // StopTunnel stops just the tunnel process and websocket connection // without shutting down the entire application -func StopTunnel() error { +func (o *Olm) StopTunnel() error { logger.Info("Stopping tunnel process") + if !o.tunnelRunning { + logger.Debug("Tunnel not running, nothing to stop") + return nil + } + // Cancel the tunnel context if it exists - if tunnelCancel != nil { - tunnelCancel() + if o.tunnelCancel != nil { + o.tunnelCancel() // Give it a moment to clean up time.Sleep(200 * time.Millisecond) } // Close the websocket connection - if olmClient != nil { - olmClient.Close() - olmClient = nil + if o.olmClient != nil { + _ = o.olmClient.Close() + o.olmClient = nil } - Close() + o.Close() // Reset the connected state - connected = false - tunnelRunning = false + o.connected = false + o.tunnelRunning = false // Update API server status - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) network.ClearNetworkSettings() - apiServer.ClearPeerStatuses() + o.apiServer.ClearPeerStatuses() logger.Info("Tunnel process stopped") return nil } -func StopApi() error { - if apiServer != nil { - err := apiServer.Stop() +func (o *Olm) StopApi() error { + if o.apiServer != nil { + err := o.apiServer.Stop() if err != nil { return fmt.Errorf("failed to stop API server: %w", err) } } + return nil } -func StartApi() error { - if apiServer != nil { - err := apiServer.Start() +func (o *Olm) StartApi() error { + if o.apiServer != nil { + err := o.apiServer.Start() if err != nil { return fmt.Errorf("failed to start API server: %w", err) } } + return nil } -func GetStatus() api.StatusResponse { - return apiServer.GetStatus() +func (o *Olm) GetStatus() api.StatusResponse { + return o.apiServer.GetStatus() } -func SwitchOrg(orgID string) error { +func (o *Olm) SwitchOrg(orgID string) error { logger.Info("Processing org switch request to orgId: %s", orgID) // stop the tunnel - if err := StopTunnel(); err != nil { + if err := o.StopTunnel(); err != nil { return fmt.Errorf("failed to stop existing tunnel: %w", err) } // Update the org ID in the API server and global config - apiServer.SetOrgID(orgID) + o.apiServer.SetOrgID(orgID) - tunnelConfig.OrgID = orgID + o.tunnelConfig.OrgID = orgID // Restart the tunnel with the same config but new org ID - go StartTunnel(tunnelConfig) + go o.StartTunnel(o.tunnelConfig) return nil } diff --git a/olm/peer.go b/olm/peer.go new file mode 100644 index 0000000..8acec42 --- /dev/null +++ b/olm/peer.go @@ -0,0 +1,195 @@ +package olm + +import ( + "encoding/json" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" +) + +func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { + logger.Debug("Received add-peer message: %v", msg.Data) + + if o.stopPeerSend != nil { + o.stopPeerSend() + o.stopPeerSend = nil + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var siteConfig peers.SiteConfig + if err := json.Unmarshal(jsonData, &siteConfig); err != nil { + logger.Error("Error unmarshaling add data: %v", err) + return + } + + _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it + + if err := o.peerManager.AddPeer(siteConfig); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + logger.Info("Successfully added peer for site %d", siteConfig.SiteId) +} + +func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { + logger.Debug("Received remove-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeData peers.PeerRemove + if err := json.Unmarshal(jsonData, &removeData); err != nil { + logger.Error("Error unmarshaling remove data: %v", err) + return + } + + if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil { + logger.Error("Failed to remove peer: %v", err) + return + } + + // Remove any exit nodes associated with this peer from hole punching + if o.holePunchManager != nil { + removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId) + if removed > 0 { + logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId) + } + } + + logger.Info("Successfully removed peer for site %d", removeData.SiteId) +} + +func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) { + logger.Debug("Received update-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateData peers.SiteConfig + if err := json.Unmarshal(jsonData, &updateData); err != nil { + logger.Error("Error unmarshaling update data: %v", err) + return + } + + // Get existing peer from PeerManager + existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId) + if !exists { + logger.Warn("Peer with site ID %d not found", updateData.SiteId) + return + } + + // Create updated site config by merging with existing data + siteConfig := existingPeer + + if updateData.Endpoint != "" { + siteConfig.Endpoint = updateData.Endpoint + } + if updateData.RelayEndpoint != "" { + siteConfig.RelayEndpoint = updateData.RelayEndpoint + } + if updateData.PublicKey != "" { + siteConfig.PublicKey = updateData.PublicKey + } + if updateData.ServerIP != "" { + siteConfig.ServerIP = updateData.ServerIP + } + if updateData.ServerPort != 0 { + siteConfig.ServerPort = updateData.ServerPort + } + if updateData.RemoteSubnets != nil { + siteConfig.RemoteSubnets = updateData.RemoteSubnets + } + + if err := o.peerManager.UpdatePeer(siteConfig); err != nil { + logger.Error("Failed to update peer: %v", err) + return + } + + // If the endpoint changed, trigger holepunch to refresh NAT mappings + if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { + logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) + _ = o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetInterval() + } + + logger.Info("Successfully updated peer for site %d", updateData.SiteId) +} + +func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) { + logger.Debug("Received relay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if o.peerManager == nil { + logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.RelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) + if err != nil { + logger.Error("Failed to resolve primary relay endpoint: %v", err) + return + } + + // Update HTTP server to mark this peer as using relay + o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) + + o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort) +} + +func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { + logger.Debug("Received unrelay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if o.peerManager == nil { + logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.UnRelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.Endpoint) + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + // Update HTTP server to mark this peer as using relay + o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) + + o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) +} diff --git a/olm/util.go b/olm/ping.go similarity index 89% rename from olm/util.go rename to olm/ping.go index 6bfd171..bbeee9a 100644 --- a/olm/util.go +++ b/olm/ping.go @@ -9,7 +9,7 @@ import ( ) func sendPing(olm *websocket.Client) error { - err := olm.SendMessage("olm/ping", map[string]interface{}{ + err := olm.SendMessage("olm/ping", map[string]any{ "timestamp": time.Now().Unix(), "userToken": olm.GetConfig().UserToken, }) @@ -21,7 +21,7 @@ func sendPing(olm *websocket.Client) error { return nil } -func keepSendingPing(olm *websocket.Client) { +func (o *Olm) keepSendingPing(olm *websocket.Client) { // Send ping immediately on startup if err := sendPing(olm); err != nil { logger.Error("Failed to send initial ping: %v", err) @@ -35,7 +35,7 @@ func keepSendingPing(olm *websocket.Client) { for { select { - case <-stopPing: + case <-o.stopPing: logger.Info("Stopping ping messages") return case <-ticker.C: diff --git a/olm/types.go b/olm/types.go index 9187860..77c0b5f 100644 --- a/olm/types.go +++ b/olm/types.go @@ -12,7 +12,7 @@ type WgData struct { UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses } -type GlobalConfig struct { +type OlmConfig struct { // Logging LogLevel string LogFilePath string From 1ecb97306f3acb0b7c9419bf15dc80dbc8c8323c Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 13 Jan 2026 21:38:37 -0800 Subject: [PATCH 21/52] Add back AddDevice function Former-commit-id: cae0ffa2e151d157f485ae6e52f6069a2f883fc0 --- olm/olm.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/olm/olm.go b/olm/olm.go index 6d8f7a5..2db3630 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -576,3 +576,29 @@ func (o *Olm) SwitchOrg(orgID string) error { return nil } + +func (o *Olm) AddDevice(fd uint32) error { + if o.middleDev == nil { + return fmt.Errorf("middle device is not initialized") + } + + if o.tunnelConfig.MTU == 0 { + return fmt.Errorf("tunnel MTU is not set") + } + + tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU) + if err != nil { + return fmt.Errorf("failed to create TUN device from fd: %v", err) + } + + // Update interface name if available + if realInterfaceName, err2 := tdev.Name(); err2 == nil { + o.tunnelConfig.InterfaceName = realInterfaceName + } + + // Replace the existing TUN device in the middle device with the new one + o.middleDev.AddDevice(tdev) + + logger.Info("Added device from file descriptor %d", fd) + return nil +} From 2ab979058820c3b7da74d6f3c4c9b3edb3622b24 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 11:12:10 -0800 Subject: [PATCH 22/52] Reduce the pings Former-commit-id: 5c6ad1ea75f85558195791e3129e745a70b7fa54 --- olm/ping.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/olm/ping.go b/olm/ping.go index bbeee9a..fd7706a 100644 --- a/olm/ping.go +++ b/olm/ping.go @@ -9,6 +9,7 @@ import ( ) func sendPing(olm *websocket.Client) error { + logger.Debug("Sending ping message") err := olm.SendMessage("olm/ping", map[string]any{ "timestamp": time.Now().Unix(), "userToken": olm.GetConfig().UserToken, @@ -30,7 +31,7 @@ func (o *Olm) keepSendingPing(olm *websocket.Client) { } // Set up ticker for one minute intervals - ticker := time.NewTicker(1 * time.Minute) + ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { From c86df2c0416d593c2d60c26205d023e67e7a073f Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 11:58:12 -0800 Subject: [PATCH 23/52] Refactor operation Former-commit-id: 4f09d122bb53f3a32f73624edb2ca7d1c26b3175 --- olm/connect.go | 2 +- olm/data.go | 4 +- olm/olm.go | 115 ++++++++++++++------------------------------ olm/ping.go | 56 --------------------- websocket/client.go | 58 ++++++++++++++++------ 5 files changed, 82 insertions(+), 153 deletions(-) delete mode 100644 olm/ping.go diff --git a/olm/connect.go b/olm/connect.go index 568c731..a610ea4 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -154,7 +154,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { MiddleDev: o.middleDev, LocalIP: interfaceIP, SharedBind: o.sharedBind, - WSClient: o.olmClient, + WSClient: o.websocket, APIServer: o.apiServer, }) diff --git a/olm/data.go b/olm/data.go index 9c8d33f..93e64d0 100644 --- a/olm/data.go +++ b/olm/data.go @@ -189,9 +189,9 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud // Send handshake acknowledgment back to server with retry - o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ "siteId": handshakeData.SiteId, - }, 1*time.Second) + }, 1*time.Second, 10) logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) } diff --git a/olm/olm.go b/olm/olm.go index 15e3a6a..63b53a7 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -41,7 +41,7 @@ type Olm struct { dnsProxy *dns.DNSProxy apiServer *api.API - olmClient *websocket.Client + websocket *websocket.Client holePunchManager *holepunch.Manager peerManager *peers.PeerManager // Power mode management @@ -57,10 +57,11 @@ type Olm struct { tunnelConfig TunnelConfig stopRegister func() - stopPeerSend func() updateRegister func(newData any) - stopPing chan struct{} + stopServerPing func() + + stopPeerSend func() } // initTunnelInfo creates the shared UDP socket and holepunch manager. @@ -270,9 +271,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { tunnelCtx, cancel := context.WithCancel(o.olmCtx) o.tunnelCancel = cancel - // Recreate channels for this tunnel session - o.stopPing = make(chan struct{}) - var ( id = config.ID secret = config.Secret @@ -328,6 +326,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.apiServer.SetConnectionStatus(true) + // restart the ping if we need to + if o.stopServerPing == nil { + o.stopServerPing, _ = olmClient.SendMessageInterval("olm/ping", map[string]any{ + "timestamp": time.Now().Unix(), + "userToken": olmClient.GetConfig().UserToken, + }, 30*time.Second, -1) // -1 means dont time out with the max attempts + } + if o.connected { logger.Debug("Already connected, skipping registration") return nil @@ -347,7 +353,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { "olmAgent": o.olmConfig.Agent, "orgId": config.OrgID, "userToken": userToken, - }, 1*time.Second) + }, 1*time.Second, 10) // Invoke onRegistered callback if configured if o.olmConfig.OnRegistered != nil { @@ -355,8 +361,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } } - go o.keepSendingPing(olmClient) - return nil }) @@ -416,7 +420,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } defer func() { _ = olmClient.Close() }() - o.olmClient = olmClient + o.websocket = olmClient // Wait for context cancellation <-tunnelCtx.Done() @@ -435,9 +439,9 @@ func (o *Olm) Close() { o.holePunchManager = nil } - if o.stopPing != nil { - close(o.stopPing) - o.stopPing = nil + if o.stopServerPing != nil { + o.stopServerPing() + o.stopServerPing = nil } if o.stopRegister != nil { @@ -515,9 +519,9 @@ func (o *Olm) StopTunnel() error { } // Close the websocket connection - if o.olmClient != nil { - _ = o.olmClient.Close() - o.olmClient = nil + if o.websocket != nil { + _ = o.websocket.Close() + o.websocket = nil } o.Close() @@ -602,25 +606,13 @@ func (o *Olm) SetPowerMode(mode string) error { if mode == "low" { // Low Power Mode: Close websocket and reduce monitoring frequency - if o.olmClient != nil { + if o.websocket != nil { logger.Info("Closing websocket connection for low power mode") - if err := o.olmClient.Close(); err != nil { + if err := o.websocket.Close(); err != nil { logger.Error("Error closing websocket: %v", err) } } - if o.stopPing != nil { - select { - case <-o.stopPing: - default: - close(o.stopPing) - } - } - - if o.peerManager != nil { - o.peerManager.Stop() - } - if o.originalPeerInterval == 0 && o.peerManager != nil { peerMonitor := o.peerManager.GetPeerMonitor() if peerMonitor != nil { @@ -639,10 +631,6 @@ func (o *Olm) SetPowerMode(mode string) error { } } - if o.peerManager != nil { - o.peerManager.Start() - } - o.currentPowerMode = "low" logger.Info("Switched to low power mode") @@ -669,60 +657,19 @@ func (o *Olm) SetPowerMode(mode string) error { } } - if o.peerManager != nil { - o.peerManager.Start() - } + logger.Info("Reconnecting websocket for normal power mode") - if o.tunnelConfig.ID != "" && o.tunnelConfig.Secret != "" && o.tunnelConfig.Endpoint != "" { - logger.Info("Reconnecting websocket for normal power mode") - - if o.olmClient != nil { - o.olmClient.Close() - } - - o.stopPing = make(chan struct{}) - - var ( - id = o.tunnelConfig.ID - secret = o.tunnelConfig.Secret - userToken = o.tunnelConfig.UserToken - ) - - olm, err := websocket.NewClient( - id, - secret, - userToken, - o.tunnelConfig.OrgID, - o.tunnelConfig.Endpoint, - o.tunnelConfig.PingIntervalDuration, - o.tunnelConfig.PingTimeoutDuration, - ) - if err != nil { - logger.Error("Failed to create new websocket client: %v", err) - return fmt.Errorf("failed to create new websocket client: %w", err) - } - - o.olmClient = olm - - olm.OnConnect(func() error { - logger.Info("Websocket Reconnected") - o.apiServer.SetConnectionStatus(true) - go o.keepSendingPing(olm) - return nil - }) - - if err := olm.Connect(); err != nil { + if o.websocket != nil { + if err := o.websocket.Connect(); err != nil { logger.Error("Failed to reconnect websocket: %v", err) return fmt.Errorf("failed to reconnect websocket: %w", err) } - } else { - logger.Warn("Cannot reconnect websocket: tunnel config not available") } o.currentPowerMode = "normal" logger.Info("Switched to normal power mode") } - + return nil } @@ -749,6 +696,14 @@ func (o *Olm) AddDevice(fd uint32) error { o.middleDev.AddDevice(tdev) logger.Info("Added device from file descriptor %d", fd) - + return nil } + +func GetNetworkSettingsJSON() (string, error) { + return network.GetJSON() +} + +func GetNetworkSettingsIncrementor() int { + return network.GetIncrementor() +} diff --git a/olm/ping.go b/olm/ping.go deleted file mode 100644 index fd7706a..0000000 --- a/olm/ping.go +++ /dev/null @@ -1,56 +0,0 @@ -package olm - -import ( - "time" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/network" - "github.com/fosrl/olm/websocket" -) - -func sendPing(olm *websocket.Client) error { - logger.Debug("Sending ping message") - err := olm.SendMessage("olm/ping", map[string]any{ - "timestamp": time.Now().Unix(), - "userToken": olm.GetConfig().UserToken, - }) - if err != nil { - logger.Error("Failed to send ping message: %v", err) - return err - } - logger.Debug("Sent ping message") - return nil -} - -func (o *Olm) keepSendingPing(olm *websocket.Client) { - // Send ping immediately on startup - if err := sendPing(olm); err != nil { - logger.Error("Failed to send initial ping: %v", err) - } else { - logger.Info("Sent initial ping message") - } - - // Set up ticker for one minute intervals - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for { - select { - case <-o.stopPing: - logger.Info("Stopping ping messages") - return - case <-ticker.C: - if err := sendPing(olm); err != nil { - logger.Error("Failed to send periodic ping: %v", err) - } - } - } -} - -func GetNetworkSettingsJSON() (string, error) { - return network.GetJSON() -} - -func GetNetworkSettingsIncrementor() int { - return network.GetIncrementor() -} diff --git a/websocket/client.go b/websocket/client.go index 1c5afaf..34eea35 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -77,6 +77,7 @@ type Client struct { handlersMux sync.RWMutex reconnectInterval time.Duration isConnected bool + isDisconnected bool // Flag to track if client is intentionally disconnected reconnectMux sync.RWMutex pingInterval time.Duration pingTimeout time.Duration @@ -173,6 +174,9 @@ func (c *Client) GetConfig() *Config { // Connect establishes the WebSocket connection func (c *Client) Connect() error { + if c.isDisconnected { + c.isDisconnected = false + } go c.connectWithRetry() return nil } @@ -205,9 +209,25 @@ func (c *Client) Close() error { return nil } +// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later. +func (c *Client) Disconnect() error { + c.isDisconnected = true + c.setConnected(false) + + if c.conn != nil { + c.writeMux.Lock() + c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + c.writeMux.Unlock() + err := c.conn.Close() + c.conn = nil + return err + } + return nil +} + // SendMessage sends a message through the WebSocket connection func (c *Client) SendMessage(messageType string, data interface{}) error { - if c.conn == nil { + if c.isDisconnected || c.conn == nil { return fmt.Errorf("not connected") } @@ -223,7 +243,7 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { return c.conn.WriteJSON(msg) } -func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) { +func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) { stopChan := make(chan struct{}) updateChan := make(chan interface{}) var dataMux sync.Mutex @@ -231,30 +251,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter go func() { count := 0 - maxAttempts := 10 - err := c.SendMessage(messageType, currentData) // Send immediately - if err != nil { - logger.Error("Failed to send initial message: %v", err) + send := func() { + if c.isDisconnected || c.conn == nil { + return + } + err := c.SendMessage(messageType, currentData) + if err != nil { + logger.Error("Failed to send message: %v", err) + } + count++ } - count++ + + send() // Send immediately ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: - if count >= maxAttempts { + if maxAttempts != -1 && count >= maxAttempts { logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) return } dataMux.Lock() - err = c.SendMessage(messageType, currentData) + send() dataMux.Unlock() - if err != nil { - logger.Error("Failed to send message: %v", err) - } - count++ case newData := <-updateChan: dataMux.Lock() // Merge newData into currentData if both are maps @@ -277,6 +299,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter case <-stopChan: return } + // Suspend sending if disconnected + for c.isDisconnected { + select { + case <-stopChan: + return + case <-time.After(500 * time.Millisecond): + } + } } }() return func() { @@ -587,7 +617,7 @@ func (c *Client) pingMonitor() { case <-c.done: return case <-ticker.C: - if c.conn == nil { + if c.isDisconnected || c.conn == nil { return } c.writeMux.Lock() From 3470da76fccb0f8748a8a791be8d03001bc557e0 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 12:32:29 -0800 Subject: [PATCH 24/52] Update resetting intervals Former-commit-id: 303c2dc0b78336f6d9aafad87ff3854ed8461ab7 --- olm/olm.go | 111 ++++++++++++++++++++++++--------------- olm/types.go | 2 + peers/manager.go | 33 ++++++++++-- peers/monitor/monitor.go | 63 +++++++++++++++------- peers/peer.go | 22 +++++++- 5 files changed, 165 insertions(+), 66 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index 63b53a7..6a0a26f 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -7,6 +7,7 @@ import ( "net/http" _ "net/http/pprof" "os" + "sync" "time" "github.com/fosrl/newt/bind" @@ -46,9 +47,9 @@ type Olm struct { peerManager *peers.PeerManager // Power mode management currentPowerMode string - originalPeerInterval time.Duration - originalHolepunchMinInterval time.Duration - originalHolepunchMaxInterval time.Duration + powerModeMu sync.Mutex + wakeUpTimer *time.Timer + wakeUpDebounce time.Duration olmCtx context.Context tunnelCancel context.CancelFunc @@ -133,6 +134,10 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.SetOutput(file) logFile = file } + + if config.WakeUpDebounce == 0 { + config.WakeUpDebounce = 3 * time.Second + } logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() @@ -589,22 +594,38 @@ func (o *Olm) SwitchOrg(orgID string) error { // SetPowerMode switches between normal and low power modes // In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes // In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored +// Wake-up has a 3-second debounce to prevent rapid flip-flopping; sleep is immediate func (o *Olm) SetPowerMode(mode string) error { // Validate mode if mode != "normal" && mode != "low" { return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode) } + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + // If already in the requested mode, return early if o.currentPowerMode == mode { + // Cancel any pending wake-up timer if we're already in normal mode + if mode == "normal" && o.wakeUpTimer != nil { + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } logger.Debug("Already in %s power mode", mode) return nil } - logger.Info("Switching to %s power mode", mode) - if mode == "low" { - // Low Power Mode: Close websocket and reduce monitoring frequency + // Low Power Mode: Cancel any pending wake-up and immediately go to sleep + + // Cancel pending wake-up timer if any + if o.wakeUpTimer != nil { + logger.Debug("Cancelling pending wake-up timer") + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } + + logger.Info("Switching to low power mode") if o.websocket != nil { logger.Info("Closing websocket connection for low power mode") @@ -613,14 +634,6 @@ func (o *Olm) SetPowerMode(mode string) error { } } - if o.originalPeerInterval == 0 && o.peerManager != nil { - peerMonitor := o.peerManager.GetPeerMonitor() - if peerMonitor != nil { - o.originalPeerInterval = 2 * time.Second - o.originalHolepunchMinInterval, o.originalHolepunchMaxInterval = peerMonitor.GetHolepunchIntervals() - } - } - if o.peerManager != nil { peerMonitor := o.peerManager.GetPeerMonitor() if peerMonitor != nil { @@ -629,45 +642,61 @@ func (o *Olm) SetPowerMode(mode string) error { peerMonitor.SetHolepunchInterval(lowPowerInterval, lowPowerInterval) logger.Info("Set monitoring intervals to 10 minutes for low power mode") } + o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable } o.currentPowerMode = "low" logger.Info("Switched to low power mode") } else { - // Normal Power Mode: Restore intervals and reconnect websocket + // Normal Power Mode: Start debounce timer before actually waking up - if o.peerManager != nil { - peerMonitor := o.peerManager.GetPeerMonitor() - if peerMonitor != nil { - if o.originalPeerInterval == 0 { - o.originalPeerInterval = 2 * time.Second - } - peerMonitor.SetInterval(o.originalPeerInterval) - - if o.originalHolepunchMinInterval == 0 { - o.originalHolepunchMinInterval = 2 * time.Second - } - if o.originalHolepunchMaxInterval == 0 { - o.originalHolepunchMaxInterval = 30 * time.Second - } - peerMonitor.SetHolepunchInterval(o.originalHolepunchMinInterval, o.originalHolepunchMaxInterval) - logger.Info("Restored monitoring intervals to normal (peer: %v, holepunch: %v-%v)", - o.originalPeerInterval, o.originalHolepunchMinInterval, o.originalHolepunchMaxInterval) - } + // If there's already a pending wake-up timer, don't start another + if o.wakeUpTimer != nil { + logger.Debug("Wake-up already pending, ignoring duplicate request") + return nil } - logger.Info("Reconnecting websocket for normal power mode") + logger.Info("Wake-up requested, starting %v debounce timer", o.wakeUpDebounce) - if o.websocket != nil { - if err := o.websocket.Connect(); err != nil { - logger.Error("Failed to reconnect websocket: %v", err) - return fmt.Errorf("failed to reconnect websocket: %w", err) + o.wakeUpTimer = time.AfterFunc(o.wakeUpDebounce, func() { + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + + // Clear the timer reference + o.wakeUpTimer = nil + + // Double-check we're still in low power mode (could have changed) + if o.currentPowerMode == "normal" { + logger.Debug("Already in normal mode after debounce, skipping wake-up") + return } - } - o.currentPowerMode = "normal" - logger.Info("Switched to normal power mode") + logger.Info("Debounce complete, switching to normal power mode") + + // Restore intervals and reconnect websocket + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.ResetHolepunchInterval() + peerMonitor.ResetInterval() + } + + o.peerManager.UpdateAllPeersPersistentKeepalive(5) + } + + logger.Info("Reconnecting websocket for normal power mode") + + if o.websocket != nil { + if err := o.websocket.Connect(); err != nil { + logger.Error("Failed to reconnect websocket: %v", err) + return + } + } + + o.currentPowerMode = "normal" + logger.Info("Switched to normal power mode") + }) } return nil diff --git a/olm/types.go b/olm/types.go index 77c0b5f..397eab9 100644 --- a/olm/types.go +++ b/olm/types.go @@ -23,6 +23,8 @@ type OlmConfig struct { SocketPath string Version string Agent string + + WakeUpDebounce time.Duration // Debugging PprofAddr string // Address to serve pprof on (e.g., "localhost:6060") diff --git a/peers/manager.go b/peers/manager.go index 56f3707..0566775 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -50,6 +50,8 @@ type PeerManager struct { // key is the CIDR string, value is a set of siteIds that want this IP allowedIPClaims map[string]map[int]bool APIServer *api.API + + PersistentKeepalive int } // NewPeerManager creates a new PeerManager with an internal PeerMonitor @@ -127,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -166,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { return nil } +// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once +// without recreating them. Returns a map of siteId to error for any peers that failed to update. +func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error { + pm.mu.RLock() + defer pm.mu.RUnlock() + + pm.PersistentKeepalive = interval + + errors := make(map[int]error) + + for siteId, peer := range pm.peers { + err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval) + if err != nil { + errors[siteId] = err + } + } + + if len(errors) == 0 { + return nil + } + return errors +} + func (pm *PeerManager) RemovePeer(siteId int) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -245,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error { ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId) wgConfig := promotedPeer wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } @@ -321,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -331,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId) promotedWgConfig := promotedPeer promotedWgConfig.AllowedIps = promotedOwnedIPs - if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 2bb0c80..3ac4b54 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -28,13 +28,14 @@ import ( // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { - monitors map[int]*Client - mutex sync.Mutex - running bool - interval time.Duration - timeout time.Duration - maxAttempts int - wsClient *websocket.Client + monitors map[int]*Client + mutex sync.Mutex + running bool + defaultInterval time.Duration + interval time.Duration + timeout time.Duration + maxAttempts int + wsClient *websocket.Client // Netstack fields middleDev *middleDevice.MiddleDevice @@ -50,7 +51,6 @@ type PeerMonitor struct { // Holepunch testing fields sharedBind *bind.SharedBind holepunchTester *holepunch.HolepunchTester - holepunchInterval time.Duration holepunchTimeout time.Duration holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchStatus map[int]bool // siteID -> connected status @@ -62,11 +62,13 @@ type PeerMonitor struct { holepunchFailures map[int]int // siteID -> consecutive failure count // Exponential backoff fields for holepunch monitor - holepunchMinInterval time.Duration // Minimum interval (initial) - holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) - holepunchBackoffMultiplier float64 // Multiplier for each stable check - holepunchStableCount map[int]int // siteID -> consecutive stable status count - holepunchCurrentInterval time.Duration // Current interval with backoff applied + defaultHolepunchMinInterval time.Duration // Minimum interval (initial) + defaultHolepunchMaxInterval time.Duration + holepunchMinInterval time.Duration // Minimum interval (initial) + holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) + holepunchBackoffMultiplier float64 // Multiplier for each stable check + holepunchStableCount map[int]int // siteID -> consecutive stable status count + holepunchCurrentInterval time.Duration // Current interval with backoff applied // Rapid initial test fields rapidTestInterval time.Duration // interval between rapid test attempts @@ -85,6 +87,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), + defaultInterval: 2 * time.Second, interval: 2 * time.Second, // Default check interval (faster) timeout: 3 * time.Second, maxAttempts: 3, @@ -95,7 +98,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe nsCtx: ctx, nsCancel: cancel, sharedBind: sharedBind, - holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds holepunchTimeout: 2 * time.Second, // Faster timeout holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), @@ -109,11 +111,13 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe apiServer: apiServer, wgConnectionStatus: make(map[int]bool), // Exponential backoff settings for holepunch monitor - holepunchMinInterval: 2 * time.Second, - holepunchMaxInterval: 30 * time.Second, - holepunchBackoffMultiplier: 1.5, - holepunchStableCount: make(map[int]int), - holepunchCurrentInterval: 2 * time.Second, + defaultHolepunchMinInterval: 2 * time.Second, + defaultHolepunchMaxInterval: 30 * time.Second, + holepunchMinInterval: 2 * time.Second, + holepunchMaxInterval: 30 * time.Second, + holepunchBackoffMultiplier: 1.5, + holepunchStableCount: make(map[int]int), + holepunchCurrentInterval: 2 * time.Second, } if err := pm.initNetstack(); err != nil { @@ -141,6 +145,18 @@ func (pm *PeerMonitor) SetInterval(interval time.Duration) { } } +func (pm *PeerMonitor) ResetInterval() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.interval = pm.defaultInterval + + // Update interval for all existing monitors + for _, client := range pm.monitors { + client.SetPacketInterval(pm.defaultInterval) + } +} + // SetTimeout changes the timeout for waiting for responses func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { pm.mutex.Lock() @@ -186,6 +202,15 @@ func (pm *PeerMonitor) GetHolepunchIntervals() (minInterval, maxInterval time.Du return pm.holepunchMinInterval, pm.holepunchMaxInterval } +func (pm *PeerMonitor) ResetHolepunchInterval() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.holepunchMinInterval = pm.defaultHolepunchMinInterval + pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval + pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval +} + // AddPeer adds a new peer to monitor func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint string) error { pm.mutex.Lock() diff --git a/peers/peer.go b/peers/peer.go index 9370b9d..8211fa4 100644 --- a/peers/peer.go +++ b/peers/peer.go @@ -11,7 +11,7 @@ import ( ) // ConfigurePeer sets up or updates a peer within the WireGuard device -func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error { +func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error { var endpoint string if relay && siteConfig.RelayEndpoint != "" { endpoint = formatEndpoint(siteConfig.RelayEndpoint) @@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes } configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) - configBuilder.WriteString("persistent_keepalive_interval=5\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive)) config := configBuilder.String() logger.Debug("Configuring peer with config: %s", config) @@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [ return nil } +// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it +func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error { + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) + configBuilder.WriteString("update_only=true\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval)) + + config := configBuilder.String() + logger.Debug("Updating persistent keepalive for peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err) + } + + return nil +} + func formatEndpoint(endpoint string) string { if strings.Contains(endpoint, ":") { return endpoint From 3ba171452488a9a944687617a509009edce84a83 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 16:38:40 -0800 Subject: [PATCH 25/52] Power state getting set correctly Former-commit-id: 0895156efd764c365b9196e55dcb1199b3ec9b1c --- go.mod | 2 + go.sum | 2 - olm/data.go | 2 +- olm/olm.go | 56 ++++++----- olm/peer.go | 2 +- peers/monitor/monitor.go | 198 +++++++++++++++++++------------------- peers/monitor/wgtester.go | 109 +++++++++++++-------- websocket/client.go | 61 +++++++----- 8 files changed, 239 insertions(+), 193 deletions(-) diff --git a/go.mod b/go.mod index 4f42df6..0d6bbcb 100644 --- a/go.mod +++ b/go.mod @@ -30,3 +30,5 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect ) + +replace github.com/fosrl/newt => ../newt diff --git a/go.sum b/go.sum index a543b5a..f6ca61a 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/fosrl/newt v1.8.0 h1:wIRCO2shhCpkFzsbNbb4g2LC7mPzIpp2ialNveBMJy4= -github.com/fosrl/newt v1.8.0/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= diff --git a/olm/data.go b/olm/data.go index 93e64d0..fe0b36a 100644 --- a/olm/data.go +++ b/olm/data.go @@ -186,7 +186,7 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { } o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt - o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud + o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud // Send handshake acknowledgment back to server with retry o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ diff --git a/olm/olm.go b/olm/olm.go index 6a0a26f..3f197ae 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -46,10 +46,10 @@ type Olm struct { holePunchManager *holepunch.Manager peerManager *peers.PeerManager // Power mode management - currentPowerMode string - powerModeMu sync.Mutex - wakeUpTimer *time.Timer - wakeUpDebounce time.Duration + currentPowerMode string + powerModeMu sync.Mutex + wakeUpTimer *time.Timer + wakeUpDebounce time.Duration olmCtx context.Context tunnelCancel context.CancelFunc @@ -134,7 +134,7 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.SetOutput(file) logFile = file } - + if config.WakeUpDebounce == 0 { config.WakeUpDebounce = 3 * time.Second } @@ -628,23 +628,28 @@ func (o *Olm) SetPowerMode(mode string) error { logger.Info("Switching to low power mode") if o.websocket != nil { - logger.Info("Closing websocket connection for low power mode") - if err := o.websocket.Close(); err != nil { - logger.Error("Error closing websocket: %v", err) + logger.Info("Disconnecting websocket for low power mode") + if err := o.websocket.Disconnect(); err != nil { + logger.Error("Error disconnecting websocket: %v", err) } } + lowPowerInterval := 10 * time.Minute + if o.peerManager != nil { peerMonitor := o.peerManager.GetPeerMonitor() if peerMonitor != nil { - lowPowerInterval := 10 * time.Minute - peerMonitor.SetInterval(lowPowerInterval) - peerMonitor.SetHolepunchInterval(lowPowerInterval, lowPowerInterval) + peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval) + peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval) logger.Info("Set monitoring intervals to 10 minutes for low power mode") } o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable } + if o.holePunchManager != nil { + o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval) + } + o.currentPowerMode = "low" logger.Info("Switched to low power mode") @@ -673,20 +678,8 @@ func (o *Olm) SetPowerMode(mode string) error { } logger.Info("Debounce complete, switching to normal power mode") - - // Restore intervals and reconnect websocket - if o.peerManager != nil { - peerMonitor := o.peerManager.GetPeerMonitor() - if peerMonitor != nil { - peerMonitor.ResetHolepunchInterval() - peerMonitor.ResetInterval() - } - - o.peerManager.UpdateAllPeersPersistentKeepalive(5) - } - + logger.Info("Reconnecting websocket for normal power mode") - if o.websocket != nil { if err := o.websocket.Connect(); err != nil { logger.Error("Failed to reconnect websocket: %v", err) @@ -694,6 +687,21 @@ func (o *Olm) SetPowerMode(mode string) error { } } + // Restore intervals and reconnect websocket + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.ResetPeerHolepunchInterval() + peerMonitor.ResetPeerInterval() + } + + o.peerManager.UpdateAllPeersPersistentKeepalive(5) + } + + if o.holePunchManager != nil { + o.holePunchManager.ResetServerHolepunchInterval() + } + o.currentPowerMode = "normal" logger.Info("Switched to normal power mode") }) diff --git a/olm/peer.go b/olm/peer.go index 8acec42..9bc842e 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -123,7 +123,7 @@ func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) { if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) _ = o.holePunchManager.TriggerHolePunch() - o.holePunchManager.ResetInterval() + o.holePunchManager.ResetServerHolepunchInterval() } logger.Info("Successfully updated peer for site %d", updateData.SiteId) diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 3ac4b54..387b82f 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -28,14 +28,12 @@ import ( // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { - monitors map[int]*Client - mutex sync.Mutex - running bool - defaultInterval time.Duration - interval time.Duration + monitors map[int]*Client + mutex sync.Mutex + running bool timeout time.Duration - maxAttempts int - wsClient *websocket.Client + maxAttempts int + wsClient *websocket.Client // Netstack fields middleDev *middleDevice.MiddleDevice @@ -54,7 +52,8 @@ type PeerMonitor struct { holepunchTimeout time.Duration holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchStatus map[int]bool // siteID -> connected status - holepunchStopChan chan struct{} + holepunchStopChan chan struct{} + holepunchUpdateChan chan struct{} // Relay tracking fields relayedPeers map[int]bool // siteID -> whether the peer is currently relayed @@ -87,8 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), - defaultInterval: 2 * time.Second, - interval: 2 * time.Second, // Default check interval (faster) timeout: 3 * time.Second, maxAttempts: 3, wsClient: wsClient, @@ -118,6 +115,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe holepunchBackoffMultiplier: 1.5, holepunchStableCount: make(map[int]int), holepunchCurrentInterval: 2 * time.Second, + holepunchUpdateChan: make(chan struct{}, 1), } if err := pm.initNetstack(); err != nil { @@ -133,82 +131,76 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe } // SetInterval changes how frequently peers are checked -func (pm *PeerMonitor) SetInterval(interval time.Duration) { +func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.interval = interval - // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetPacketInterval(interval) + client.SetPacketInterval(minInterval, maxInterval) } + + logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval) } -func (pm *PeerMonitor) ResetInterval() { +func (pm *PeerMonitor) ResetPeerInterval() { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.interval = pm.defaultInterval - // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetPacketInterval(pm.defaultInterval) + client.ResetPacketInterval() } } -// SetTimeout changes the timeout for waiting for responses -func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { +// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) { pm.mutex.Lock() - defer pm.mutex.Unlock() - - pm.timeout = timeout - - // Update timeout for all existing monitors - for _, client := range pm.monitors { - client.SetTimeout(timeout) - } -} - -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (pm *PeerMonitor) SetMaxAttempts(attempts int) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - pm.maxAttempts = attempts - - // Update max attempts for all existing monitors - for _, client := range pm.monitors { - client.SetMaxAttempts(attempts) - } -} - -// SetHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring -func (pm *PeerMonitor) SetHolepunchInterval(minInterval, maxInterval time.Duration) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - pm.holepunchMinInterval = minInterval pm.holepunchMaxInterval = maxInterval // Reset current interval to the new minimum pm.holepunchCurrentInterval = minInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } } -// GetHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring -func (pm *PeerMonitor) GetHolepunchIntervals() (minInterval, maxInterval time.Duration) { +// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() return pm.holepunchMinInterval, pm.holepunchMaxInterval } -func (pm *PeerMonitor) ResetHolepunchInterval() { +func (pm *PeerMonitor) ResetPeerHolepunchInterval() { pm.mutex.Lock() - defer pm.mutex.Unlock() - pm.holepunchMinInterval = pm.defaultHolepunchMinInterval pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } } // AddPeer adds a new peer to monitor @@ -226,11 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st return err } - client.SetPacketInterval(pm.interval) - client.SetTimeout(pm.timeout) - client.SetMaxAttempts(pm.maxAttempts) - client.SetMaxInterval(30 * time.Second) // Allow backoff up to 30 seconds when stable - pm.monitors[siteID] = client pm.holepunchEndpoints[siteID] = holepunchEndpoint @@ -541,6 +528,15 @@ func (pm *PeerMonitor) runHolepunchMonitor() { select { case <-pm.holepunchStopChan: return + case <-pm.holepunchUpdateChan: + // Interval settings changed, reset to minimum + pm.mutex.Lock() + pm.holepunchCurrentInterval = pm.holepunchMinInterval + currentInterval := pm.holepunchCurrentInterval + pm.mutex.Unlock() + + timer.Reset(currentInterval) + logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval) case <-timer.C: anyStatusChanged := pm.checkHolepunchEndpoints() @@ -584,7 +580,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool { anyStatusChanged := false for siteID, endpoint := range endpoints { - // logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) + logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint) result := pm.holepunchTester.TestEndpoint(endpoint, timeout) pm.mutex.Lock() @@ -733,55 +729,55 @@ func (pm *PeerMonitor) Close() { logger.Debug("PeerMonitor: Cleanup complete") } -// TestPeer tests connectivity to a specific peer -func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { - pm.mutex.Lock() - client, exists := pm.monitors[siteID] - pm.mutex.Unlock() +// // TestPeer tests connectivity to a specific peer +// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { +// pm.mutex.Lock() +// client, exists := pm.monitors[siteID] +// pm.mutex.Unlock() - if !exists { - return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) - } +// if !exists { +// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) +// } - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - defer cancel() +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// defer cancel() - connected, rtt := client.TestConnection(ctx) - return connected, rtt, nil -} +// connected, rtt := client.TestPeerConnection(ctx) +// return connected, rtt, nil +// } -// TestAllPeers tests connectivity to all peers -func (pm *PeerMonitor) TestAllPeers() map[int]struct { - Connected bool - RTT time.Duration -} { - pm.mutex.Lock() - peers := make(map[int]*Client, len(pm.monitors)) - for siteID, client := range pm.monitors { - peers[siteID] = client - } - pm.mutex.Unlock() +// // TestAllPeers tests connectivity to all peers +// func (pm *PeerMonitor) TestAllPeers() map[int]struct { +// Connected bool +// RTT time.Duration +// } { +// pm.mutex.Lock() +// peers := make(map[int]*Client, len(pm.monitors)) +// for siteID, client := range pm.monitors { +// peers[siteID] = client +// } +// pm.mutex.Unlock() - results := make(map[int]struct { - Connected bool - RTT time.Duration - }) - for siteID, client := range peers { - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - connected, rtt := client.TestConnection(ctx) - cancel() +// results := make(map[int]struct { +// Connected bool +// RTT time.Duration +// }) +// for siteID, client := range peers { +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// connected, rtt := client.TestPeerConnection(ctx) +// cancel() - results[siteID] = struct { - Connected bool - RTT time.Duration - }{ - Connected: connected, - RTT: rtt, - } - } +// results[siteID] = struct { +// Connected bool +// RTT time.Duration +// }{ +// Connected: connected, +// RTT: rtt, +// } +// } - return results -} +// return results +// } // initNetstack initializes the gvisor netstack func (pm *PeerMonitor) initNetstack() error { diff --git a/peers/monitor/wgtester.go b/peers/monitor/wgtester.go index 21f788a..f06759a 100644 --- a/peers/monitor/wgtester.go +++ b/peers/monitor/wgtester.go @@ -32,16 +32,19 @@ type Client struct { monitorLock sync.Mutex connLock sync.Mutex // Protects connection operations shutdownCh chan struct{} + updateCh chan struct{} packetInterval time.Duration timeout time.Duration maxAttempts int dialer Dialer // Exponential backoff fields - minInterval time.Duration // Minimum interval (initial) - maxInterval time.Duration // Maximum interval (cap for backoff) - backoffMultiplier float64 // Multiplier for each stable check - stableCountToBackoff int // Number of stable checks before backing off + defaultMinInterval time.Duration // Default minimum interval (initial) + defaultMaxInterval time.Duration // Default maximum interval (cap for backoff) + minInterval time.Duration // Minimum interval (initial) + maxInterval time.Duration // Maximum interval (cap for backoff) + backoffMultiplier float64 // Multiplier for each stable check + stableCountToBackoff int // Number of stable checks before backing off } // Dialer is a function that creates a connection @@ -56,43 +59,59 @@ type ConnectionStatus struct { // NewClient creates a new connection test client func NewClient(serverAddr string, dialer Dialer) (*Client, error) { return &Client{ - serverAddr: serverAddr, - shutdownCh: make(chan struct{}), - packetInterval: 2 * time.Second, - minInterval: 2 * time.Second, - maxInterval: 30 * time.Second, - backoffMultiplier: 1.5, - stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off - timeout: 500 * time.Millisecond, // Timeout for individual packets - maxAttempts: 3, // Default max attempts - dialer: dialer, + serverAddr: serverAddr, + shutdownCh: make(chan struct{}), + updateCh: make(chan struct{}, 1), + packetInterval: 2 * time.Second, + defaultMinInterval: 2 * time.Second, + defaultMaxInterval: 30 * time.Second, + minInterval: 2 * time.Second, + maxInterval: 30 * time.Second, + backoffMultiplier: 1.5, + stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off + timeout: 500 * time.Millisecond, // Timeout for individual packets + maxAttempts: 3, // Default max attempts + dialer: dialer, }, nil } // SetPacketInterval changes how frequently packets are sent in monitor mode -func (c *Client) SetPacketInterval(interval time.Duration) { - c.packetInterval = interval - c.minInterval = interval +func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) { + c.monitorLock.Lock() + c.packetInterval = minInterval + c.minInterval = minInterval + c.maxInterval = maxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() + + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } -// SetTimeout changes the timeout for waiting for responses -func (c *Client) SetTimeout(timeout time.Duration) { - c.timeout = timeout -} +func (c *Client) ResetPacketInterval() { + c.monitorLock.Lock() + c.packetInterval = c.defaultMinInterval + c.minInterval = c.defaultMinInterval + c.maxInterval = c.defaultMaxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (c *Client) SetMaxAttempts(attempts int) { - c.maxAttempts = attempts -} - -// SetMaxInterval sets the maximum backoff interval -func (c *Client) SetMaxInterval(interval time.Duration) { - c.maxInterval = interval -} - -// SetBackoffMultiplier sets the multiplier for exponential backoff -func (c *Client) SetBackoffMultiplier(multiplier float64) { - c.backoffMultiplier = multiplier + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } // UpdateServerAddr updates the server address and resets the connection @@ -146,9 +165,10 @@ func (c *Client) ensureConnection() error { return nil } -// TestConnection checks if the connection to the server is working +// TestPeerConnection checks if the connection to the server is working // Returns true if connected, false otherwise -func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { +func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) { + logger.Debug("wgtester: testing connection to peer %s", c.serverAddr) if err := c.ensureConnection(); err != nil { logger.Warn("Failed to ensure connection: %v", err) return false, 0 @@ -232,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - return c.TestConnection(ctx) + return c.TestPeerConnection(ctx) } // MonitorCallback is the function type for connection status change callbacks @@ -269,9 +289,20 @@ func (c *Client) StartMonitor(callback MonitorCallback) error { select { case <-c.shutdownCh: return + case <-c.updateCh: + // Interval settings changed, reset to minimum + c.monitorLock.Lock() + currentInterval = c.minInterval + c.monitorLock.Unlock() + + // Reset backoff state + stableCount = 0 + + timer.Reset(currentInterval) + logger.Debug("Packet interval updated, reset to %v", currentInterval) case <-timer.C: ctx, cancel := context.WithTimeout(context.Background(), c.timeout) - connected, rtt := c.TestConnection(ctx) + connected, rtt := c.TestPeerConnection(ctx) cancel() statusChanged := connected != lastConnected @@ -321,4 +352,4 @@ func (c *Client) StopMonitor() { close(c.shutdownCh) c.monitorRunning = false -} \ No newline at end of file +} diff --git a/websocket/client.go b/websocket/client.go index 34eea35..f040aa4 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -236,7 +236,7 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { Data: data, } - logger.Debug("Sending message: %s, data: %+v", messageType, data) + logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data) c.writeMux.Lock() defer c.writeMux.Unlock() @@ -258,7 +258,7 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter } err := c.SendMessage(messageType, currentData) if err != nil { - logger.Error("Failed to send message: %v", err) + logger.Error("websocket: Failed to send message: %v", err) } count++ } @@ -271,7 +271,7 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter select { case <-ticker.C: if maxAttempts != -1 && count >= maxAttempts { - logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) + logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) return } dataMux.Lock() @@ -353,7 +353,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { tlsConfig = &tls.Config{} } tlsConfig.InsecureSkipVerify = true - logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } tokenData := map[string]interface{}{ @@ -382,7 +382,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { req.Header.Set("X-CSRF-Token", "x-csrf-protection") // print out the request for debugging - logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) + logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) // Make the request client := &http.Client{} @@ -399,7 +399,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) // Return AuthError for 401/403 status codes if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { @@ -415,7 +415,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - logger.Error("Failed to decode token response.") + logger.Error("websocket: Failed to decode token response.") return "", nil, fmt.Errorf("failed to decode token response: %w", err) } @@ -427,7 +427,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { return "", nil, fmt.Errorf("received empty token from server") } - logger.Debug("Received token: %s", tokenResp.Data.Token) + logger.Debug("websocket: Received token: %s", tokenResp.Data.Token) return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil } @@ -442,7 +442,7 @@ func (c *Client) connectWithRetry() { if err != nil { // Check if this is an auth error (401/403) if authErr, ok := err.(*AuthError); ok { - logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr) + logger.Error("websocket: Authentication failed: %v. Terminating tunnel and retrying...", authErr) // Trigger auth error callback if set (this should terminate the tunnel) if c.onAuthError != nil { c.onAuthError(authErr.StatusCode, authErr.Message) @@ -452,7 +452,7 @@ func (c *Client) connectWithRetry() { continue } // For other errors (5xx, network issues), continue retrying - logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) + logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) time.Sleep(c.reconnectInterval) continue } @@ -505,7 +505,7 @@ func (c *Client) establishConnection() error { // Use new TLS configuration method if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { - logger.Info("Setting up TLS configuration for WebSocket connection") + logger.Info("websocket: Setting up TLS configuration for WebSocket connection") tlsConfig, err := c.setupTLS() if err != nil { return fmt.Errorf("failed to setup TLS configuration: %w", err) @@ -519,7 +519,7 @@ func (c *Client) establishConnection() error { dialer.TLSClientConfig = &tls.Config{} } dialer.TLSClientConfig.InsecureSkipVerify = true - logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } conn, _, err := dialer.Dial(u.String(), nil) @@ -537,7 +537,7 @@ func (c *Client) establishConnection() error { if c.onConnect != nil { if err := c.onConnect(); err != nil { - logger.Error("OnConnect callback failed: %v", err) + logger.Error("websocket: OnConnect callback failed: %v", err) } } @@ -550,9 +550,9 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Handle new separate certificate configuration if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" { - logger.Info("Loading separate certificate files for mTLS") - logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile) - logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile) + logger.Info("websocket: Loading separate certificate files for mTLS") + logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile) + logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile) // Load client certificate and key cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile) @@ -563,7 +563,7 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Load CA certificates for remote validation if specified if len(c.tlsConfig.CAFiles) > 0 { - logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles) + logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles) caCertPool := x509.NewCertPool() for _, caFile := range c.tlsConfig.CAFiles { caCert, err := os.ReadFile(caFile) @@ -589,13 +589,13 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Fallback to existing PKCS12 implementation for backward compatibility if c.tlsConfig.PKCS12File != "" { - logger.Info("Loading PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)") return c.setupPKCS12TLS() } // Legacy fallback using config.TlsClientCert if c.config.TlsClientCert != "" { - logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)") return loadClientCertificate(c.config.TlsClientCert) } @@ -630,7 +630,7 @@ func (c *Client) pingMonitor() { // Expected during shutdown return default: - logger.Error("Ping failed: %v", err) + logger.Error("websocket: Ping failed: %v", err) c.reconnect() return } @@ -663,18 +663,23 @@ func (c *Client) readPumpWithDisconnectDetection() { var msg WSMessage err := c.conn.ReadJSON(&msg) if err != nil { - // Check if we're shutting down before logging error + // Check if we're shutting down or explicitly disconnected before logging error select { case <-c.done: // Expected during shutdown, don't log as error - logger.Debug("WebSocket connection closed during shutdown") + logger.Debug("websocket: connection closed during shutdown") return default: + // Check if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: connection closed: client was explicitly disconnected") + return + } // Unexpected error during normal operation if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { - logger.Error("WebSocket read error: %v", err) + logger.Error("websocket: read error: %v", err) } else { - logger.Debug("WebSocket connection closed: %v", err) + logger.Debug("websocket: connection closed: %v", err) } return // triggers reconnect via defer } @@ -696,6 +701,12 @@ func (c *Client) reconnect() { c.conn = nil } + // Don't reconnect if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected") + return + } + // Only reconnect if we're not shutting down select { case <-c.done: @@ -713,7 +724,7 @@ func (c *Client) setConnected(status bool) { // LoadClientCertificate Helper method to load client certificates (PKCS12 format) func loadClientCertificate(p12Path string) (*tls.Config, error) { - logger.Info("Loading tls-client-cert %s", p12Path) + logger.Info("websocket: Loading tls-client-cert %s", p12Path) // Read the PKCS12 file p12Data, err := os.ReadFile(p12Path) if err != nil { From 17b75bf58f48345ac10bef2b486006cd2c9aa481 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 16:51:04 -0800 Subject: [PATCH 26/52] Dont get token each time Former-commit-id: 07dfc651f19a767a44d880fd27200bfc91a54cc7 --- websocket/client.go | 45 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/websocket/client.go b/websocket/client.go index f040aa4..b50cf31 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -88,6 +88,10 @@ type Client struct { clientType string // Type of client (e.g., "newt", "olm") tlsConfig TLSConfig configNeedsSave bool // Flag to track if config needs to be saved + token string // Cached authentication token + exitNodes []ExitNode // Cached exit nodes from token response + tokenMux sync.RWMutex // Protects token and exitNodes + forceNewToken bool // Flag to force fetching a new token on next connection } type ClientOption func(*Client) @@ -462,15 +466,25 @@ func (c *Client) connectWithRetry() { } func (c *Client) establishConnection() error { - // Get token for authentication - token, exitNodes, err := c.getToken() - if err != nil { - return fmt.Errorf("failed to get token: %w", err) - } - - if c.onTokenUpdate != nil { - c.onTokenUpdate(token, exitNodes) + // Get token for authentication - reuse cached token unless forced to get new one + c.tokenMux.Lock() + needNewToken := c.token == "" || c.forceNewToken + if needNewToken { + token, exitNodes, err := c.getToken() + if err != nil { + c.tokenMux.Unlock() + return fmt.Errorf("failed to get token: %w", err) + } + c.token = token + c.exitNodes = exitNodes + c.forceNewToken = false + + if c.onTokenUpdate != nil { + c.onTokenUpdate(token, exitNodes) + } } + token := c.token + c.tokenMux.Unlock() // Parse the base URL to determine protocol and hostname baseURL, err := url.Parse(c.baseURL) @@ -522,8 +536,20 @@ func (c *Client) establishConnection() error { logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } - conn, _, err := dialer.Dial(u.String(), nil) + conn, resp, err := dialer.Dial(u.String(), nil) if err != nil { + // Check if this is an unauthorized error (401) + if resp != nil && resp.StatusCode == http.StatusUnauthorized { + logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized") + // Force getting a new token on next reconnect attempt + c.tokenMux.Lock() + c.forceNewToken = true + c.tokenMux.Unlock() + return &AuthError{ + StatusCode: http.StatusUnauthorized, + Message: "WebSocket connection unauthorized", + } + } return fmt.Errorf("failed to connect to WebSocket: %w", err) } @@ -675,6 +701,7 @@ func (c *Client) readPumpWithDisconnectDetection() { logger.Debug("websocket: connection closed: client was explicitly disconnected") return } + // Unexpected error during normal operation if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { logger.Error("websocket: read error: %v", err) From 69952ee5c5fd122a7cba9321dc435fad649af345 Mon Sep 17 00:00:00 2001 From: Varun Narravula Date: Thu, 8 Jan 2026 20:37:29 -0800 Subject: [PATCH 27/52] feat(api): add fingerprint + posture fields to client state Former-commit-id: 566084683ab5b12ae026cb68e399c1d4f4144b8e --- api/api.go | 42 ++++++++++++++++++++++++++++---- olm/olm.go | 67 ++++++++++++++++++++++++++++++++++------------------ olm/types.go | 3 +++ 3 files changed, 85 insertions(+), 27 deletions(-) diff --git a/api/api.go b/api/api.go index a6ac9cd..442162e 100644 --- a/api/api.go +++ b/api/api.go @@ -61,6 +61,11 @@ type StatusResponse struct { NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"` } +type MetadataChangeRequest struct { + Fingerprint map[string]any `json:"fingerprint"` + Postures map[string]any `json:"postures"` +} + // API represents the HTTP server and its state type API struct { addr string @@ -68,10 +73,11 @@ type API struct { listener net.Listener server *http.Server - onConnect func(ConnectionRequest) error - onSwitchOrg func(SwitchOrgRequest) error - onDisconnect func() error - onExit func() error + onConnect func(ConnectionRequest) error + onSwitchOrg func(SwitchOrgRequest) error + onMetadataChange func(MetadataChangeRequest) error + onDisconnect func() error + onExit func() error statusMu sync.RWMutex peerStatuses map[int]*PeerStatus @@ -117,6 +123,7 @@ func NewAPIStub() *API { func (s *API) SetHandlers( onConnect func(ConnectionRequest) error, onSwitchOrg func(SwitchOrgRequest) error, + onMetadataChange func(MetadataChangeRequest) error, onDisconnect func() error, onExit func() error, ) { @@ -136,6 +143,7 @@ func (s *API) Start() error { mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/switch-org", s.handleSwitchOrg) + mux.HandleFunc("/metadata", s.handleMetadataChange) mux.HandleFunc("/disconnect", s.handleDisconnect) mux.HandleFunc("/exit", s.handleExit) mux.HandleFunc("/health", s.handleHealth) @@ -514,6 +522,32 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { }) } +// handleMetadataChange handles the /metadata endpoint +func (s *API) handleMetadataChange(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req MetadataChangeRequest + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) + return + } + + logger.Info("Received metadata change request via API: %v", req) + + _ = s.onMetadataChange(req) + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "metadata updated", + }) +} + func (s *API) GetStatus() StatusResponse { return StatusResponse{ Connected: s.isConnected, diff --git a/olm/olm.go b/olm/olm.go index 2db3630..de3f5a7 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -7,6 +7,7 @@ import ( "net/http" _ "net/http/pprof" "os" + "sync" "time" "github.com/fosrl/newt/bind" @@ -51,6 +52,11 @@ type Olm struct { olmConfig OlmConfig tunnelConfig TunnelConfig + // Metadata to send alongside pings + fingerprint map[string]any + postures map[string]any + metaMu sync.Mutex + stopRegister func() stopPeerSend func() updateRegister func(newData any) @@ -229,6 +235,20 @@ func (o *Olm) registerAPICallbacks() { logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID) return o.SwitchOrg(req.OrgID) }, + // onMetadataChange + func(req api.MetadataChangeRequest) error { + logger.Info("Received change metadata request via API") + + if req.Fingerprint != nil { + o.SetFingerprint(req.Fingerprint) + } + + if req.Postures != nil { + o.SetPostures(req.Postures) + } + + return nil + }, // onDisconnect func() error { logger.Info("Processing disconnect request via API") @@ -404,6 +424,19 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } }) + fingerprint := config.InitialFingerprint + if fingerprint == nil { + fingerprint = make(map[string]any) + } + + postures := config.InitialPostures + if postures == nil { + postures = make(map[string]any) + } + + o.SetFingerprint(fingerprint) + o.SetPostures(postures) + // Connect to the WebSocket server if err := olmClient.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) @@ -577,28 +610,16 @@ func (o *Olm) SwitchOrg(orgID string) error { return nil } -func (o *Olm) AddDevice(fd uint32) error { - if o.middleDev == nil { - return fmt.Errorf("middle device is not initialized") - } +func (o *Olm) SetFingerprint(data map[string]any) { + o.metaMu.Lock() + defer o.metaMu.Unlock() - if o.tunnelConfig.MTU == 0 { - return fmt.Errorf("tunnel MTU is not set") - } - - tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU) - if err != nil { - return fmt.Errorf("failed to create TUN device from fd: %v", err) - } - - // Update interface name if available - if realInterfaceName, err2 := tdev.Name(); err2 == nil { - o.tunnelConfig.InterfaceName = realInterfaceName - } - - // Replace the existing TUN device in the middle device with the new one - o.middleDev.AddDevice(tdev) - - logger.Info("Added device from file descriptor %d", fd) - return nil + o.fingerprint = data +} + +func (o *Olm) SetPostures(data map[string]any) { + o.metaMu.Lock() + defer o.metaMu.Unlock() + + o.postures = data } diff --git a/olm/types.go b/olm/types.go index 77c0b5f..28e2260 100644 --- a/olm/types.go +++ b/olm/types.go @@ -67,5 +67,8 @@ type TunnelConfig struct { OverrideDNS bool TunnelDNS bool + InitialFingerprint map[string]any + InitialPostures map[string]any + DisableRelay bool } From 4b6999e06aacfd8e57cfa7f662dfc8c5913262a9 Mon Sep 17 00:00:00 2001 From: Varun Narravula Date: Thu, 8 Jan 2026 20:40:41 -0800 Subject: [PATCH 28/52] feat(ping): send fingerprint and posture checks as part of ping/register Former-commit-id: 70a7e83291cd8890bbf6217a9b4d819005c867f1 --- olm/olm.go | 14 ++++++++------ olm/ping.go | 12 +++++++----- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index de3f5a7..0810025 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -356,12 +356,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) { if o.stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) o.stopRegister, o.updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]any{ - "publicKey": publicKey.String(), - "relay": !config.Holepunch, - "olmVersion": o.olmConfig.Version, - "olmAgent": o.olmConfig.Agent, - "orgId": config.OrgID, - "userToken": userToken, + "publicKey": publicKey.String(), + "relay": !config.Holepunch, + "olmVersion": o.olmConfig.Version, + "olmAgent": o.olmConfig.Agent, + "orgId": config.OrgID, + "userToken": userToken, + "fingerprint": o.fingerprint, + "postures": o.postures, }, 1*time.Second) // Invoke onRegistered callback if configured diff --git a/olm/ping.go b/olm/ping.go index bbeee9a..0d5235d 100644 --- a/olm/ping.go +++ b/olm/ping.go @@ -8,10 +8,12 @@ import ( "github.com/fosrl/olm/websocket" ) -func sendPing(olm *websocket.Client) error { +func (o *Olm) sendPing(olm *websocket.Client) error { err := olm.SendMessage("olm/ping", map[string]any{ - "timestamp": time.Now().Unix(), - "userToken": olm.GetConfig().UserToken, + "timestamp": time.Now().Unix(), + "userToken": olm.GetConfig().UserToken, + "fingerprint": o.fingerprint, + "postures": o.postures, }) if err != nil { logger.Error("Failed to send ping message: %v", err) @@ -23,7 +25,7 @@ func sendPing(olm *websocket.Client) error { func (o *Olm) keepSendingPing(olm *websocket.Client) { // Send ping immediately on startup - if err := sendPing(olm); err != nil { + if err := o.sendPing(olm); err != nil { logger.Error("Failed to send initial ping: %v", err) } else { logger.Info("Sent initial ping message") @@ -39,7 +41,7 @@ func (o *Olm) keepSendingPing(olm *websocket.Client) { logger.Info("Stopping ping messages") return case <-ticker.C: - if err := sendPing(olm); err != nil { + if err := o.sendPing(olm); err != nil { logger.Error("Failed to send periodic ping: %v", err) } } From 9dcc0796a653eff6fc9548153c8a166393a7438b Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 15 Jan 2026 14:20:12 -0800 Subject: [PATCH 29/52] Small clean up and move ping to client.go Former-commit-id: af33218792fb9faf32249368ee08cfaedfeecc00 --- olm/data.go | 6 +++--- olm/olm.go | 15 --------------- olm/types.go | 6 +++++- websocket/client.go | 7 +++++-- 4 files changed, 13 insertions(+), 21 deletions(-) diff --git a/olm/data.go b/olm/data.go index 80a52fc..cf7448a 100644 --- a/olm/data.go +++ b/olm/data.go @@ -216,15 +216,15 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { return } - var wgData WgData - if err := json.Unmarshal(jsonData, &wgData); err != nil { + var syncData SyncData + if err := json.Unmarshal(jsonData, &syncData); err != nil { logger.Error("Error unmarshaling sync data: %v", err) return } // Build a map of expected peers from the incoming data expectedPeers := make(map[int]peers.SiteConfig) - for _, site := range wgData.Sites { + for _, site := range syncData.Sites { expectedPeers[site.SiteId] = site } diff --git a/olm/olm.go b/olm/olm.go index 85dcbe6..9582232 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -60,8 +60,6 @@ type Olm struct { stopRegister func() updateRegister func(newData any) - stopServerPing func() - stopPeerSend func() } @@ -332,14 +330,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.apiServer.SetConnectionStatus(true) - // restart the ping if we need to - if o.stopServerPing == nil { - o.stopServerPing, _ = olmClient.SendMessageInterval("olm/ping", map[string]any{ - "timestamp": time.Now().Unix(), - "userToken": olmClient.GetConfig().UserToken, - }, 30*time.Second, -1) // -1 means dont time out with the max attempts - } - if o.connected { logger.Debug("Already connected, skipping registration") return nil @@ -445,11 +435,6 @@ func (o *Olm) Close() { o.holePunchManager = nil } - if o.stopServerPing != nil { - o.stopServerPing() - o.stopServerPing = nil - } - if o.stopRegister != nil { o.stopRegister() o.stopRegister = nil diff --git a/olm/types.go b/olm/types.go index 397eab9..804f8e5 100644 --- a/olm/types.go +++ b/olm/types.go @@ -12,6 +12,10 @@ type WgData struct { UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses } +type SyncData struct { + Sites []peers.SiteConfig `json:"sites"` +} + type OlmConfig struct { // Logging LogLevel string @@ -23,7 +27,7 @@ type OlmConfig struct { SocketPath string Version string Agent string - + WakeUpDebounce time.Duration // Debugging diff --git a/websocket/client.go b/websocket/client.go index ba70494..7877e6d 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -657,8 +657,11 @@ func (c *Client) pingMonitor() { c.configVersionMux.RUnlock() pingMsg := WSMessage{ - Type: "ping", - Data: map[string]interface{}{}, + Type: "olm/ping", + Data: map[string]any{ + "timestamp": time.Now().Unix(), + "userToken": c.config.UserToken, + }, ConfigVersion: configVersion, } From e047330ffd1894c12f170750b623b1cb535a7a9c Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 15 Jan 2026 16:36:11 -0800 Subject: [PATCH 30/52] Handle and test config version bugs Former-commit-id: 285f8ce530cdc3be995c21294ea7c76e7f057da3 --- olm/data.go | 2 +- olm/olm.go | 73 ++++++++++++++++++++++++++++++++------------- websocket/client.go | 11 ++++--- 3 files changed, 58 insertions(+), 28 deletions(-) diff --git a/olm/data.go b/olm/data.go index cf7448a..eff46f4 100644 --- a/olm/data.go +++ b/olm/data.go @@ -198,7 +198,7 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { // Handler for syncing peer configuration - reconciles expected state with actual state func (o *Olm) handleSync(msg websocket.WSMessage) { - logger.Debug("Received sync message: %v", msg.Data) + logger.Debug("++++++++++++++++++++++++++++Received sync message: %v", msg.Data) if !o.connected { logger.Warn("Not connected, ignoring sync request") diff --git a/olm/olm.go b/olm/olm.go index 9582232..22a936f 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -7,7 +7,9 @@ import ( "net/http" _ "net/http/pprof" "os" + "os/signal" "sync" + "syscall" "time" "github.com/fosrl/newt/bind" @@ -275,6 +277,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.tunnelCancel = cancel var ( + err error id = config.ID secret = config.Secret userToken = config.UserToken @@ -284,8 +287,8 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.apiServer.SetOrgID(config.OrgID) - // Create a new olmClient client using the provided credentials - olmClient, err := websocket.NewClient( + // Create a new o.websocket client using the provided credentials + o.websocket, err = websocket.NewClient( id, secret, userToken, @@ -306,26 +309,26 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } // Handlers for managing connection status - olmClient.RegisterHandler("olm/wg/connect", o.handleConnect) - olmClient.RegisterHandler("olm/terminate", o.handleTerminate) + o.websocket.RegisterHandler("olm/wg/connect", o.handleConnect) + o.websocket.RegisterHandler("olm/terminate", o.handleTerminate) // Handlers for managing peers - olmClient.RegisterHandler("olm/wg/peer/add", o.handleWgPeerAdd) - olmClient.RegisterHandler("olm/wg/peer/remove", o.handleWgPeerRemove) - olmClient.RegisterHandler("olm/wg/peer/update", o.handleWgPeerUpdate) - olmClient.RegisterHandler("olm/wg/peer/relay", o.handleWgPeerRelay) - olmClient.RegisterHandler("olm/wg/peer/unrelay", o.handleWgPeerUnrelay) + o.websocket.RegisterHandler("olm/wg/peer/add", o.handleWgPeerAdd) + o.websocket.RegisterHandler("olm/wg/peer/remove", o.handleWgPeerRemove) + o.websocket.RegisterHandler("olm/wg/peer/update", o.handleWgPeerUpdate) + o.websocket.RegisterHandler("olm/wg/peer/relay", o.handleWgPeerRelay) + o.websocket.RegisterHandler("olm/wg/peer/unrelay", o.handleWgPeerUnrelay) // Handlers for managing remote subnets to a peer - olmClient.RegisterHandler("olm/wg/peer/data/add", o.handleWgPeerAddData) - olmClient.RegisterHandler("olm/wg/peer/data/remove", o.handleWgPeerRemoveData) - olmClient.RegisterHandler("olm/wg/peer/data/update", o.handleWgPeerUpdateData) + o.websocket.RegisterHandler("olm/wg/peer/data/add", o.handleWgPeerAddData) + o.websocket.RegisterHandler("olm/wg/peer/data/remove", o.handleWgPeerRemoveData) + o.websocket.RegisterHandler("olm/wg/peer/data/update", o.handleWgPeerUpdateData) // Handler for peer handshake - adds exit node to holepunch rotation and notifies server - olmClient.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) - olmClient.RegisterHandler("olm/sync", o.handleSync) + o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) + o.websocket.RegisterHandler("olm/sync", o.handleSync) - olmClient.OnConnect(func() error { + o.websocket.OnConnect(func() error { logger.Info("Websocket Connected") o.apiServer.SetConnectionStatus(true) @@ -342,7 +345,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { if o.stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) - o.stopRegister, o.updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]any{ + o.stopRegister, o.updateRegister = o.websocket.SendMessageInterval("olm/wg/register", map[string]any{ "publicKey": publicKey.String(), "relay": !config.Holepunch, "olmVersion": o.olmConfig.Version, @@ -360,7 +363,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { return nil }) - olmClient.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { + o.websocket.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { o.holePunchManager.SetToken(token) logger.Debug("Got exit nodes for hole punching: %v", exitNodes) @@ -390,7 +393,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } }) - olmClient.OnAuthError(func(statusCode int, message string) { + o.websocket.OnAuthError(func(statusCode int, message string) { logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message) o.apiServer.SetTerminated(true) o.apiServer.SetConnectionStatus(false) @@ -410,13 +413,41 @@ func (o *Olm) StartTunnel(config TunnelConfig) { }) // Connect to the WebSocket server - if err := olmClient.Connect(); err != nil { + if err := o.websocket.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) return } - defer func() { _ = olmClient.Close() }() + defer func() { _ = o.websocket.Close() }() - o.websocket = olmClient + // Setup SIGHUP signal handler for testing (toggles power state) + // THIS SHOULD ONLY BE USED AND ON IN A DEV MODE + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGHUP) + + go func() { + powerMode := "normal" + for { + select { + case <-sigChan: + + logger.Info("SIGHUP received, toggling power mode") + if powerMode == "normal" { + powerMode = "low" + if err := o.SetPowerMode("low"); err != nil { + logger.Error("Failed to set low power mode: %v", err) + } + } else { + powerMode = "normal" + if err := o.SetPowerMode("normal"); err != nil { + logger.Error("Failed to set normal power mode: %v", err) + } + } + + case <-tunnelCtx.Done(): + return + } + } + }() // Wait for context cancellation <-tunnelCtx.Done() diff --git a/websocket/client.go b/websocket/client.go index 7877e6d..8bcbeb3 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -665,6 +665,8 @@ func (c *Client) pingMonitor() { ConfigVersion: configVersion, } + logger.Debug("++++++++++++++++++++++++++++websocket: Sending ping: %+v", pingMsg) + c.writeMux.Lock() err := c.conn.WriteJSON(pingMsg) c.writeMux.Unlock() @@ -695,9 +697,8 @@ func (c *Client) GetConfigVersion() int { func (c *Client) setConfigVersion(version int) { c.configVersionMux.Lock() defer c.configVersionMux.Unlock() - if version > c.configVersion { - c.configVersion = version - } + logger.Debug("++++++++++++++++++++++++++++websocket: setting config version to %d", version) + c.configVersion = version } // readPumpWithDisconnectDetection reads messages and triggers reconnect on error @@ -748,9 +749,7 @@ func (c *Client) readPumpWithDisconnectDetection() { } // Update config version from incoming message - if msg.ConfigVersion > 0 { - c.setConfigVersion(msg.ConfigVersion) - } + c.setConfigVersion(msg.ConfigVersion) c.handlersMux.RLock() if handler, ok := c.handlers[msg.Type]; ok { From bd8031651e9fe66f365f41b3ed8eeb127dd14df7 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 15 Jan 2026 21:25:53 -0800 Subject: [PATCH 31/52] Message syncing works Former-commit-id: 1650624a553208c577b32368f5b93a77322ab922 --- olm/data.go | 139 ++++++++++++++++++++++---------------------- olm/peer.go | 63 ++++++++++++++++++++ olm/types.go | 10 +++- websocket/client.go | 26 +++++++++ 4 files changed, 169 insertions(+), 69 deletions(-) diff --git a/olm/data.go b/olm/data.go index eff46f4..1cd29fa 100644 --- a/olm/data.go +++ b/olm/data.go @@ -135,67 +135,6 @@ func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) { logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) } -func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { - logger.Debug("Received peer-handshake message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling handshake data: %v", err) - return - } - - var handshakeData struct { - SiteId int `json:"siteId"` - ExitNode struct { - PublicKey string `json:"publicKey"` - Endpoint string `json:"endpoint"` - RelayPort uint16 `json:"relayPort"` - } `json:"exitNode"` - } - - if err := json.Unmarshal(jsonData, &handshakeData); err != nil { - logger.Error("Error unmarshaling handshake data: %v", err) - return - } - - // Get existing peer from PeerManager - _, exists := o.peerManager.GetPeer(handshakeData.SiteId) - if exists { - logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) - return - } - - relayPort := handshakeData.ExitNode.RelayPort - if relayPort == 0 { - relayPort = 21820 // default relay port - } - - siteId := handshakeData.SiteId - exitNode := holepunch.ExitNode{ - Endpoint: handshakeData.ExitNode.Endpoint, - RelayPort: relayPort, - PublicKey: handshakeData.ExitNode.PublicKey, - SiteIds: []int{siteId}, - } - - added := o.holePunchManager.AddExitNode(exitNode) - if added { - logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) - } else { - logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) - } - - o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt - o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud - - // Send handshake acknowledgment back to server with retry - o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": handshakeData.SiteId, - }, 1*time.Second, 10) - - logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) -} - // Handler for syncing peer configuration - reconciles expected state with actual state func (o *Olm) handleSync(msg websocket.WSMessage) { logger.Debug("++++++++++++++++++++++++++++Received sync message: %v", msg.Data) @@ -222,6 +161,9 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { return } + // Sync exit nodes for hole punching + o.syncExitNodes(syncData.ExitNodes) + // Build a map of expected peers from the incoming data expectedPeers := make(map[int]peers.SiteConfig) for _, site := range syncData.Sites { @@ -259,15 +201,21 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { // New peer - add it using the add flow (with holepunch) logger.Info("Sync: Adding new peer for site %d", siteId) - // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it o.holePunchManager.TriggerHolePunch() - // TODO: do we need to send the message to the cloud to add the peer that way? - if err := o.peerManager.AddPeer(expectedSite); err != nil { - logger.Error("Sync: Failed to add peer %d: %v", siteId, err) - } else { - logger.Info("Sync: Successfully added peer for site %d", siteId) - } + // // TODO: do we need to send the message to the cloud to add the peer that way? + // if err := o.peerManager.AddPeer(expectedSite); err != nil { + // logger.Error("Sync: Failed to add peer %d: %v", siteId, err) + // } else { + // logger.Info("Sync: Successfully added peer for site %d", siteId) + // } + + // add the peer via the server + // this is important because newt needs to get triggered as well to add the peer once the hp is complete + o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": expectedSite.SiteId, + }, 1*time.Second, 10) + } else { // Existing peer - check if update is needed currentSite := currentPeerMap[siteId] @@ -342,3 +290,58 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers)) } + +// syncExitNodes reconciles the expected exit nodes with the current ones in the hole punch manager +func (o *Olm) syncExitNodes(expectedExitNodes []SyncExitNode) { + if o.holePunchManager == nil { + logger.Warn("Hole punch manager not initialized, skipping exit node sync") + return + } + + // Build a map of expected exit nodes by endpoint + expectedExitNodeMap := make(map[string]SyncExitNode) + for _, exitNode := range expectedExitNodes { + expectedExitNodeMap[exitNode.Endpoint] = exitNode + } + + // Get current exit nodes from hole punch manager + currentExitNodes := o.holePunchManager.GetExitNodes() + currentExitNodeMap := make(map[string]holepunch.ExitNode) + for _, exitNode := range currentExitNodes { + currentExitNodeMap[exitNode.Endpoint] = exitNode + } + + // Find exit nodes to remove (in current but not in expected) + for endpoint := range currentExitNodeMap { + if _, exists := expectedExitNodeMap[endpoint]; !exists { + logger.Info("Sync: Removing exit node %s (no longer in expected config)", endpoint) + o.holePunchManager.RemoveExitNode(endpoint) + } + } + + // Find exit nodes to add (in expected but not in current) + for endpoint, expectedExitNode := range expectedExitNodeMap { + if _, exists := currentExitNodeMap[endpoint]; !exists { + logger.Info("Sync: Adding new exit node %s", endpoint) + + relayPort := expectedExitNode.RelayPort + if relayPort == 0 { + relayPort = 21820 // default relay port + } + + hpExitNode := holepunch.ExitNode{ + Endpoint: expectedExitNode.Endpoint, + RelayPort: relayPort, + PublicKey: expectedExitNode.PublicKey, + SiteIds: expectedExitNode.SiteIds, + } + + if o.holePunchManager.AddExitNode(hpExitNode) { + logger.Info("Sync: Successfully added exit node %s", endpoint) + } + o.holePunchManager.TriggerHolePunch() + } + } + + logger.Info("Sync exit nodes completed: processed %d expected exit nodes, had %d current exit nodes", len(expectedExitNodeMap), len(currentExitNodeMap)) +} diff --git a/olm/peer.go b/olm/peer.go index 9bc842e..56e298d 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -2,7 +2,9 @@ package olm import ( "encoding/json" + "time" + "github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/util" "github.com/fosrl/olm/peers" @@ -193,3 +195,64 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) } + +func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { + logger.Debug("Received peer-handshake message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling handshake data: %v", err) + return + } + + var handshakeData struct { + SiteId int `json:"siteId"` + ExitNode struct { + PublicKey string `json:"publicKey"` + Endpoint string `json:"endpoint"` + RelayPort uint16 `json:"relayPort"` + } `json:"exitNode"` + } + + if err := json.Unmarshal(jsonData, &handshakeData); err != nil { + logger.Error("Error unmarshaling handshake data: %v", err) + return + } + + // Get existing peer from PeerManager + _, exists := o.peerManager.GetPeer(handshakeData.SiteId) + if exists { + logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) + return + } + + relayPort := handshakeData.ExitNode.RelayPort + if relayPort == 0 { + relayPort = 21820 // default relay port + } + + siteId := handshakeData.SiteId + exitNode := holepunch.ExitNode{ + Endpoint: handshakeData.ExitNode.Endpoint, + RelayPort: relayPort, + PublicKey: handshakeData.ExitNode.PublicKey, + SiteIds: []int{siteId}, + } + + added := o.holePunchManager.AddExitNode(exitNode) + if added { + logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) + } else { + logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) + } + + o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt + o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud + + // Send handshake acknowledgment back to server with retry + o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + }, 1*time.Second, 10) + + logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) +} diff --git a/olm/types.go b/olm/types.go index 491ed19..2e56ad7 100644 --- a/olm/types.go +++ b/olm/types.go @@ -13,7 +13,15 @@ type WgData struct { } type SyncData struct { - Sites []peers.SiteConfig `json:"sites"` + Sites []peers.SiteConfig `json:"sites"` + ExitNodes []SyncExitNode `json:"exitNodes"` +} + +type SyncExitNode struct { + Endpoint string `json:"endpoint"` + RelayPort uint16 `json:"relayPort"` + PublicKey string `json:"publicKey"` + SiteIds []int `json:"siteIds"` } type OlmConfig struct { diff --git a/websocket/client.go b/websocket/client.go index 8bcbeb3..4a1099e 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -96,6 +96,9 @@ type Client struct { exitNodes []ExitNode // Cached exit nodes from token response tokenMux sync.RWMutex // Protects token and exitNodes forceNewToken bool // Flag to force fetching a new token on next connection + processingMessage bool // Flag to track if a message is currently being processed + processingMux sync.RWMutex // Protects processingMessage + processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete } type ClientOption func(*Client) @@ -222,6 +225,9 @@ func (c *Client) Disconnect() error { c.isDisconnected = true c.setConnected(false) + // Wait for any message currently being processed to complete + c.processingWg.Wait() + if c.conn != nil { c.writeMux.Lock() c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) @@ -651,6 +657,14 @@ func (c *Client) pingMonitor() { if c.isDisconnected || c.conn == nil { return } + // Skip ping if a message is currently being processed + c.processingMux.RLock() + isProcessing := c.processingMessage + c.processingMux.RUnlock() + if isProcessing { + logger.Debug("websocket: Skipping ping, message is being processed") + continue + } // Send application-level ping with config version c.configVersionMux.RLock() configVersion := c.configVersion @@ -753,7 +767,19 @@ func (c *Client) readPumpWithDisconnectDetection() { c.handlersMux.RLock() if handler, ok := c.handlers[msg.Type]; ok { + // Mark that we're processing a message + c.processingMux.Lock() + c.processingMessage = true + c.processingMux.Unlock() + c.processingWg.Add(1) + handler(msg) + + // Mark that we're done processing + c.processingWg.Done() + c.processingMux.Lock() + c.processingMessage = false + c.processingMux.Unlock() } c.handlersMux.RUnlock() } From e1a687407eec5f3b9d2c8c6ec3936b1ad3380678 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 15 Jan 2026 21:59:18 -0800 Subject: [PATCH 32/52] Set the ping inteval to 30 seconds Former-commit-id: 737ffca15d746204423ac5b4a98f5a7e8be783f9 --- olm/olm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olm/olm.go b/olm/olm.go index 97bd4b7..e1d9a7f 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -313,7 +313,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { userToken, config.OrgID, config.Endpoint, - config.PingIntervalDuration, + 30, // 30 seconds config.PingTimeoutDuration, ) if err != nil { From eafd8161596f3b1177620f88d2b38050437071d7 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 16 Jan 2026 12:02:02 -0800 Subject: [PATCH 33/52] Clean up log messages Former-commit-id: 0231591f366ffc3fa118622d810b0e24ec4357b0 --- olm/data.go | 2 +- websocket/client.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/olm/data.go b/olm/data.go index 1cd29fa..35798c6 100644 --- a/olm/data.go +++ b/olm/data.go @@ -137,7 +137,7 @@ func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) { // Handler for syncing peer configuration - reconciles expected state with actual state func (o *Olm) handleSync(msg websocket.WSMessage) { - logger.Debug("++++++++++++++++++++++++++++Received sync message: %v", msg.Data) + logger.Debug("Received sync message: %v", msg.Data) if !o.connected { logger.Warn("Not connected, ignoring sync request") diff --git a/websocket/client.go b/websocket/client.go index 4a1099e..024d915 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -679,7 +679,7 @@ func (c *Client) pingMonitor() { ConfigVersion: configVersion, } - logger.Debug("++++++++++++++++++++++++++++websocket: Sending ping: %+v", pingMsg) + logger.Debug("websocket: Sending ping: %+v", pingMsg) c.writeMux.Lock() err := c.conn.WriteJSON(pingMsg) @@ -711,7 +711,7 @@ func (c *Client) GetConfigVersion() int { func (c *Client) setConfigVersion(version int) { c.configVersionMux.Lock() defer c.configVersionMux.Unlock() - logger.Debug("++++++++++++++++++++++++++++websocket: setting config version to %d", version) + logger.Debug("websocket: setting config version to %d", version) c.configVersion = version } From 71044165d027b5049372e2186baeb55e2c38bd0e Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 16 Jan 2026 12:16:51 -0800 Subject: [PATCH 34/52] Include fingerprint and posture info in ping Former-commit-id: f061596e5b12552feaedb4b7079111cd6734bc0e --- olm/olm.go | 8 ++++++++ websocket/client.go | 31 +++++++++++++++++++++++-------- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index e1d9a7f..bc06602 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -315,6 +315,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) { config.Endpoint, 30, // 30 seconds config.PingTimeoutDuration, + websocket.WithPingDataProvider(func() map[string]any { + o.metaMu.Lock() + defer o.metaMu.Unlock() + return map[string]any{ + "fingerprint": o.fingerprint, + "postures": o.postures, + } + }), ) if err != nil { logger.Error("Failed to create olm: %v", err) diff --git a/websocket/client.go b/websocket/client.go index 024d915..d0ac73b 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -96,9 +96,10 @@ type Client struct { exitNodes []ExitNode // Cached exit nodes from token response tokenMux sync.RWMutex // Protects token and exitNodes forceNewToken bool // Flag to force fetching a new token on next connection - processingMessage bool // Flag to track if a message is currently being processed - processingMux sync.RWMutex // Protects processingMessage - processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete + processingMessage bool // Flag to track if a message is currently being processed + processingMux sync.RWMutex // Protects processingMessage + processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete + getPingData func() map[string]any // Callback to get additional ping data } type ClientOption func(*Client) @@ -134,6 +135,13 @@ func WithTLSConfig(config TLSConfig) ClientOption { } } +// WithPingDataProvider sets a callback to provide additional data for ping messages +func WithPingDataProvider(fn func() map[string]any) ClientOption { + return func(c *Client) { + c.getPingData = fn + } +} + func (c *Client) OnConnect(callback func() error) { c.onConnect = callback } @@ -670,12 +678,19 @@ func (c *Client) pingMonitor() { configVersion := c.configVersion c.configVersionMux.RUnlock() + pingData := map[string]any{ + "timestamp": time.Now().Unix(), + "userToken": c.config.UserToken, + } + if c.getPingData != nil { + for k, v := range c.getPingData() { + pingData[k] = v + } + } + pingMsg := WSMessage{ - Type: "olm/ping", - Data: map[string]any{ - "timestamp": time.Now().Unix(), - "userToken": c.config.UserToken, - }, + Type: "olm/ping", + Data: pingData, ConfigVersion: configVersion, } From 0b462891368569143b1f2a0f9ac0ee3f134f28d5 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 16 Jan 2026 14:19:02 -0800 Subject: [PATCH 35/52] Add error can be sent from cloud to display in api Former-commit-id: 2167f22713ba384fe7068eb52bea52ed3abbaeea --- api/api.go | 33 ++++++++++++++++++++++++++++++++- olm/connect.go | 34 ++++++++++++++++++++++++++++++++++ olm/olm.go | 29 ++++++++++++++++------------- olm/types.go | 1 + 4 files changed, 83 insertions(+), 14 deletions(-) diff --git a/api/api.go b/api/api.go index 442162e..b85b041 100644 --- a/api/api.go +++ b/api/api.go @@ -49,11 +49,18 @@ type PeerStatus struct { HolepunchConnected bool `json:"holepunchConnected"` } +// OlmError holds error information from registration failures +type OlmError struct { + Code string `json:"code"` + Message string `json:"message"` +} + // StatusResponse is returned by the status endpoint type StatusResponse struct { Connected bool `json:"connected"` Registered bool `json:"registered"` Terminated bool `json:"terminated"` + OlmError *OlmError `json:"error,omitempty"` Version string `json:"version,omitempty"` Agent string `json:"agent,omitempty"` OrgID string `json:"orgId,omitempty"` @@ -85,6 +92,7 @@ type API struct { isConnected bool isRegistered bool isTerminated bool + olmError *OlmError version string agent string @@ -138,7 +146,7 @@ func (s *API) Start() error { if s.socketPath == "" && s.addr == "" { return fmt.Errorf("either socketPath or addr must be provided to start the API server") } - + mux := http.NewServeMux() mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) @@ -260,6 +268,27 @@ func (s *API) SetRegistered(registered bool) { s.statusMu.Lock() defer s.statusMu.Unlock() s.isRegistered = registered + // Clear any registration error when successfully registered + if registered { + s.olmError = nil + } +} + +// SetOlmError sets the registration error +func (s *API) SetOlmError(code string, message string) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.olmError = &OlmError{ + Code: code, + Message: message, + } +} + +// ClearOlmError clears any registration error +func (s *API) ClearOlmError() { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.olmError = nil } func (s *API) SetTerminated(terminated bool) { @@ -387,6 +416,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { Connected: s.isConnected, Registered: s.isRegistered, Terminated: s.isTerminated, + OlmError: s.olmError, Version: s.version, Agent: s.agent, OrgID: s.orgID, @@ -553,6 +583,7 @@ func (s *API) GetStatus() StatusResponse { Connected: s.isConnected, Registered: s.isRegistered, Terminated: s.isTerminated, + OlmError: s.olmError, Version: s.version, Agent: s.agent, OrgID: s.orgID, diff --git a/olm/connect.go b/olm/connect.go index a610ea4..ebe7009 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -19,6 +19,12 @@ import ( "golang.zx2c4.com/wireguard/tun" ) +// OlmErrorData represents the error data sent from the server +type OlmErrorData struct { + Code string `json:"code"` + Message string `json:"message"` +} + func (o *Olm) handleConnect(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) @@ -206,11 +212,39 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { logger.Info("WireGuard device created.") } +func (o *Olm) handleOlmError(msg websocket.WSMessage) { + logger.Debug("Received olm error message: %v", msg.Data) + + var errorData OlmErrorData + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling olm error data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &errorData); err != nil { + logger.Error("Error unmarshaling olm error data: %v", err) + return + } + + logger.Error("Olm error (code: %s): %s", errorData.Code, errorData.Message) + + // Set the olm error in the API server so it can be exposed via status + o.apiServer.SetOlmError(errorData.Code, errorData.Message) + + // Invoke onOlmError callback if configured + if o.olmConfig.OnOlmError != nil { + go o.olmConfig.OnOlmError(errorData.Code, errorData.Message) + } +} + func (o *Olm) handleTerminate(msg websocket.WSMessage) { logger.Info("Received terminate message") o.apiServer.SetTerminated(true) o.apiServer.SetConnectionStatus(false) o.apiServer.SetRegistered(false) + o.apiServer.ClearOlmError() o.apiServer.ClearPeerStatuses() network.ClearNetworkSettings() diff --git a/olm/olm.go b/olm/olm.go index bc06602..df6cad0 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -337,6 +337,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { // Handlers for managing connection status o.websocket.RegisterHandler("olm/wg/connect", o.handleConnect) + o.websocket.RegisterHandler("olm/error", o.handleOlmError) o.websocket.RegisterHandler("olm/terminate", o.handleTerminate) // Handlers for managing peers @@ -427,6 +428,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.apiServer.SetTerminated(true) o.apiServer.SetConnectionStatus(false) o.apiServer.SetRegistered(false) + o.apiServer.ClearOlmError() o.apiServer.ClearPeerStatuses() network.ClearNetworkSettings() @@ -471,20 +473,20 @@ func (o *Olm) StartTunnel(config TunnelConfig) { for { select { case <-sigChan: - - logger.Info("SIGHUP received, toggling power mode") - if powerMode == "normal" { - powerMode = "low" - if err := o.SetPowerMode("low"); err != nil { - logger.Error("Failed to set low power mode: %v", err) + + logger.Info("SIGHUP received, toggling power mode") + if powerMode == "normal" { + powerMode = "low" + if err := o.SetPowerMode("low"); err != nil { + logger.Error("Failed to set low power mode: %v", err) + } + } else { + powerMode = "normal" + if err := o.SetPowerMode("normal"); err != nil { + logger.Error("Failed to set normal power mode: %v", err) + } } - } else { - powerMode = "normal" - if err := o.SetPowerMode("normal"); err != nil { - logger.Error("Failed to set normal power mode: %v", err) - } - } - + case <-tunnelCtx.Done(): return } @@ -597,6 +599,7 @@ func (o *Olm) StopTunnel() error { // Update API server status o.apiServer.SetConnectionStatus(false) o.apiServer.SetRegistered(false) + o.apiServer.ClearOlmError() network.ClearNetworkSettings() o.apiServer.ClearPeerStatuses() diff --git a/olm/types.go b/olm/types.go index 2e56ad7..198b222 100644 --- a/olm/types.go +++ b/olm/types.go @@ -46,6 +46,7 @@ type OlmConfig struct { OnConnected func() OnTerminated func() OnAuthError func(statusCode int, message string) // Called when auth fails (401/403) + OnOlmError func(code string, message string) // Called when registration fails OnExit func() // Called when exit is requested via API } From 2ea12ce2589d3a92a94ac4dc0caa56c89c16489b Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 16 Jan 2026 14:59:13 -0800 Subject: [PATCH 36/52] Set the error on terminate as well Former-commit-id: 8ff58e6efcd239523c308e1604b184ba6f01bd32 --- olm/connect.go | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/olm/connect.go b/olm/connect.go index ebe7009..394e7e2 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -241,10 +241,25 @@ func (o *Olm) handleOlmError(msg websocket.WSMessage) { func (o *Olm) handleTerminate(msg websocket.WSMessage) { logger.Info("Received terminate message") + + var errorData OlmErrorData + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling terminate error data: %v", err) + } else { + if err := json.Unmarshal(jsonData, &errorData); err != nil { + logger.Error("Error unmarshaling terminate error data: %v", err) + } else { + logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message) + // Set the olm error in the API server so it can be exposed via status + o.apiServer.SetOlmError(errorData.Code, errorData.Message) + } + } + o.apiServer.SetTerminated(true) o.apiServer.SetConnectionStatus(false) o.apiServer.SetRegistered(false) - o.apiServer.ClearOlmError() o.apiServer.ClearPeerStatuses() network.ClearNetworkSettings() From 5ecba61718b9bae0fb7830f5ad5c658507796014 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 16 Jan 2026 15:17:20 -0800 Subject: [PATCH 37/52] Use the right duration Former-commit-id: 352b122166be02f642fc9b1f0a6f806bb1e5c86c --- olm/olm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olm/olm.go b/olm/olm.go index bc06602..f6e1980 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -313,7 +313,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { userToken, config.OrgID, config.Endpoint, - 30, // 30 seconds + 30 * time.Second, // 30 seconds config.PingTimeoutDuration, websocket.WithPingDataProvider(func() map[string]any { o.metaMu.Lock() From cfac3cdd533ac48128f27a08ae579a3260caea4b Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 16 Jan 2026 15:17:20 -0800 Subject: [PATCH 38/52] Use the right duration Former-commit-id: c921f08bd522d7730925ed3aac1fabae0ba97606 --- olm/olm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olm/olm.go b/olm/olm.go index df6cad0..2fa9a6f 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -313,7 +313,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { userToken, config.OrgID, config.Endpoint, - 30, // 30 seconds + 30 * time.Second, // 30 seconds config.PingTimeoutDuration, websocket.WithPingDataProvider(func() map[string]any { o.metaMu.Lock() From a13010c4afc3d8c283fb1f4e2ab3d18b3c1520a1 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 16 Jan 2026 17:33:40 -0800 Subject: [PATCH 39/52] Update docs for metadata Former-commit-id: 9d77a1daf7451e74a2337f0467497220d76cb627 --- API.md | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/API.md b/API.md index 4e20f50..f8d8878 100644 --- a/API.md +++ b/API.md @@ -46,7 +46,18 @@ Initiates a new connection request to a Pangolin server. "tlsClientCert": "string", "pingInterval": "3s", "pingTimeout": "5s", - "orgId": "string" + "orgId": "string", + "fingerprint": { + "username": "string", + "hostname": "string", + "platform": "string", + "osVersion": "string", + "kernelVersion": "string", + "arch": "string", + "deviceModel": "string", + "serialNumber": "string" + }, + "postures": {} } ``` @@ -67,6 +78,16 @@ Initiates a new connection request to a Pangolin server. - `pingInterval`: Interval for pinging the server (default: 3s) - `pingTimeout`: Timeout for each ping (default: 5s) - `orgId`: Organization ID to connect to +- `fingerprint`: Device fingerprinting information (should be set before connecting) + - `username`: Current username on the device + - `hostname`: Device hostname + - `platform`: Operating system platform (macos, windows, linux, ios, android, unknown) + - `osVersion`: Operating system version + - `kernelVersion`: Kernel version + - `arch`: System architecture (e.g., amd64, arm64) + - `deviceModel`: Device model identifier + - `serialNumber`: Device serial number +- `postures`: Device posture/security information **Response:** - **Status Code:** `202 Accepted` @@ -205,6 +226,56 @@ Switches to a different organization while maintaining the connection. --- +### PUT /metadata +Updates device fingerprinting and posture information. This endpoint can be called at any time to update metadata, but it's recommended to provide this information in the initial `/connect` request or immediately before connecting. + +**Request Body:** +```json +{ + "fingerprint": { + "username": "string", + "hostname": "string", + "platform": "string", + "osVersion": "string", + "kernelVersion": "string", + "arch": "string", + "deviceModel": "string", + "serialNumber": "string" + }, + "postures": {} +} +``` + +**Optional Fields:** +- `fingerprint`: Device fingerprinting information + - `username`: Current username on the device + - `hostname`: Device hostname + - `platform`: Operating system platform (macos, windows, linux, ios, android, unknown) + - `osVersion`: Operating system version + - `kernelVersion`: Kernel version + - `arch`: System architecture (e.g., amd64, arm64) + - `deviceModel`: Device model identifier + - `serialNumber`: Device serial number +- `postures`: Device posture/security information (object with arbitrary key-value pairs) + +**Response:** +- **Status Code:** `200 OK` +- **Content-Type:** `application/json` + +```json +{ + "status": "metadata updated" +} +``` + +**Error Responses:** +- `405 Method Not Allowed` - Non-PUT requests +- `400 Bad Request` - Invalid JSON + +**Note:** It's recommended to call this endpoint BEFORE `/connect` to ensure fingerprinting information is available during the initial connection handshake. + +--- + ### POST /exit Initiates a graceful shutdown of the Olm process. @@ -247,6 +318,22 @@ Simple health check endpoint to verify the API server is running. ## Usage Examples +### Update metadata before connecting (recommended) +```bash +curl -X PUT http://localhost:9452/metadata \ + -H "Content-Type: application/json" \ + -d '{ + "fingerprint": { + "username": "john", + "hostname": "johns-laptop", + "platform": "macos", + "osVersion": "14.2.1", + "arch": "arm64", + "deviceModel": "MacBookPro18,3" + } + }' +``` + ### Connect to a peer ```bash curl -X POST http://localhost:9452/connect \ From a06436eeab5c26be568a4fe9de13a74331d753b6 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 17 Jan 2026 17:05:29 -0800 Subject: [PATCH 40/52] Add rebind endpoints for the shared socket Former-commit-id: 6fd0984b13954402b4598bf396710a34e2337128 --- api/api.go | 34 ++++++++++++++++++++++++++ olm/olm.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) diff --git a/api/api.go b/api/api.go index 442162e..e18bee7 100644 --- a/api/api.go +++ b/api/api.go @@ -78,6 +78,7 @@ type API struct { onMetadataChange func(MetadataChangeRequest) error onDisconnect func() error onExit func() error + onRebind func() error statusMu sync.RWMutex peerStatuses map[int]*PeerStatus @@ -126,11 +127,13 @@ func (s *API) SetHandlers( onMetadataChange func(MetadataChangeRequest) error, onDisconnect func() error, onExit func() error, + onRebind func() error, ) { s.onConnect = onConnect s.onSwitchOrg = onSwitchOrg s.onDisconnect = onDisconnect s.onExit = onExit + s.onRebind = onRebind } // Start starts the HTTP server @@ -147,6 +150,7 @@ func (s *API) Start() error { mux.HandleFunc("/disconnect", s.handleDisconnect) mux.HandleFunc("/exit", s.handleExit) mux.HandleFunc("/health", s.handleHealth) + mux.HandleFunc("/rebind", s.handleRebind) s.server = &http.Server{ Handler: mux, @@ -560,3 +564,33 @@ func (s *API) GetStatus() StatusResponse { NetworkSettings: network.GetSettings(), } } + +// handleRebind handles the /rebind endpoint +// This triggers a socket rebind, which is necessary when network connectivity changes +// (e.g., WiFi to cellular transition on macOS/iOS) and the old socket becomes stale. +func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + logger.Info("Received rebind request via API") + + // Call the rebind handler if set + if s.onRebind != nil { + if err := s.onRebind(); err != nil { + http.Error(w, fmt.Sprintf("Rebind failed: %v", err), http.StatusInternalServerError) + return + } + } else { + http.Error(w, "Rebind handler not configured", http.StatusNotImplemented) + return + } + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "socket rebound successfully", + }) +} diff --git a/olm/olm.go b/olm/olm.go index f6e1980..26fc0e4 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -273,6 +273,11 @@ func (o *Olm) registerAPICallbacks() { } return nil }, + // onRebind + func() error { + logger.Info("Processing rebind request via API") + return o.RebindSocket() + }, ) } @@ -783,6 +788,72 @@ func (o *Olm) SetPowerMode(mode string) error { return nil } +// RebindSocket recreates the UDP socket when network connectivity changes. +// This is necessary on macOS/iOS when transitioning between WiFi and cellular, +// as the old socket becomes stale and can no longer route packets. +// Call this method when detecting a network path change. +func (o *Olm) RebindSocket() error { + if o.sharedBind == nil { + return fmt.Errorf("shared bind is not initialized") + } + + // Get the current port so we can try to reuse it + currentPort := o.sharedBind.GetPort() + + logger.Info("Rebinding UDP socket (current port: %d)", currentPort) + + // Create a new UDP socket + var newConn *net.UDPConn + var newPort uint16 + var err error + + // First try to bind to the same port + localAddr := &net.UDPAddr{ + Port: int(currentPort), + IP: net.IPv4zero, + } + + newConn, err = net.ListenUDP("udp4", localAddr) + if err != nil { + // If we can't reuse the port, find a new one + logger.Warn("Could not rebind to port %d, finding new port: %v", currentPort, err) + newPort, err = util.FindAvailableUDPPort(49152, 65535) + if err != nil { + return fmt.Errorf("failed to find available UDP port: %w", err) + } + + localAddr = &net.UDPAddr{ + Port: int(newPort), + IP: net.IPv4zero, + } + + // Use udp4 explicitly to avoid IPv6 dual-stack issues + newConn, err = net.ListenUDP("udp4", localAddr) + if err != nil { + return fmt.Errorf("failed to create new UDP socket: %w", err) + } + } else { + newPort = currentPort + } + + // Rebind the shared bind with the new connection + if err := o.sharedBind.Rebind(newConn); err != nil { + newConn.Close() + return fmt.Errorf("failed to rebind shared bind: %w", err) + } + + logger.Info("Successfully rebound UDP socket on port %d", newPort) + + // Trigger a hole punch to re-establish NAT mappings with the new socket + if o.holePunchManager != nil { + o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetServerHolepunchInterval() + logger.Info("Triggered hole punch after socket rebind") + } + + return nil +} + func (o *Olm) AddDevice(fd uint32) error { if o.middleDev == nil { return fmt.Errorf("middle device is not initialized") From 17dc1b0be19e4a0b0efaff87d7db296056e43d18 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 17 Jan 2026 17:32:01 -0800 Subject: [PATCH 41/52] Dont start the ping until we are connected Former-commit-id: 43c8a14fda9d8f09cb8e8b31ff973a5f124979d1 --- olm/connect.go | 3 +++ olm/olm.go | 2 ++ websocket/client.go | 26 ++++++++++++++++++++++++-- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/olm/connect.go b/olm/connect.go index a610ea4..7f3785e 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -198,6 +198,9 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { o.connected = true + // Start ping monitor now that we are registered and connected + o.websocket.StartPingMonitor() + // Invoke onConnected callback if configured if o.olmConfig.OnConnected != nil { go o.olmConfig.OnConnected() diff --git a/olm/olm.go b/olm/olm.go index f6e1980..b2df734 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -362,6 +362,8 @@ func (o *Olm) StartTunnel(config TunnelConfig) { if o.connected { logger.Debug("Already connected, skipping registration") + // Restart ping monitor on reconnect since the old one would have exited + o.websocket.StartPingMonitor() return nil } diff --git a/websocket/client.go b/websocket/client.go index d0ac73b..844bde3 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -100,6 +100,8 @@ type Client struct { processingMux sync.RWMutex // Protects processingMessage processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete getPingData func() map[string]any // Callback to get additional ping data + pingStarted bool // Flag to track if ping monitor has been started + pingStartedMux sync.Mutex // Protects pingStarted } type ClientOption func(*Client) @@ -575,8 +577,14 @@ func (c *Client) establishConnection() error { c.conn = conn c.setConnected(true) - // Start the ping monitor - go c.pingMonitor() + // Reset ping started flag on new connection + c.pingStartedMux.Lock() + c.pingStarted = false + c.pingStartedMux.Unlock() + + // Note: ping monitor is NOT started here - it will be started when + // StartPingMonitor() is called after registration completes + // Start the read pump with disconnect detection go c.readPumpWithDisconnectDetection() @@ -715,6 +723,20 @@ func (c *Client) pingMonitor() { } } +// StartPingMonitor starts the ping monitor goroutine. +// This should be called after the client is registered and connected. +// It is safe to call multiple times - only the first call will start the monitor. +func (c *Client) StartPingMonitor() { + c.pingStartedMux.Lock() + defer c.pingStartedMux.Unlock() + + if c.pingStarted { + return + } + c.pingStarted = true + go c.pingMonitor() +} + // GetConfigVersion returns the current config version func (c *Client) GetConfigVersion() int { c.configVersionMux.RLock() From 4e4d1a39f6b980c7785cb8ed190461299484348c Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 17 Jan 2026 17:35:00 -0800 Subject: [PATCH 42/52] Try to close the socket first Former-commit-id: ed4775bd263085442907fbc3ff97db2a79c9769f --- olm/olm.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index 26fc0e4..286db25 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -797,17 +797,19 @@ func (o *Olm) RebindSocket() error { return fmt.Errorf("shared bind is not initialized") } - // Get the current port so we can try to reuse it - currentPort := o.sharedBind.GetPort() + // Close the old socket first to release the port, then try to rebind to the same port + currentPort, err := o.sharedBind.CloseSocket() + if err != nil { + return fmt.Errorf("failed to close old socket: %w", err) + } - logger.Info("Rebinding UDP socket (current port: %d)", currentPort) + logger.Info("Rebinding UDP socket (released port: %d)", currentPort) // Create a new UDP socket var newConn *net.UDPConn var newPort uint16 - var err error - // First try to bind to the same port + // First try to bind to the same port (now available since we closed the old socket) localAddr := &net.UDPAddr{ Port: int(currentPort), IP: net.IPv4zero, From 8b9ee6f26ad1181882176c6c9afde7fab3d83894 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 18 Jan 2026 11:46:03 -0800 Subject: [PATCH 43/52] Move power mode to the api from signal Former-commit-id: 5d8ea92ef0518bf5c4b59642e6d05e0fdcdf3fd0 --- api/api.go | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++++- olm/olm.go | 37 +++++-------------------------------- 2 files changed, 57 insertions(+), 33 deletions(-) diff --git a/api/api.go b/api/api.go index b11cc70..efd3346 100644 --- a/api/api.go +++ b/api/api.go @@ -33,7 +33,12 @@ type ConnectionRequest struct { // SwitchOrgRequest defines the structure for switching organizations type SwitchOrgRequest struct { - OrgID string `json:"orgId"` + OrgID string `json:"org_id"` +} + +// PowerModeRequest represents a request to change power mode +type PowerModeRequest struct { + Mode string `json:"mode"` // "normal" or "low" } // PeerStatus represents the status of a peer connection @@ -86,6 +91,7 @@ type API struct { onDisconnect func() error onExit func() error onRebind func() error + onPowerMode func(PowerModeRequest) error statusMu sync.RWMutex peerStatuses map[int]*PeerStatus @@ -136,12 +142,15 @@ func (s *API) SetHandlers( onDisconnect func() error, onExit func() error, onRebind func() error, + onPowerMode func(PowerModeRequest) error, ) { s.onConnect = onConnect s.onSwitchOrg = onSwitchOrg + s.onMetadataChange = onMetadataChange s.onDisconnect = onDisconnect s.onExit = onExit s.onRebind = onRebind + s.onPowerMode = onPowerMode } // Start starts the HTTP server @@ -159,6 +168,7 @@ func (s *API) Start() error { mux.HandleFunc("/exit", s.handleExit) mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/rebind", s.handleRebind) + mux.HandleFunc("/power-mode", s.handlePowerMode) s.server = &http.Server{ Handler: mux, @@ -625,3 +635,44 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) { "status": "socket rebound successfully", }) } + +// handlePowerMode handles the /power-mode endpoint +// This allows changing the power mode between "normal" and "low" +func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req PowerModeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + // Validate power mode + if req.Mode != "normal" && req.Mode != "low" { + http.Error(w, "Invalid power mode: must be 'normal' or 'low'", http.StatusBadRequest) + return + } + + logger.Info("Received power mode change request via API: mode=%s", req.Mode) + + // Call the power mode handler if set + if s.onPowerMode != nil { + if err := s.onPowerMode(req); err != nil { + http.Error(w, fmt.Sprintf("Power mode change failed: %v", err), http.StatusInternalServerError) + return + } + } else { + http.Error(w, "Power mode handler not configured", http.StatusNotImplemented) + return + } + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": fmt.Sprintf("power mode changed to %s successfully", req.Mode), + }) +} diff --git a/olm/olm.go b/olm/olm.go index 6c975d3..691d716 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -7,9 +7,7 @@ import ( "net/http" _ "net/http/pprof" "os" - "os/signal" "sync" - "syscall" "time" "github.com/fosrl/newt/bind" @@ -278,6 +276,11 @@ func (o *Olm) registerAPICallbacks() { logger.Info("Processing rebind request via API") return o.RebindSocket() }, + // onPowerMode + func(req api.PowerModeRequest) error { + logger.Info("Processing power mode change request via API: mode=%s", req.Mode) + return o.SetPowerMode(req.Mode) + }, ) } @@ -470,36 +473,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } defer func() { _ = o.websocket.Close() }() - // Setup SIGHUP signal handler for testing (toggles power state) - // THIS SHOULD ONLY BE USED AND ON IN A DEV MODE - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGHUP) - - go func() { - powerMode := "normal" - for { - select { - case <-sigChan: - - logger.Info("SIGHUP received, toggling power mode") - if powerMode == "normal" { - powerMode = "low" - if err := o.SetPowerMode("low"); err != nil { - logger.Error("Failed to set low power mode: %v", err) - } - } else { - powerMode = "normal" - if err := o.SetPowerMode("normal"); err != nil { - logger.Error("Failed to set normal power mode: %v", err) - } - } - - case <-tunnelCtx.Done(): - return - } - } - }() - // Wait for context cancellation <-tunnelCtx.Done() logger.Info("Tunnel process context cancelled, cleaning up") From a8e0844758df7fb84a9243d0bfe5f9f9f754c5cc Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 18 Jan 2026 11:55:09 -0800 Subject: [PATCH 44/52] Send disconnecting message when stopping Former-commit-id: 1fb6e2a00d70ea73554dffc7b3e4caa19daa3b8b --- olm/olm.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/olm/olm.go b/olm/olm.go index 691d716..7476561 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -321,7 +321,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { userToken, config.OrgID, config.Endpoint, - 30 * time.Second, // 30 seconds + 30*time.Second, // 30 seconds config.PingTimeoutDuration, websocket.WithPingDataProvider(func() map[string]any { o.metaMu.Lock() @@ -479,6 +479,9 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } func (o *Olm) Close() { + // send a disconnect message to the cloud to show disconnected + o.websocket.SendMessage("olm/disconnecting", map[string]any{}) + // Restore original DNS configuration // we do this first to avoid any DNS issues if something else gets stuck if err := dnsOverride.RestoreDNSOverride(); err != nil { From 25cb50901ede9df6ffecc5764d4290f0cb2f7b80 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 18 Jan 2026 12:18:48 -0800 Subject: [PATCH 45/52] Quiet up logs again Former-commit-id: 112283191c7122d0608859bc2bfaf28f82bcb6cb --- peers/monitor/monitor.go | 2 +- peers/monitor/wgtester.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 387b82f..28d92ef 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -580,7 +580,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool { anyStatusChanged := false for siteID, endpoint := range endpoints { - logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint) + // logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint) result := pm.holepunchTester.TestEndpoint(endpoint, timeout) pm.mutex.Lock() diff --git a/peers/monitor/wgtester.go b/peers/monitor/wgtester.go index f06759a..e9f6f63 100644 --- a/peers/monitor/wgtester.go +++ b/peers/monitor/wgtester.go @@ -168,7 +168,7 @@ func (c *Client) ensureConnection() error { // TestPeerConnection checks if the connection to the server is working // Returns true if connected, false otherwise func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) { - logger.Debug("wgtester: testing connection to peer %s", c.serverAddr) + // logger.Debug("wgtester: testing connection to peer %s", c.serverAddr) if err := c.ensureConnection(); err != nil { logger.Warn("Failed to ensure connection: %v", err) return false, 0 From a81c683c66a2dbeb1016e91de2d8d8839bccf5bb Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 18 Jan 2026 14:49:42 -0800 Subject: [PATCH 46/52] Reorder websocket disconnect message Former-commit-id: 592a0d60c654e1e24d5a42f0578ca31ada002cab --- olm/olm.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index 7476561..12f804a 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -480,7 +480,12 @@ func (o *Olm) StartTunnel(config TunnelConfig) { func (o *Olm) Close() { // send a disconnect message to the cloud to show disconnected - o.websocket.SendMessage("olm/disconnecting", map[string]any{}) + if o.websocket != nil { + o.websocket.SendMessage("olm/disconnecting", map[string]any{}) + // Close the websocket connection after sending disconnect + _ = o.websocket.Close() + o.websocket = nil + } // Restore original DNS configuration // we do this first to avoid any DNS issues if something else gets stuck @@ -567,12 +572,7 @@ func (o *Olm) StopTunnel() error { time.Sleep(200 * time.Millisecond) } - // Close the websocket connection - if o.websocket != nil { - _ = o.websocket.Close() - o.websocket = nil - } - + // Close() will handle sending disconnect message and closing websocket o.Close() // Reset the connected state From 6d10650e70f534574e922054fec8e549463152c0 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 18 Jan 2026 15:14:11 -0800 Subject: [PATCH 47/52] Send an initial ping so we get online faster in the dashboard Former-commit-id: 41e4eb24a2d7707bb5ec7af0e6b8ef6f1a46352b --- websocket/client.go | 110 ++++++++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 50 deletions(-) diff --git a/websocket/client.go b/websocket/client.go index 844bde3..a3e39a4 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -660,6 +660,59 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) { return loadClientCertificate(c.tlsConfig.PKCS12File) } +// sendPing sends a single ping message +func (c *Client) sendPing() { + if c.isDisconnected || c.conn == nil { + return + } + // Skip ping if a message is currently being processed + c.processingMux.RLock() + isProcessing := c.processingMessage + c.processingMux.RUnlock() + if isProcessing { + logger.Debug("websocket: Skipping ping, message is being processed") + return + } + // Send application-level ping with config version + c.configVersionMux.RLock() + configVersion := c.configVersion + c.configVersionMux.RUnlock() + + pingData := map[string]any{ + "timestamp": time.Now().Unix(), + "userToken": c.config.UserToken, + } + if c.getPingData != nil { + for k, v := range c.getPingData() { + pingData[k] = v + } + } + + pingMsg := WSMessage{ + Type: "olm/ping", + Data: pingData, + ConfigVersion: configVersion, + } + + logger.Debug("websocket: Sending ping: %+v", pingMsg) + + c.writeMux.Lock() + err := c.conn.WriteJSON(pingMsg) + c.writeMux.Unlock() + if err != nil { + // Check if we're shutting down before logging error and reconnecting + select { + case <-c.done: + // Expected during shutdown + return + default: + logger.Error("websocket: Ping failed: %v", err) + c.reconnect() + return + } + } +} + // pingMonitor sends pings at a short interval and triggers reconnect on failure func (c *Client) pingMonitor() { ticker := time.NewTicker(c.pingInterval) @@ -670,55 +723,7 @@ func (c *Client) pingMonitor() { case <-c.done: return case <-ticker.C: - if c.isDisconnected || c.conn == nil { - return - } - // Skip ping if a message is currently being processed - c.processingMux.RLock() - isProcessing := c.processingMessage - c.processingMux.RUnlock() - if isProcessing { - logger.Debug("websocket: Skipping ping, message is being processed") - continue - } - // Send application-level ping with config version - c.configVersionMux.RLock() - configVersion := c.configVersion - c.configVersionMux.RUnlock() - - pingData := map[string]any{ - "timestamp": time.Now().Unix(), - "userToken": c.config.UserToken, - } - if c.getPingData != nil { - for k, v := range c.getPingData() { - pingData[k] = v - } - } - - pingMsg := WSMessage{ - Type: "olm/ping", - Data: pingData, - ConfigVersion: configVersion, - } - - logger.Debug("websocket: Sending ping: %+v", pingMsg) - - c.writeMux.Lock() - err := c.conn.WriteJSON(pingMsg) - c.writeMux.Unlock() - if err != nil { - // Check if we're shutting down before logging error and reconnecting - select { - case <-c.done: - // Expected during shutdown - return - default: - logger.Error("websocket: Ping failed: %v", err) - c.reconnect() - return - } - } + c.sendPing() } } } @@ -734,7 +739,12 @@ func (c *Client) StartPingMonitor() { return } c.pingStarted = true - go c.pingMonitor() + + // Send an initial ping immediately + go func() { + c.sendPing() + c.pingMonitor() + }() } // GetConfigVersion returns the current config version From f2e81c024aa6dd115516a448e5636f2e7a8f2d6e Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 19 Jan 2026 15:05:29 -0800 Subject: [PATCH 48/52] Set fingerprint earlier Former-commit-id: ef36f7ca821b4b02c2aa95492a99ac6b197ef9ed --- olm/olm.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index 12f804a..ec0b6dc 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -289,15 +289,28 @@ func (o *Olm) StartTunnel(config TunnelConfig) { logger.Info("Tunnel already running") return } + + // debug print out the whole config + logger.Debug("Starting tunnel with config: %+v", config) o.tunnelRunning = true // Also set it here in case it is called externally o.tunnelConfig = config // Reset terminated status when tunnel starts o.apiServer.SetTerminated(false) + + fingerprint := config.InitialFingerprint + if fingerprint == nil { + fingerprint = make(map[string]any) + } - // debug print out the whole config - logger.Debug("Starting tunnel with config: %+v", config) + postures := config.InitialPostures + if postures == nil { + postures = make(map[string]any) + } + + o.SetFingerprint(fingerprint) + o.SetPostures(postures) // Create a cancellable context for this tunnel process tunnelCtx, cancel := context.WithCancel(o.olmCtx) @@ -453,19 +466,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } }) - fingerprint := config.InitialFingerprint - if fingerprint == nil { - fingerprint = make(map[string]any) - } - - postures := config.InitialPostures - if postures == nil { - postures = make(map[string]any) - } - - o.SetFingerprint(fingerprint) - o.SetPostures(postures) - // Connect to the WebSocket server if err := o.websocket.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) From 79e8a4a8bb8a3b06cb5fe8654a95d6abe62905c3 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 19 Jan 2026 15:57:20 -0800 Subject: [PATCH 49/52] Dont start holepunching if we rebind while in low power mode Former-commit-id: 4a5ebd41f343ccf9a668bdc8ccff0bbc2a3905f0 --- olm/olm.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index ec0b6dc..fb528f9 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -827,11 +827,18 @@ func (o *Olm) RebindSocket() error { logger.Info("Successfully rebound UDP socket on port %d", newPort) - // Trigger a hole punch to re-establish NAT mappings with the new socket - if o.holePunchManager != nil { + // Check if we're in low power mode before triggering hole punch + o.powerModeMu.Lock() + isLowPower := o.currentPowerMode == "low" + o.powerModeMu.Unlock() + + // Only trigger hole punch if not in low power mode + if !isLowPower && o.holePunchManager != nil { o.holePunchManager.TriggerHolePunch() o.holePunchManager.ResetServerHolepunchInterval() logger.Info("Triggered hole punch after socket rebind") + } else if isLowPower { + logger.Info("Skipping hole punch trigger due to low power mode") } return nil From c4e297cc9628f3b62a66a306dc252cb0b45fb4e9 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 20 Jan 2026 11:30:06 -0800 Subject: [PATCH 50/52] Handle properly stopping and starting the ping Former-commit-id: 34c7717767d42b880ac8697d03fd898a5f4b042d --- olm/olm.go | 12 ++++++++++-- websocket/client.go | 29 ++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index fb528f9..cd8a844 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -383,9 +383,9 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.apiServer.SetConnectionStatus(true) if o.connected { - logger.Debug("Already connected, skipping registration") - // Restart ping monitor on reconnect since the old one would have exited o.websocket.StartPingMonitor() + + logger.Debug("Already connected, skipping registration") return nil } @@ -686,6 +686,14 @@ func (o *Olm) SetPowerMode(mode string) error { logger.Info("Switching to low power mode") + // Mark as disconnected so we re-register on reconnect + o.connected = false + + // Update API server connection status + if o.apiServer != nil { + o.apiServer.SetConnectionStatus(false) + } + if o.websocket != nil { logger.Info("Disconnecting websocket for low power mode") if err := o.websocket.Disconnect(); err != nil { diff --git a/websocket/client.go b/websocket/client.go index a3e39a4..c4e67b0 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -102,6 +102,7 @@ type Client struct { getPingData func() map[string]any // Callback to get additional ping data pingStarted bool // Flag to track if ping monitor has been started pingStartedMux sync.Mutex // Protects pingStarted + pingDone chan struct{} // Channel to stop the ping monitor independently } type ClientOption func(*Client) @@ -176,6 +177,7 @@ func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time. pingInterval: pingInterval, pingTimeout: pingTimeout, clientType: "olm", + pingDone: make(chan struct{}), } // Apply options before loading config @@ -235,6 +237,9 @@ func (c *Client) Disconnect() error { c.isDisconnected = true c.setConnected(false) + // Stop the ping monitor + c.stopPingMonitor() + // Wait for any message currently being processed to complete c.processingWg.Wait() @@ -577,11 +582,6 @@ func (c *Client) establishConnection() error { c.conn = conn c.setConnected(true) - // Reset ping started flag on new connection - c.pingStartedMux.Lock() - c.pingStarted = false - c.pingStartedMux.Unlock() - // Note: ping monitor is NOT started here - it will be started when // StartPingMonitor() is called after registration completes @@ -722,6 +722,8 @@ func (c *Client) pingMonitor() { select { case <-c.done: return + case <-c.pingDone: + return case <-ticker.C: c.sendPing() } @@ -740,6 +742,9 @@ func (c *Client) StartPingMonitor() { } c.pingStarted = true + // Create a new pingDone channel for this ping monitor instance + c.pingDone = make(chan struct{}) + // Send an initial ping immediately go func() { c.sendPing() @@ -747,6 +752,20 @@ func (c *Client) StartPingMonitor() { }() } +// stopPingMonitor stops the ping monitor goroutine if it's running. +func (c *Client) stopPingMonitor() { + c.pingStartedMux.Lock() + defer c.pingStartedMux.Unlock() + + if !c.pingStarted { + return + } + + // Close the pingDone channel to stop the monitor + close(c.pingDone) + c.pingStarted = false +} + // GetConfigVersion returns the current config version func (c *Client) GetConfigVersion() int { c.configVersionMux.RLock() From 4ef6089053216c6b337e399bc36c6a75ddb376d5 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 23 Jan 2026 10:19:38 -0800 Subject: [PATCH 51/52] Comment out local newt Former-commit-id: c4ef1e724e404c5d9093f757819e0a706d39a172 --- go.mod | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 0d6bbcb..aa631ef 100644 --- a/go.mod +++ b/go.mod @@ -31,4 +31,5 @@ require ( golang.zx2c4.com/wireguard/windows v0.5.3 // indirect ) -replace github.com/fosrl/newt => ../newt +# To be used ONLY for local development +# replace github.com/fosrl/newt => ../newt From 51eee9dcf539de0dc662aeaa070cdf311af025e4 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 23 Jan 2026 10:23:42 -0800 Subject: [PATCH 52/52] Bump newt Former-commit-id: f4885e9c4db4bd9e081a82caebde58588acdbb16 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 6ff6989..09a5bc4 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.25 require ( github.com/Microsoft/go-winio v0.6.2 - github.com/fosrl/newt v1.8.1 + github.com/fosrl/newt v1.9.0 github.com/godbus/dbus/v5 v5.2.2 github.com/gorilla/websocket v1.5.3 github.com/miekg/dns v1.1.70 diff --git a/go.sum b/go.sum index c0a2bf7..be51e01 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/fosrl/newt v1.8.1 h1:oP3xBEISoO/TENsHccqqs6LXpoOWCt6aiP75CfIWpvk= -github.com/fosrl/newt v1.8.1/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= +github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE= +github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=