From 03051a37fe378159987fcfecb43c0fd400f3c71e Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 22 Dec 2025 16:19:45 -0500 Subject: [PATCH 01/27] Update mod Former-commit-id: ca5105b6b2e6a25167e4fb6d269b065dbbf8e5cd --- go.mod | 47 +++++++++++++++++++++++++++- go.sum | 96 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 4844592..11cb67a 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.25 require ( github.com/Microsoft/go-winio v0.6.2 - github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01 + github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06 github.com/godbus/dbus/v5 v5.2.0 github.com/gorilla/websocket v1.5.3 github.com/miekg/dns v1.1.68 @@ -16,17 +16,62 @@ require ( ) require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cenkalti/backoff/v5 v5.0.3 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/containerd/errdefs v0.3.0 // indirect + github.com/containerd/errdefs/pkg v0.3.0 // indirect + github.com/distribution/reference v0.6.0 // indirect + github.com/docker/docker v28.5.2+incompatible // indirect + github.com/docker/go-connections v0.6.0 // indirect + github.com/docker/go-units v0.4.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/go-cmp v0.7.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect + github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.1.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/prometheus/client_golang v1.23.2 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/otlptranslator v0.0.2 // indirect + github.com/prometheus/procfs v0.17.0 // indirect github.com/vishvananda/netlink v1.3.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect + go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0 // indirect + go.opentelemetry.io/otel v1.38.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 // indirect + go.opentelemetry.io/otel/exporters/prometheus v0.60.0 // indirect + go.opentelemetry.io/otel/metric v1.38.0 // indirect + go.opentelemetry.io/otel/sdk v1.38.0 // indirect + go.opentelemetry.io/otel/sdk/metric v1.38.0 // indirect + go.opentelemetry.io/otel/trace v1.38.0 // indirect + go.opentelemetry.io/proto/otlp v1.7.1 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect golang.org/x/mod v0.30.0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.18.0 // indirect + golang.org/x/text v0.31.0 // indirect golang.org/x/time v0.12.0 // indirect golang.org/x/tools v0.39.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect + google.golang.org/grpc v1.76.0 // indirect + google.golang.org/protobuf v1.36.8 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9bf88e2..66084df 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,103 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/containerd/errdefs v0.3.0 h1:FSZgGOeK4yuT/+DnF07/Olde/q4KBoMsaamhXxIMDp4= +github.com/containerd/errdefs v0.3.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= +github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= +github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk= +github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= +github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM= +github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= +github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= +github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01 h1:VpuI42l4enih//6IFFQDln/B7WukfMePxIRIpXsNe/0= github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= +github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06 h1:xWuCn+gzX0W7bHs/cV/ykNBliisNzNomPR76E4M0dtI= +github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc h1:GN2Lv3MGO7AS6PrRoT6yV5+wkrOpcszoIsO4+4ds248= +github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= +github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= +github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= +github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/otlptranslator v0.0.2 h1:+1CdeLVrRQ6Psmhnobldo0kTp96Rj80DRXRd5OSnMEQ= +github.com/prometheus/otlptranslator v0.0.2/go.mod h1:P8AwMgdD7XEr6QRUJ2QWLpiAZTgTE2UYgjlu3svompI= +github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= +github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg= +go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0 h1:PeBoRj6af6xMI7qCupwFvTbbnd49V7n5YpG6pg8iDYQ= +go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0/go.mod h1:ingqBCtMCe8I4vpz/UVzCW6sxoqgZB37nao91mLQ3Bw= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 h1:vl9obrcoWVKp/lwl8tRE33853I8Xru9HFbw/skNeLs8= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0/go.mod h1:GAXRxmLJcVM3u22IjTg74zWBrRCKq8BnOqUVLodpcpw= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk= +go.opentelemetry.io/otel/exporters/prometheus v0.60.0 h1:cGtQxGvZbnrWdC2GyjZi0PDKVSLWP/Jocix3QWfXtbo= +go.opentelemetry.io/otel/exporters/prometheus v0.60.0/go.mod h1:hkd1EekxNo69PTV4OWFGZcKQiIqg0RfuWExcPKFvepk= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= +go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4= +go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= @@ -30,6 +112,8 @@ golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= @@ -42,6 +126,18 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdI golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= +google.golang.org/genproto v0.0.0-20230920204549-e6e6cdab5c13 h1:vlzZttNJGVqTsRFU9AmdnrcO1Znh8Ew9kCD//yjigk0= +google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY= +google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc= +google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A= +google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU= From effc1a31ace47aa74f68cb40f232a96799aa151e Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 22 Dec 2025 17:24:51 -0500 Subject: [PATCH 02/27] Update readme Former-commit-id: 44282226b4124dbe3d16b308fc44fc3079231229 --- README.md | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/README.md b/README.md index 97d0f66..0d7847e 100644 --- a/README.md +++ b/README.md @@ -20,13 +20,7 @@ When Olm receives WireGuard control messages, it will use the information encode ## Hole Punching -In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to newt. If you want to disable hole punching, use the `--disable-holepunch` flag. Hole punching attempts to orchestrate a NAT hole punch between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil. - -Right now, basic NAT hole punching is supported. We plan to add: - -- [ ] Birthday paradox -- [ ] UPnP -- [ ] LAN detection +In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to Newt. Hole punching attempts to orchestrate a NAT traversal between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil. ## Build From 4e3e8242761bfe0c4b15f5ff68ee629b21b1b28c Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 22 Dec 2025 21:32:59 -0500 Subject: [PATCH 03/27] Fix latest Former-commit-id: 6fcd8ac6cb03ef791bbf2a979c93595e38cd0054 --- .github/workflows/cicd.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 989e68c..193c1ba 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -273,7 +273,7 @@ jobs: tags: | type=semver,pattern={{version}},value=${{ env.TAG }} type=semver,pattern={{major}}.{{minor}},value=${{ env.TAG }},enable=${{ env.PUBLISH_MINOR == 'true' && env.IS_RC != 'true' }} - type=raw,value=latest,enable=${{ env.PUBLISH_LATEST == 'true' && env.IS_RC != 'true' }} + type=raw,value=latest,enable=${{ env.IS_RC != 'true' }} flavor: | latest=false labels: | From 385c64c364d5e67bcf1a59afaec5d4ef7f58c494 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 23 Dec 2025 17:54:04 -0500 Subject: [PATCH 04/27] Dont run on v tags Former-commit-id: 69a00b6231c0947997d41b66b7df5ac17b350c72 --- .github/workflows/cicd.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 193c1ba..c44a2d7 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -11,7 +11,9 @@ permissions: on: push: tags: - - "*" + - "[0-9]+.[0-9]+.[0-9]+" + - "[0-9]+.[0-9]+.[0-9]+-rc.[0-9]+" + workflow_dispatch: inputs: version: From 88cc57bcefad2b5e617fe51cf7e8305449b21db7 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 23 Dec 2025 18:00:15 -0500 Subject: [PATCH 05/27] Update mod Former-commit-id: 1b474ebc1cb12156ddb1f0c57d553b129895a0a3 --- go.mod | 47 +-------------------------- go.sum | 100 ++------------------------------------------------------- 2 files changed, 3 insertions(+), 144 deletions(-) diff --git a/go.mod b/go.mod index 11cb67a..4f42df6 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.25 require ( github.com/Microsoft/go-winio v0.6.2 - github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06 + github.com/fosrl/newt v1.8.0 github.com/godbus/dbus/v5 v5.2.0 github.com/gorilla/websocket v1.5.3 github.com/miekg/dns v1.1.68 @@ -16,62 +16,17 @@ require ( ) require ( - github.com/beorn7/perks v1.0.1 // indirect - github.com/cenkalti/backoff/v5 v5.0.3 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/containerd/errdefs v0.3.0 // indirect - github.com/containerd/errdefs/pkg v0.3.0 // indirect - github.com/distribution/reference v0.6.0 // indirect - github.com/docker/docker v28.5.2+incompatible // indirect - github.com/docker/go-connections v0.6.0 // indirect - github.com/docker/go-units v0.4.0 // indirect - github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/go-logr/logr v1.4.3 // indirect - github.com/go-logr/stdr v1.2.2 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/go-cmp v0.7.0 // indirect - github.com/google/uuid v1.6.0 // indirect - github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc // indirect - github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect - github.com/moby/docker-image-spec v1.3.1 // indirect - github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.0 // indirect - github.com/pkg/errors v0.9.1 // indirect - github.com/prometheus/client_golang v1.23.2 // indirect - github.com/prometheus/client_model v0.6.2 // indirect - github.com/prometheus/common v0.66.1 // indirect - github.com/prometheus/otlptranslator v0.0.2 // indirect - github.com/prometheus/procfs v0.17.0 // indirect github.com/vishvananda/netlink v1.3.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect - go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0 // indirect - go.opentelemetry.io/otel v1.38.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 // indirect - go.opentelemetry.io/otel/exporters/prometheus v0.60.0 // indirect - go.opentelemetry.io/otel/metric v1.38.0 // indirect - go.opentelemetry.io/otel/sdk v1.38.0 // indirect - go.opentelemetry.io/otel/sdk/metric v1.38.0 // indirect - go.opentelemetry.io/otel/trace v1.38.0 // indirect - go.opentelemetry.io/proto/otlp v1.7.1 // indirect - go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect golang.org/x/mod v0.30.0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.18.0 // indirect - golang.org/x/text v0.31.0 // indirect golang.org/x/time v0.12.0 // indirect golang.org/x/tools v0.39.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect - google.golang.org/grpc v1.76.0 // indirect - google.golang.org/protobuf v1.36.8 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 66084df..a543b5a 100644 --- a/go.sum +++ b/go.sum @@ -1,103 +1,21 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= -github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= -github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= -github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/containerd/errdefs v0.3.0 h1:FSZgGOeK4yuT/+DnF07/Olde/q4KBoMsaamhXxIMDp4= -github.com/containerd/errdefs v0.3.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= -github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= -github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk= -github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= -github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= -github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM= -github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= -github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= -github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= -github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= -github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= -github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01 h1:VpuI42l4enih//6IFFQDln/B7WukfMePxIRIpXsNe/0= -github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= -github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06 h1:xWuCn+gzX0W7bHs/cV/ykNBliisNzNomPR76E4M0dtI= -github.com/fosrl/newt v0.0.0-20251222211541-80ae03997a06/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= -github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= -github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= -github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/fosrl/newt v1.8.0 h1:wIRCO2shhCpkFzsbNbb4g2LC7mPzIpp2ialNveBMJy4= +github.com/fosrl/newt v1.8.0/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc h1:GN2Lv3MGO7AS6PrRoT6yV5+wkrOpcszoIsO4+4ds248= -github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= -github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= -github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= -github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= -github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= -github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= -github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= -github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= -github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= -github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= -github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= -github.com/prometheus/otlptranslator v0.0.2 h1:+1CdeLVrRQ6Psmhnobldo0kTp96Rj80DRXRd5OSnMEQ= -github.com/prometheus/otlptranslator v0.0.2/go.mod h1:P8AwMgdD7XEr6QRUJ2QWLpiAZTgTE2UYgjlu3svompI= -github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= -github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg= -go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0 h1:PeBoRj6af6xMI7qCupwFvTbbnd49V7n5YpG6pg8iDYQ= -go.opentelemetry.io/contrib/instrumentation/runtime v0.63.0/go.mod h1:ingqBCtMCe8I4vpz/UVzCW6sxoqgZB37nao91mLQ3Bw= -go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= -go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= -go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 h1:vl9obrcoWVKp/lwl8tRE33853I8Xru9HFbw/skNeLs8= -go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0/go.mod h1:GAXRxmLJcVM3u22IjTg74zWBrRCKq8BnOqUVLodpcpw= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk= -go.opentelemetry.io/otel/exporters/prometheus v0.60.0 h1:cGtQxGvZbnrWdC2GyjZi0PDKVSLWP/Jocix3QWfXtbo= -go.opentelemetry.io/otel/exporters/prometheus v0.60.0/go.mod h1:hkd1EekxNo69PTV4OWFGZcKQiIqg0RfuWExcPKFvepk= -go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= -go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= -go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= -go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= -go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= -go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= -go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= -go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= -go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4= -go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE= -go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= -go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= @@ -112,8 +30,6 @@ golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= @@ -126,18 +42,6 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdI golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= -google.golang.org/genproto v0.0.0-20230920204549-e6e6cdab5c13 h1:vlzZttNJGVqTsRFU9AmdnrcO1Znh8Ew9kCD//yjigk0= -google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY= -google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc= -google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A= -google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c= -google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= -google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU= From 148f5fde23ee2f3aff8cfbb452f99451bdf16305 Mon Sep 17 00:00:00 2001 From: Varun Narravula Date: Tue, 23 Dec 2025 15:33:04 -0800 Subject: [PATCH 06/27] fix(ci): add back missing docker build local image rule Former-commit-id: 6d2afb4c72f7956ccb9509e8aed018636070d1d7 --- .github/workflows/test.yml | 2 +- Makefile | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2349f3a..6fe7514 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,4 +22,4 @@ jobs: run: make go-build-release - name: Build Docker image - run: make docker-build-release + run: make docker-build diff --git a/Makefile b/Makefile index 8eed5c2..55ebf81 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,9 @@ all: local local: CGO_ENABLED=0 go build -o ./bin/olm +docker-build: + docker build -t fosrl/olm:latest . + docker-build-release: @if [ -z "$(tag)" ]; then \ echo "Error: tag is required. Usage: make docker-build-release tag="; \ From f8dc1342103a74fda006c05104eb03b5373acf9e Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Mon, 29 Dec 2025 17:28:12 -0500 Subject: [PATCH 07/27] add content-length header to status payload Former-commit-id: 8152d4133f1ae85b2632c48983aeb3ea68f0fd2a --- api/api.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/api/api.go b/api/api.go index 787f958..a7e2f24 100644 --- a/api/api.go +++ b/api/api.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "strconv" "sync" "time" @@ -358,7 +359,6 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { } s.statusMu.RLock() - defer s.statusMu.RUnlock() resp := StatusResponse{ Connected: s.isConnected, @@ -371,8 +371,18 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { NetworkSettings: network.GetSettings(), } + s.statusMu.RUnlock() + + data, err := json.Marshal(resp) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + w.Header().Set("Content-Length", strconv.Itoa(len(data))) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(data) } // handleHealth handles the /health endpoint From 28910ce1880dc97a18be9ab5535cfb6e89db28a5 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 29 Dec 2025 17:49:58 -0500 Subject: [PATCH 08/27] Add stub Former-commit-id: ece4239aaa70318a0246c0b2e17b4e3e8d306e7d --- dns/override/dns_override_android.go | 18 ++++++++++++++++++ dns/override/dns_override_ios.go | 17 +++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 dns/override/dns_override_android.go create mode 100644 dns/override/dns_override_ios.go diff --git a/dns/override/dns_override_android.go b/dns/override/dns_override_android.go new file mode 100644 index 0000000..af1d946 --- /dev/null +++ b/dns/override/dns_override_android.go @@ -0,0 +1,18 @@ +//go:build android + +package olm + +import ( + "github.com/fosrl/olm/dns" +) + +// SetupDNSOverride is a no-op on Android +// Android handles DNS through the VpnService API at the Java/Kotlin layer +func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { + return nil +} + +// RestoreDNSOverride is a no-op on Android +func RestoreDNSOverride() error { + return nil +} \ No newline at end of file diff --git a/dns/override/dns_override_ios.go b/dns/override/dns_override_ios.go new file mode 100644 index 0000000..109d471 --- /dev/null +++ b/dns/override/dns_override_ios.go @@ -0,0 +1,17 @@ +//go:build ios + +package olm + +import ( + "github.com/fosrl/olm/dns" +) + +// SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system +func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { + return nil +} + +// RestoreDNSOverride is a no-op on iOS as DNS configuration is handled by the system +func RestoreDNSOverride() error { + return nil +} \ No newline at end of file From c56696bab1ce19dc67563f4915521a5e82b60c89 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 30 Dec 2025 16:59:36 -0500 Subject: [PATCH 09/27] Use a different method on android Former-commit-id: adf4c21f7b280f50e5356325b202be2e554d9333 --- api/api.go | 12 ++++++++++++ olm/olm.go | 12 ++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/api/api.go b/api/api.go index a7e2f24..91d9f37 100644 --- a/api/api.go +++ b/api/api.go @@ -102,6 +102,14 @@ func NewAPISocket(socketPath string) *API { return s } +func NewAPIStub() *API { + s := &API{ + peerStatuses: make(map[int]*PeerStatus), + } + + return s +} + // SetHandlers sets the callback functions for handling API requests func (s *API) SetHandlers( onConnect func(ConnectionRequest) error, @@ -117,6 +125,10 @@ func (s *API) SetHandlers( // Start starts the HTTP server func (s *API) Start() error { + if s.socketPath == "" && s.addr == "" { + return fmt.Errorf("either socketPath or addr must be provided to start the API server") + } + mux := http.NewServeMux() mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) diff --git a/olm/olm.go b/olm/olm.go index f84ee4f..9cc1f51 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -111,6 +111,9 @@ func Init(ctx context.Context, config GlobalConfig) { apiServer = api.NewAPI(config.HTTPAddr) } else if config.SocketPath != "" { apiServer = api.NewAPISocket(config.SocketPath) + } else { + // this is so is not null but it cant be started without either the socket path or http addr + apiServer = api.NewAPIStub() } apiServer.SetVersion(config.Version) @@ -304,7 +307,12 @@ func StartTunnel(config TunnelConfig) { tdev, err = func() (tun.Device, error) { if config.FileDescriptorTun != 0 { - return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) + if runtime.GOOS == "android" { // otherwise we get a permission denied + theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(config.FileDescriptorTun)) + return theTun, err + } else { + return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) + } } var ifName = interfaceName if runtime.GOOS == "darwin" { // this is if we dont pass a fd @@ -811,7 +819,7 @@ func StartTunnel(config TunnelConfig) { Endpoint: handshakeData.ExitNode.Endpoint, RelayPort: relayPort, PublicKey: handshakeData.ExitNode.PublicKey, - SiteIds: []int{siteId}, + SiteIds: []int{siteId}, } added := holePunchManager.AddExitNode(exitNode) From cce87424906e6919355a7c1e7de0bd7f57afd53c Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 30 Dec 2025 21:38:07 -0500 Subject: [PATCH 10/27] Try to make the tun replacable Former-commit-id: 6be095888755c22f8cc6621688e75f7c040eaf57 --- device/middle_device.go | 640 +++++++++++++++++++++++++++++----------- device/tun_unix.go | 6 + dns/dns_proxy.go | 11 +- olm/olm.go | 53 +++- 4 files changed, 517 insertions(+), 193 deletions(-) diff --git a/device/middle_device.go b/device/middle_device.go index b031871..2a5d9b9 100644 --- a/device/middle_device.go +++ b/device/middle_device.go @@ -1,9 +1,12 @@ package device import ( + "io" "net/netip" "os" "sync" + "sync/atomic" + "time" "github.com/fosrl/newt/logger" "golang.zx2c4.com/wireguard/tun" @@ -18,14 +21,68 @@ type FilterRule struct { Handler PacketHandler } -// MiddleDevice wraps a TUN device with packet filtering capabilities -type MiddleDevice struct { +// closeAwareDevice wraps a tun.Device along with a flag +// indicating whether its Close method was called. +type closeAwareDevice struct { + isClosed atomic.Bool tun.Device - rules []FilterRule - mutex sync.RWMutex - readCh chan readResult - injectCh chan []byte - closed chan struct{} + closeEventCh chan struct{} + wg sync.WaitGroup + closeOnce sync.Once +} + +func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice { + return &closeAwareDevice{ + Device: tunDevice, + isClosed: atomic.Bool{}, + closeEventCh: make(chan struct{}), + } +} + +// redirectEvents redirects the Events() method of the underlying tun.Device +// to the given channel. +func (c *closeAwareDevice) redirectEvents(out chan tun.Event) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + for { + select { + case ev, ok := <-c.Device.Events(): + if !ok { + return + } + + if ev == tun.EventDown { + continue + } + + select { + case out <- ev: + case <-c.closeEventCh: + return + } + case <-c.closeEventCh: + return + } + } + }() +} + +// Close calls the underlying Device's Close method +// after setting isClosed to true. +func (c *closeAwareDevice) Close() (err error) { + c.closeOnce.Do(func() { + c.isClosed.Store(true) + close(c.closeEventCh) + err = c.Device.Close() + c.wg.Wait() + }) + + return err +} + +func (c *closeAwareDevice) IsClosed() bool { + return c.isClosed.Load() } type readResult struct { @@ -36,58 +93,124 @@ type readResult struct { err error } +// MiddleDevice wraps a TUN device with packet filtering capabilities +// and supports swapping the underlying device. +type MiddleDevice struct { + devices []*closeAwareDevice + mu sync.Mutex + cond *sync.Cond + rules []FilterRule + rulesMutex sync.RWMutex + readCh chan readResult + injectCh chan []byte + closed atomic.Bool + events chan tun.Event +} + // NewMiddleDevice creates a new filtered TUN device wrapper func NewMiddleDevice(device tun.Device) *MiddleDevice { d := &MiddleDevice{ - Device: device, + devices: make([]*closeAwareDevice, 0), rules: make([]FilterRule, 0), - readCh: make(chan readResult), + readCh: make(chan readResult, 16), injectCh: make(chan []byte, 100), - closed: make(chan struct{}), + events: make(chan tun.Event, 16), } - go d.pump() + d.cond = sync.NewCond(&d.mu) + + if device != nil { + d.AddDevice(device) + } + return d } -func (d *MiddleDevice) pump() { +// AddDevice adds a new underlying TUN device, closing any previous one +func (d *MiddleDevice) AddDevice(device tun.Device) { + d.mu.Lock() + if d.closed.Load() { + d.mu.Unlock() + _ = device.Close() + return + } + + var toClose *closeAwareDevice + if len(d.devices) > 0 { + toClose = d.devices[len(d.devices)-1] + } + + cad := newCloseAwareDevice(device) + cad.redirectEvents(d.events) + + d.devices = []*closeAwareDevice{cad} + + // Start pump for the new device + go d.pump(cad) + + d.cond.Broadcast() + d.mu.Unlock() + + if toClose != nil { + logger.Debug("MiddleDevice: Closing previous device") + if err := toClose.Close(); err != nil { + logger.Debug("MiddleDevice: Error closing previous device: %v", err) + } + } +} + +func (d *MiddleDevice) pump(dev *closeAwareDevice) { const defaultOffset = 16 - batchSize := d.Device.BatchSize() - logger.Debug("MiddleDevice: pump started") + batchSize := dev.BatchSize() + logger.Debug("MiddleDevice: pump started for device") for { - // Check closed first with priority - select { - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel") + // Check if this device is closed + if dev.IsClosed() { + logger.Debug("MiddleDevice: pump exiting, device is closed") + return + } + + // Check if MiddleDevice itself is closed + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed") return - default: } // Allocate buffers for reading - // We allocate new buffers for each read to avoid race conditions - // since we pass them to the channel bufs := make([][]byte, batchSize) sizes := make([]int, batchSize) for i := range bufs { bufs[i] = make([]byte, 2048) // Standard MTU + headroom } - n, err := d.Device.Read(bufs, sizes, defaultOffset) + n, err := dev.Read(bufs, sizes, defaultOffset) - // Check closed again after read returns - select { - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)") + // Check if device was closed during read + if dev.IsClosed() { + logger.Debug("MiddleDevice: pump exiting, device closed during read") return - default: } - // Now try to send the result + // Check if MiddleDevice was closed during read + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read") + return + } + + // Try to send the result select { case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)") - return + default: + // Channel full, check if we should exit + if dev.IsClosed() || d.closed.Load() { + return + } + // Try again with blocking + select { + case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: + case <-dev.closeEventCh: + return + } } if err != nil { @@ -99,16 +222,21 @@ func (d *MiddleDevice) pump() { // InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN) func (d *MiddleDevice) InjectOutbound(packet []byte) { + if d.closed.Load() { + return + } select { case d.injectCh <- packet: - case <-d.closed: + default: + // Channel full, drop packet + logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full") } } // AddRule adds a packet filtering rule func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) { - d.mutex.Lock() - defer d.mutex.Unlock() + d.rulesMutex.Lock() + defer d.rulesMutex.Unlock() d.rules = append(d.rules, FilterRule{ DestIP: destIP, Handler: handler, @@ -117,8 +245,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) { // RemoveRule removes all rules for a given destination IP func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { - d.mutex.Lock() - defer d.mutex.Unlock() + d.rulesMutex.Lock() + defer d.rulesMutex.Unlock() newRules := make([]FilterRule, 0, len(d.rules)) for _, rule := range d.rules { if rule.DestIP != destIP { @@ -130,18 +258,113 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { // Close stops the device func (d *MiddleDevice) Close() error { - select { - case <-d.closed: - // Already closed - return nil - default: - logger.Debug("MiddleDevice: Closing, signaling closed channel") - close(d.closed) + if !d.closed.CompareAndSwap(false, true) { + return nil // already closed } - logger.Debug("MiddleDevice: Closing underlying TUN device") - err := d.Device.Close() - logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err) - return err + + d.mu.Lock() + devices := d.devices + d.devices = nil + d.cond.Broadcast() + d.mu.Unlock() + + var lastErr error + logger.Debug("MiddleDevice: Closing %d devices", len(devices)) + for _, device := range devices { + if err := device.Close(); err != nil { + logger.Debug("MiddleDevice: Error closing device: %v", err) + lastErr = err + } + } + + close(d.events) + return lastErr +} + +// Events returns the events channel +func (d *MiddleDevice) Events() <-chan tun.Event { + return d.events +} + +// File returns the underlying file descriptor +func (d *MiddleDevice) File() *os.File { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return nil + } + continue + } + + file := dev.File() + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return file + } +} + +// MTU returns the MTU of the underlying device +func (d *MiddleDevice) MTU() (int, error) { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } + continue + } + + mtu, err := dev.MTU() + if err == nil { + return mtu, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return 0, err + } +} + +// Name returns the name of the underlying device +func (d *MiddleDevice) Name() (string, error) { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return "", io.EOF + } + continue + } + + name, err := dev.Name() + if err == nil { + return name, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return "", err + } +} + +// BatchSize returns the batch size +func (d *MiddleDevice) BatchSize() int { + dev := d.peekLast() + if dev == nil { + return 1 + } + return dev.BatchSize() } // extractDestIP extracts destination IP from packet (fast path) @@ -176,156 +399,231 @@ func extractDestIP(packet []byte) (netip.Addr, bool) { // Read intercepts packets going UP from the TUN device (towards WireGuard) func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { - // Check if already closed first (non-blocking) - select { - case <-d.closed: - logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)") - return 0, os.ErrClosed - default: - } - - // Now block waiting for data - select { - case res := <-d.readCh: - if res.err != nil { - logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err) - return 0, res.err + for { + if d.closed.Load() { + logger.Debug("MiddleDevice: Read returning io.EOF, device closed") + return 0, io.EOF } - // Copy packets from result to provided buffers - count := 0 - for i := 0; i < res.n && i < len(bufs); i++ { - // Handle offset mismatch if necessary - // We assume the pump used defaultOffset (16) - // If caller asks for different offset, we need to shift - src := res.bufs[i] - srcOffset := res.offset - srcSize := res.sizes[i] - - // Calculate where the packet data starts and ends in src - pktData := src[srcOffset : srcOffset+srcSize] - - // Ensure dest buffer is large enough - if len(bufs[i]) < offset+len(pktData) { - continue // Skip if buffer too small + // Wait for a device to be available + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF } - - copy(bufs[i][offset:], pktData) - sizes[i] = len(pktData) - count++ - } - n = count - - case pkt := <-d.injectCh: - if len(bufs) == 0 { - return 0, nil - } - if len(bufs[0]) < offset+len(pkt) { - return 0, nil // Buffer too small - } - copy(bufs[0][offset:], pkt) - sizes[0] = len(pkt) - n = 1 - - case <-d.closed: - logger.Debug("MiddleDevice: Read returning os.ErrClosed") - return 0, os.ErrClosed // Signal that device is closed - } - - d.mutex.RLock() - rules := d.rules - d.mutex.RUnlock() - - if len(rules) == 0 { - return n, nil - } - - // Process packets and filter out handled ones - writeIdx := 0 - for readIdx := 0; readIdx < n; readIdx++ { - packet := bufs[readIdx][offset : offset+sizes[readIdx]] - - destIP, ok := extractDestIP(packet) - if !ok { - // Can't parse, keep packet - if writeIdx != readIdx { - bufs[writeIdx] = bufs[readIdx] - sizes[writeIdx] = sizes[readIdx] - } - writeIdx++ continue } - // Check if packet matches any rule - handled := false - for _, rule := range rules { - if rule.DestIP == destIP { - if rule.Handler(packet) { - // Packet was handled and should be dropped - handled = true - break + // Now block waiting for data from readCh or injectCh + select { + case res := <-d.readCh: + if res.err != nil { + // Check if device was swapped + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err) + return 0, res.err + } + + // Copy packets from result to provided buffers + count := 0 + for i := 0; i < res.n && i < len(bufs); i++ { + src := res.bufs[i] + srcOffset := res.offset + srcSize := res.sizes[i] + + pktData := src[srcOffset : srcOffset+srcSize] + + if len(bufs[i]) < offset+len(pktData) { + continue + } + + copy(bufs[i][offset:], pktData) + sizes[i] = len(pktData) + count++ + } + n = count + + case pkt := <-d.injectCh: + if len(bufs) == 0 { + return 0, nil + } + if len(bufs[0]) < offset+len(pkt) { + return 0, nil + } + copy(bufs[0][offset:], pkt) + sizes[0] = len(pkt) + n = 1 + } + + // Apply filtering rules + d.rulesMutex.RLock() + rules := d.rules + d.rulesMutex.RUnlock() + + if len(rules) == 0 { + return n, nil + } + + // Process packets and filter out handled ones + writeIdx := 0 + for readIdx := 0; readIdx < n; readIdx++ { + packet := bufs[readIdx][offset : offset+sizes[readIdx]] + + destIP, ok := extractDestIP(packet) + if !ok { + if writeIdx != readIdx { + bufs[writeIdx] = bufs[readIdx] + sizes[writeIdx] = sizes[readIdx] + } + writeIdx++ + continue + } + + handled := false + for _, rule := range rules { + if rule.DestIP == destIP { + if rule.Handler(packet) { + handled = true + break + } } } - } - if !handled { - // Keep packet - if writeIdx != readIdx { - bufs[writeIdx] = bufs[readIdx] - sizes[writeIdx] = sizes[readIdx] + if !handled { + if writeIdx != readIdx { + bufs[writeIdx] = bufs[readIdx] + sizes[writeIdx] = sizes[readIdx] + } + writeIdx++ } - writeIdx++ } - } - return writeIdx, err + return writeIdx, nil + } } // Write intercepts packets going DOWN to the TUN device (from WireGuard) func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) { - d.mutex.RLock() - rules := d.rules - d.mutex.RUnlock() + for { + if d.closed.Load() { + return 0, io.EOF + } - if len(rules) == 0 { - return d.Device.Write(bufs, offset) - } - - // Filter packets going down - filteredBufs := make([][]byte, 0, len(bufs)) - for _, buf := range bufs { - if len(buf) <= offset { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } continue } - packet := buf[offset:] - destIP, ok := extractDestIP(packet) - if !ok { - // Can't parse, keep packet - filteredBufs = append(filteredBufs, buf) - continue - } + d.rulesMutex.RLock() + rules := d.rules + d.rulesMutex.RUnlock() - // Check if packet matches any rule - handled := false - for _, rule := range rules { - if rule.DestIP == destIP { - if rule.Handler(packet) { - // Packet was handled and should be dropped - handled = true - break + var filteredBufs [][]byte + if len(rules) == 0 { + filteredBufs = bufs + } else { + filteredBufs = make([][]byte, 0, len(bufs)) + for _, buf := range bufs { + if len(buf) <= offset { + continue + } + + packet := buf[offset:] + destIP, ok := extractDestIP(packet) + if !ok { + filteredBufs = append(filteredBufs, buf) + continue + } + + handled := false + for _, rule := range rules { + if rule.DestIP == destIP { + if rule.Handler(packet) { + handled = true + break + } + } + } + + if !handled { + filteredBufs = append(filteredBufs, buf) } } } - if !handled { - filteredBufs = append(filteredBufs, buf) + if len(filteredBufs) == 0 { + return len(bufs), nil } - } - if len(filteredBufs) == 0 { - return len(bufs), nil // All packets were handled - } + n, err := dev.Write(filteredBufs, offset) + if err == nil { + return n, nil + } - return d.Device.Write(filteredBufs, offset) + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } } + +func (d *MiddleDevice) waitForDevice() bool { + d.mu.Lock() + defer d.mu.Unlock() + + for len(d.devices) == 0 && !d.closed.Load() { + d.cond.Wait() + } + return !d.closed.Load() +} + +func (d *MiddleDevice) peekLast() *closeAwareDevice { + d.mu.Lock() + defer d.mu.Unlock() + + if len(d.devices) == 0 { + return nil + } + + return d.devices[len(d.devices)-1] +} + +// WriteToTun writes packets directly to the underlying TUN device, +// bypassing WireGuard. This is useful for sending packets that should +// appear to come from the TUN interface (e.g., DNS responses from a proxy). +// Unlike Write(), this does not go through packet filtering rules. +func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) { + for { + if d.closed.Load() { + return 0, io.EOF + } + + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } + continue + } + + n, err := dev.Write(bufs, offset) + if err == nil { + return n, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } +} \ No newline at end of file diff --git a/device/tun_unix.go b/device/tun_unix.go index c9bab60..22cec13 100644 --- a/device/tun_unix.go +++ b/device/tun_unix.go @@ -5,6 +5,7 @@ package device import ( "net" "os" + "runtime" "github.com/fosrl/newt/logger" "golang.org/x/sys/unix" @@ -13,6 +14,11 @@ import ( ) func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { + if runtime.GOOS == "android" { // otherwise we get a permission denied + theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd)) + return theTun, err + } + dupTunFd, err := unix.Dup(int(tunFd)) if err != nil { logger.Error("Unable to dup tun fd: %v", err) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 6d56379..748a5a9 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -12,7 +12,6 @@ import ( "github.com/fosrl/newt/util" "github.com/fosrl/olm/device" "github.com/miekg/dns" - "golang.zx2c4.com/wireguard/tun" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" @@ -36,8 +35,7 @@ type DNSProxy struct { upstreamDNS []string tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally mtu int - tunDevice tun.Device // Direct reference to underlying TUN device for responses - middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering + middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes recordStore *DNSRecordStore // Local DNS records // Tunnel DNS fields - for sending queries over WireGuard @@ -53,7 +51,7 @@ type DNSProxy struct { } // NewDNSProxy creates a new DNS proxy -func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) { +func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) { proxyIP, err := PickIPFromSubnet(utilitySubnet) if err != nil { return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err) @@ -68,7 +66,6 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in proxy := &DNSProxy{ proxyIP: proxyIP, mtu: mtu, - tunDevice: tunDevice, middleDevice: middleDevice, upstreamDNS: upstreamDns, tunnelDNS: tunnelDns, @@ -694,9 +691,9 @@ func (p *DNSProxy) runPacketSender() { pos += len(slice) } - // Write packet to TUN device + // Write packet to TUN device via MiddleDevice // offset=16 indicates packet data starts at position 16 in the buffer - _, err := p.tunDevice.Write([][]byte{buf}, offset) + _, err := p.middleDevice.WriteToTun([][]byte{buf}, offset) if err != nil { logger.Error("Failed to write DNS response to TUN: %v", err) } diff --git a/olm/olm.go b/olm/olm.go index 9cc1f51..a3bb694 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -35,6 +35,7 @@ var ( uapiListener net.Listener tdev tun.Device middleDev *olmDevice.MiddleDevice + interfaceName string dnsProxy *dns.DNSProxy apiServer *api.API olmClient *websocket.Client @@ -237,11 +238,11 @@ func StartTunnel(config TunnelConfig) { stopPing = make(chan struct{}) var ( - interfaceName = config.InterfaceName - id = config.ID - secret = config.Secret - userToken = config.UserToken + id = config.ID + secret = config.Secret + userToken = config.UserToken ) + interfaceName = config.InterfaceName apiServer.SetOrgID(config.OrgID) @@ -307,12 +308,7 @@ func StartTunnel(config TunnelConfig) { tdev, err = func() (tun.Device, error) { if config.FileDescriptorTun != 0 { - if runtime.GOOS == "android" { // otherwise we get a permission denied - theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(config.FileDescriptorTun)) - return theTun, err - } else { - return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) - } + return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) } var ifName = interfaceName if runtime.GOOS == "darwin" { // this is if we dont pass a fd @@ -329,11 +325,11 @@ func StartTunnel(config TunnelConfig) { return } - if config.FileDescriptorTun == 0 { - if realInterfaceName, err2 := tdev.Name(); err2 == nil { - interfaceName = realInterfaceName - } + // if config.FileDescriptorTun == 0 { + if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? + interfaceName = realInterfaceName } + // } // Wrap TUN device with packet filter for DNS proxy middleDev = olmDevice.NewMiddleDevice(tdev) @@ -389,7 +385,7 @@ func StartTunnel(config TunnelConfig) { } // Create and start DNS proxy - dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP) + dnsProxy, err = dns.NewDNSProxy(middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP) if err != nil { logger.Error("Failed to create DNS proxy: %v", err) } @@ -956,6 +952,33 @@ func StartTunnel(config TunnelConfig) { logger.Info("Tunnel process context cancelled, cleaning up") } +func AddDevice(fd uint32) { + if middleDev == nil { + logger.Error("MiddleDevice is nil, cannot add device") + return + } + + if tunnelConfig.MTU == 0 { + logger.Error("No MTU configured, cannot create device") + return + } + + tdev, err := olmDevice.CreateTUNFromFD(fd, tunnelConfig.MTU) + + if err != nil { + logger.Error("Failed to create TUN device: %v", err) + return + } + + // if config.FileDescriptorTun == 0 { + if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? + interfaceName = realInterfaceName + } + + // Here we replace the existing TUN device in the middle device with the new one + middleDev.AddDevice(tdev) +} + func Close() { // Restore original DNS configuration // we do this first to avoid any DNS issues if something else gets stuck From f08b17c7bd2043742f097742c5c801cd5e7a643c Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 31 Dec 2025 11:22:09 -0500 Subject: [PATCH 11/27] Middle device working but not closing Former-commit-id: c85fcc434ba4059a2952ecd0a3d54916f8bebc29 --- olm/olm.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index a3bb694..38d3324 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -952,22 +952,20 @@ func StartTunnel(config TunnelConfig) { logger.Info("Tunnel process context cancelled, cleaning up") } -func AddDevice(fd uint32) { +func AddDevice(fd uint32) error { if middleDev == nil { - logger.Error("MiddleDevice is nil, cannot add device") - return + return fmt.Errorf("middle device is not initialized") } if tunnelConfig.MTU == 0 { - logger.Error("No MTU configured, cannot create device") - return + // error + return fmt.Errorf("tunnel MTU is not set") } tdev, err := olmDevice.CreateTUNFromFD(fd, tunnelConfig.MTU) if err != nil { - logger.Error("Failed to create TUN device: %v", err) - return + return fmt.Errorf("failed to create TUN device from fd: %v", err) } // if config.FileDescriptorTun == 0 { @@ -977,6 +975,8 @@ func AddDevice(fd uint32) { // Here we replace the existing TUN device in the middle device with the new one middleDev.AddDevice(tdev) + + return nil } func Close() { From aeb908b68cb52c35a50925261ad6b8ad0836a093 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 31 Dec 2025 11:33:00 -0500 Subject: [PATCH 12/27] Exiting the middle device works now? Former-commit-id: d76b3c366f4f97865d993d5c579dce6a79d6891a --- device/middle_device.go | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/device/middle_device.go b/device/middle_device.go index 2a5d9b9..7dfbec8 100644 --- a/device/middle_device.go +++ b/device/middle_device.go @@ -163,6 +163,13 @@ func (d *MiddleDevice) pump(dev *closeAwareDevice) { batchSize := dev.BatchSize() logger.Debug("MiddleDevice: pump started for device") + // Recover from panic if readCh is closed while we're trying to send + defer func() { + if r := recover(); r != nil { + logger.Debug("MiddleDevice: pump recovered from panic (channel closed)") + } + }() + for { // Check if this device is closed if dev.IsClosed() { @@ -197,7 +204,12 @@ func (d *MiddleDevice) pump(dev *closeAwareDevice) { return } - // Try to send the result + // Try to send the result - check closed state first to avoid sending on closed channel + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, device closed before send") + return + } + select { case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: default: @@ -225,6 +237,13 @@ func (d *MiddleDevice) InjectOutbound(packet []byte) { if d.closed.Load() { return } + // Use defer/recover to handle panic from sending on closed channel + // This can happen during shutdown race conditions + defer func() { + if r := recover(); r != nil { + logger.Debug("MiddleDevice: InjectOutbound recovered from panic (channel closed)") + } + }() select { case d.injectCh <- packet: default: @@ -268,6 +287,8 @@ func (d *MiddleDevice) Close() error { d.cond.Broadcast() d.mu.Unlock() + // Close underlying devices first - this causes the pump goroutines to exit + // when their read operations return errors var lastErr error logger.Debug("MiddleDevice: Closing %d devices", len(devices)) for _, device := range devices { @@ -277,7 +298,12 @@ func (d *MiddleDevice) Close() error { } } + // Now close channels to unblock any remaining readers + // The pump should have exited by now, but close channels to be safe + close(d.readCh) + close(d.injectCh) close(d.events) + return lastErr } @@ -416,7 +442,11 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err // Now block waiting for data from readCh or injectCh select { - case res := <-d.readCh: + case res, ok := <-d.readCh: + if !ok { + // Channel closed, device is shutting down + return 0, io.EOF + } if res.err != nil { // Check if device was swapped if dev.IsClosed() { @@ -446,7 +476,11 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err } n = count - case pkt := <-d.injectCh: + case pkt, ok := <-d.injectCh: + if !ok { + // Channel closed, device is shutting down + return 0, io.EOF + } if len(bufs) == 0 { return 0, nil } From 1b43f029a94fc3284880d28b8d28668fdca775a2 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 31 Dec 2025 15:42:51 -0500 Subject: [PATCH 13/27] Dont pass in dns proxy to override Former-commit-id: 51dd927f9b1a0d74bedaf1b0f34b046f4a937dbd --- dns/override/dns_override_android.go | 6 ++---- dns/override/dns_override_darwin.go | 9 ++------- dns/override/dns_override_ios.go | 6 ++---- dns/override/dns_override_unix.go | 19 +++++++------------ dns/override/dns_override_windows.go | 9 ++------- olm/olm.go | 6 ++++-- 6 files changed, 19 insertions(+), 36 deletions(-) diff --git a/dns/override/dns_override_android.go b/dns/override/dns_override_android.go index af1d946..d3fd78e 100644 --- a/dns/override/dns_override_android.go +++ b/dns/override/dns_override_android.go @@ -2,13 +2,11 @@ package olm -import ( - "github.com/fosrl/olm/dns" -) +import "net/netip" // SetupDNSOverride is a no-op on Android // Android handles DNS through the VpnService API at the Java/Kotlin layer -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { return nil } diff --git a/dns/override/dns_override_darwin.go b/dns/override/dns_override_darwin.go index 6ccc3fb..c1c3789 100644 --- a/dns/override/dns_override_darwin.go +++ b/dns/override/dns_override_darwin.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on macOS // Uses scutil for DNS configuration -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error configurator, err = platform.NewDarwinDNSConfigurator() if err != nil { @@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/dns/override/dns_override_ios.go b/dns/override/dns_override_ios.go index 109d471..6c95c71 100644 --- a/dns/override/dns_override_ios.go +++ b/dns/override/dns_override_ios.go @@ -2,12 +2,10 @@ package olm -import ( - "github.com/fosrl/olm/dns" -) +import "net/netip" // SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { return nil } diff --git a/dns/override/dns_override_unix.go b/dns/override/dns_override_unix.go index c3b31e8..12cb692 100644 --- a/dns/override/dns_override_unix.go +++ b/dns/override/dns_override_unix.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD // Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error // Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability @@ -32,7 +27,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName) if err == nil { logger.Info("Using systemd-resolved DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err) @@ -40,7 +35,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName) if err == nil { logger.Info("Using NetworkManager DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err) @@ -48,7 +43,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName) if err == nil { logger.Info("Using resolvconf DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create resolvconf configurator: %v, falling back", err) } @@ -60,11 +55,11 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { } logger.Info("Using file-based DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } // setDNS is a helper function to set DNS and log the results -func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error { +func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error { // Get current DNS servers before changing currentDNS, err := conf.GetCurrentDNS() if err != nil { @@ -75,7 +70,7 @@ func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/dns/override/dns_override_windows.go b/dns/override/dns_override_windows.go index a564079..16bbca1 100644 --- a/dns/override/dns_override_windows.go +++ b/dns/override/dns_override_windows.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on Windows // Uses registry-based configuration (automatically extracts interface GUID) -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error configurator, err = platform.NewWindowsDNSConfigurator(interfaceName) if err != nil { @@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/olm/olm.go b/olm/olm.go index 38d3324..4d12952 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -439,10 +439,12 @@ func StartTunnel(config TunnelConfig) { if config.OverrideDNS { // Set up DNS override to use our DNS proxy - if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy); err != nil { + if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy.GetProxyIP()); err != nil { logger.Error("Failed to setup DNS override: %v", err) return } + + network.SetDNSServers([]string{dnsProxy.GetProxyIP().String()}) } apiServer.SetRegistered(true) @@ -975,7 +977,7 @@ func AddDevice(fd uint32) error { // Here we replace the existing TUN device in the middle device with the new one middleDev.AddDevice(tdev) - + return nil } From 83edde34494264e0917f134086cb49ae05d72d82 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 31 Dec 2025 18:01:25 -0500 Subject: [PATCH 14/27] Fix build on darwin Former-commit-id: fbeb5be88d9a2c126c232e0df78efef4fa51ead8 --- device/tun_darwin.go | 44 ++++++++++++++++++++++++++++ device/{tun_unix.go => tun_linux.go} | 2 +- 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 device/tun_darwin.go rename device/{tun_unix.go => tun_linux.go} (98%) diff --git a/device/tun_darwin.go b/device/tun_darwin.go new file mode 100644 index 0000000..f763f74 --- /dev/null +++ b/device/tun_darwin.go @@ -0,0 +1,44 @@ +//go:build darwin + +package device + +import ( + "net" + "os" + + "github.com/fosrl/newt/logger" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/tun" +) + +func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { + dupTunFd, err := unix.Dup(int(tunFd)) + if err != nil { + logger.Error("Unable to dup tun fd: %v", err) + return nil, err + } + + err = unix.SetNonblock(dupTunFd, true) + if err != nil { + unix.Close(dupTunFd) + return nil, err + } + + file := os.NewFile(uintptr(dupTunFd), "/dev/tun") + device, err := tun.CreateTUNFromFile(file, mtuInt) + if err != nil { + file.Close() + return nil, err + } + + return device, nil +} + +func UapiOpen(interfaceName string) (*os.File, error) { + return ipc.UAPIOpen(interfaceName) +} + +func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { + return ipc.UAPIListen(interfaceName, fileUAPI) +} diff --git a/device/tun_unix.go b/device/tun_linux.go similarity index 98% rename from device/tun_unix.go rename to device/tun_linux.go index 22cec13..902f269 100644 --- a/device/tun_unix.go +++ b/device/tun_linux.go @@ -1,4 +1,4 @@ -//go:build !windows +//go:build linux package device From 1ed27fec1a09beaf1d4190d129878903e7547f0b Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Thu, 1 Jan 2026 17:38:01 -0500 Subject: [PATCH 15/27] set mtu to 0 on darwin Former-commit-id: fbe686961ed233f1e50cc4ccb46336e7e5938c8d --- device/tun_darwin.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/device/tun_darwin.go b/device/tun_darwin.go index f763f74..df87d53 100644 --- a/device/tun_darwin.go +++ b/device/tun_darwin.go @@ -26,7 +26,7 @@ func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { } file := os.NewFile(uintptr(dupTunFd), "/dev/tun") - device, err := tun.CreateTUNFromFile(file, mtuInt) + device, err := tun.CreateTUNFromFile(file, 0) if err != nil { file.Close() return nil, err From 7b7eae617a2ac1b297e04a8d6c97597c48571e65 Mon Sep 17 00:00:00 2001 From: Varun Narravula Date: Wed, 31 Dec 2025 15:07:06 -0800 Subject: [PATCH 16/27] chore: format files using gofmt Former-commit-id: 5cfa0dfb9781d57bd369fa3f4631f9721f64a5a9 --- dns/dns_proxy.go | 10 +++--- dns/dns_records.go | 2 +- dns/dns_records_test.go | 68 ++++++++++++++++++++--------------------- dns/platform/darwin.go | 2 +- olm/olm.go | 2 +- 5 files changed, 42 insertions(+), 42 deletions(-) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 6d56379..6c9891a 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -34,18 +34,18 @@ type DNSProxy struct { ep *channel.Endpoint proxyIP netip.Addr upstreamDNS []string - tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally + tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally mtu int tunDevice tun.Device // Direct reference to underlying TUN device for responses middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering recordStore *DNSRecordStore // Local DNS records // Tunnel DNS fields - for sending queries over WireGuard - tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries) - tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries - tunnelEp *channel.Endpoint + tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries) + tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries + tunnelEp *channel.Endpoint tunnelActivePorts map[uint16]bool - tunnelPortsLock sync.Mutex + tunnelPortsLock sync.Mutex ctx context.Context cancel context.CancelFunc diff --git a/dns/dns_records.go b/dns/dns_records.go index ed57b77..5308b0e 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -322,4 +322,4 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool { } return matchWildcardInternal(pattern, domain, pi+1, di+1) -} \ No newline at end of file +} diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go index 0bb18a1..f922afb 100644 --- a/dns/dns_records_test.go +++ b/dns/dns_records_test.go @@ -37,7 +37,7 @@ func TestWildcardMatching(t *testing.T) { domain: "autoco.internal.", expected: false, }, - + // Question mark wildcard tests { name: "host-0?.autoco.internal matches host-01.autoco.internal", @@ -63,7 +63,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-012.autoco.internal.", expected: false, }, - + // Combined wildcard tests { name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal", @@ -83,7 +83,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-01.autoco.internal.", expected: false, }, - + // Multiple asterisks { name: "*.*. autoco.internal matches any.thing.autoco.internal", @@ -97,7 +97,7 @@ func TestWildcardMatching(t *testing.T) { domain: "single.autoco.internal.", expected: false, }, - + // Asterisk in middle { name: "host-*.autoco.internal matches host-anything.autoco.internal", @@ -111,7 +111,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-.autoco.internal.", expected: true, }, - + // Multiple question marks { name: "host-??.autoco.internal matches host-01.autoco.internal", @@ -125,7 +125,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-1.autoco.internal.", expected: false, }, - + // Exact match (no wildcards) { name: "exact.autoco.internal matches exact.autoco.internal", @@ -139,7 +139,7 @@ func TestWildcardMatching(t *testing.T) { domain: "other.autoco.internal.", expected: false, }, - + // Edge cases { name: "* matches anything", @@ -154,7 +154,7 @@ func TestWildcardMatching(t *testing.T) { expected: true, }, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := matchWildcard(tt.pattern, tt.domain) @@ -167,21 +167,21 @@ func TestWildcardMatching(t *testing.T) { func TestDNSRecordStoreWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard records wildcardIP := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", wildcardIP) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Add exact record exactIP := net.ParseIP("10.0.0.2") err = store.AddRecord("exact.autoco.internal", exactIP) if err != nil { t.Fatalf("Failed to add exact record: %v", err) } - + // Test exact match takes precedence ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -190,7 +190,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) { if !ips[0].Equal(exactIP) { t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) } - + // Test wildcard match ips = store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -199,7 +199,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) { if !ips[0].Equal(wildcardIP) { t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) } - + // Test non-match (base domain) ips = store.GetRecords("autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -209,14 +209,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) { func TestDNSRecordStoreComplexWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add complex wildcard pattern ip1 := net.ParseIP("10.0.0.1") err := store.AddRecord("*.host-0?.autoco.internal", ip1) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Test matching domain ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -225,13 +225,13 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { if len(ips) > 0 && !ips[0].Equal(ip1) { t.Errorf("Expected IP %v, got %v", ip1, ips[0]) } - + // Test non-matching domain (missing prefix) ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) } - + // Test non-matching domain (wrong ? position) ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -241,23 +241,23 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { func TestDNSRecordStoreRemoveWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard record ip := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Verify it exists ips := store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP before removal, got %d", len(ips)) } - + // Remove wildcard record store.RemoveRecord("*.autoco.internal", nil) - + // Verify it's gone ips = store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -267,40 +267,40 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { func TestDNSRecordStoreMultipleWildcards(t *testing.T) { store := NewDNSRecordStore() - + // Add multiple wildcard patterns that don't overlap ip1 := net.ParseIP("10.0.0.1") ip2 := net.ParseIP("10.0.0.2") ip3 := net.ParseIP("10.0.0.3") - + err := store.AddRecord("*.prod.autoco.internal", ip1) if err != nil { t.Fatalf("Failed to add first wildcard: %v", err) } - + err = store.AddRecord("*.dev.autoco.internal", ip2) if err != nil { t.Fatalf("Failed to add second wildcard: %v", err) } - + // Add a broader wildcard that matches both err = store.AddRecord("*.autoco.internal", ip3) if err != nil { t.Fatalf("Failed to add third wildcard: %v", err) } - + // Test domain matching only the prod pattern and the broad pattern ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) } - + // Test domain matching only the dev pattern and the broad pattern ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) } - + // Test domain matching only the broad pattern ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -310,14 +310,14 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) { func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add IPv6 wildcard record ip := net.ParseIP("2001:db8::1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add IPv6 wildcard record: %v", err) } - + // Test wildcard match for IPv6 ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) if len(ips) != 1 { @@ -330,21 +330,21 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { func TestHasRecordWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard record ip := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Test HasRecord with wildcard match if !store.HasRecord("host.autoco.internal.", RecordTypeA) { t.Error("Expected HasRecord to return true for wildcard match") } - + // Test HasRecord with non-match if store.HasRecord("autoco.internal.", RecordTypeA) { t.Error("Expected HasRecord to return false for base domain") } -} \ No newline at end of file +} diff --git a/dns/platform/darwin.go b/dns/platform/darwin.go index 61cc81b..8054c57 100644 --- a/dns/platform/darwin.go +++ b/dns/platform/darwin.go @@ -416,4 +416,4 @@ func (d *DarwinDNSConfigurator) clearState() error { logger.Debug("Cleared DNS state file") return nil -} \ No newline at end of file +} diff --git a/olm/olm.go b/olm/olm.go index f84ee4f..02257b8 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -811,7 +811,7 @@ func StartTunnel(config TunnelConfig) { Endpoint: handshakeData.ExitNode.Endpoint, RelayPort: relayPort, PublicKey: handshakeData.ExitNode.PublicKey, - SiteIds: []int{siteId}, + SiteIds: []int{siteId}, } added := holePunchManager.AddExitNode(exitNode) From c565a46a6f7fea8715ce4d3f050a12fcafc596ec Mon Sep 17 00:00:00 2001 From: Varun Narravula Date: Wed, 31 Dec 2025 15:44:19 -0800 Subject: [PATCH 17/27] feat(logger): configure log file path thorugh global options Former-commit-id: 577d89f4fb84d67a170b678bbe0fd844505686d9 --- olm/olm.go | 15 ++++++++++++--- olm/types.go | 3 ++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index 02257b8..98ae6fb 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -100,6 +100,17 @@ func Init(ctx context.Context, config GlobalConfig) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) + if config.LogFilePath != "" { + logFile, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + logger.Fatal("Failed to open log file: %v", err) + } + + // TODO: figure out how to close file, if set + logger.SetOutput(logFile) + return + } + logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() if err != nil { @@ -306,7 +317,7 @@ func StartTunnel(config TunnelConfig) { if config.FileDescriptorTun != 0 { return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) } - var ifName = interfaceName + ifName := interfaceName if runtime.GOOS == "darwin" { // this is if we dont pass a fd ifName, err = network.FindUnusedUTUN() if err != nil { @@ -315,7 +326,6 @@ func StartTunnel(config TunnelConfig) { } return tun.CreateTUN(ifName, config.MTU) }() - if err != nil { logger.Error("Failed to create TUN device: %v", err) return @@ -361,7 +371,6 @@ func StartTunnel(config TunnelConfig) { for { conn, err := uapiListener.Accept() if err != nil { - return } go dev.IpcHandle(conn) diff --git a/olm/types.go b/olm/types.go index b7153af..14cc044 100644 --- a/olm/types.go +++ b/olm/types.go @@ -14,7 +14,8 @@ type WgData struct { type GlobalConfig struct { // Logging - LogLevel string + LogLevel string + LogFilePath string // HTTP server EnableAPI bool From 5b637bb4caac627959868705b80907521295f24b Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 12 Jan 2026 12:20:59 -0800 Subject: [PATCH 18/27] Add expo backoff Former-commit-id: faae551aca7447b717e50137a237c78542cb14cf --- main.go | 1 + olm/olm.go | 12 +++++++ olm/types.go | 3 ++ peers/monitor/monitor.go | 61 +++++++++++++++++++++++++++------ peers/monitor/wgtester.go | 71 +++++++++++++++++++++++++++++++-------- 5 files changed, 123 insertions(+), 25 deletions(-) diff --git a/main.go b/main.go index f6c6973..5b6c15e 100644 --- a/main.go +++ b/main.go @@ -219,6 +219,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt Agent: "Olm CLI", OnExit: cancel, // Pass cancel function directly to trigger shutdown OnTerminated: cancel, + PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE } olm.Init(ctx, olmConfig) diff --git a/olm/olm.go b/olm/olm.go index 4d12952..03cf02b 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" "net" + "net/http" + _ "net/http/pprof" "os" "runtime" "strconv" @@ -101,6 +103,16 @@ func Init(ctx context.Context, config GlobalConfig) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) + // Start pprof server if enabled + if config.PprofAddr != "" { + go func() { + logger.Info("Starting pprof server on %s", config.PprofAddr) + if err := http.ListenAndServe(config.PprofAddr, nil); err != nil { + logger.Error("Failed to start pprof server: %v", err) + } + }() + } + logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() if err != nil { diff --git a/olm/types.go b/olm/types.go index b7153af..a43121f 100644 --- a/olm/types.go +++ b/olm/types.go @@ -23,6 +23,9 @@ type GlobalConfig struct { Version string Agent string + // Debugging + PprofAddr string // Address to serve pprof on (e.g., "localhost:6060") + // Callbacks OnRegistered func() OnConnected func() diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 27bc408..1ec267e 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -61,6 +61,13 @@ type PeerMonitor struct { holepunchMaxAttempts int // max consecutive failures before triggering relay holepunchFailures map[int]int // siteID -> consecutive failure count + // Exponential backoff fields for holepunch monitor + holepunchMinInterval time.Duration // Minimum interval (initial) + holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) + holepunchBackoffMultiplier float64 // Multiplier for each stable check + holepunchStableCount map[int]int // siteID -> consecutive stable status count + holepunchCurrentInterval time.Duration // Current interval with backoff applied + // Rapid initial test fields rapidTestInterval time.Duration // interval between rapid test attempts rapidTestTimeout time.Duration // timeout for each rapid test attempt @@ -101,6 +108,12 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total apiServer: apiServer, wgConnectionStatus: make(map[int]bool), + // Exponential backoff settings for holepunch monitor + holepunchMinInterval: 2 * time.Second, + holepunchMaxInterval: 30 * time.Second, + holepunchBackoffMultiplier: 1.5, + holepunchStableCount: make(map[int]int), + holepunchCurrentInterval: 2 * time.Second, } if err := pm.initNetstack(); err != nil { @@ -172,6 +185,7 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st client.SetPacketInterval(pm.interval) client.SetTimeout(pm.timeout) client.SetMaxAttempts(pm.maxAttempts) + client.SetMaxInterval(30 * time.Second) // Allow backoff up to 30 seconds when stable pm.monitors[siteID] = client @@ -470,31 +484,50 @@ func (pm *PeerMonitor) stopHolepunchMonitor() { logger.Info("Stopped holepunch connection monitor") } -// runHolepunchMonitor runs the holepunch monitoring loop +// runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff func (pm *PeerMonitor) runHolepunchMonitor() { - ticker := time.NewTicker(pm.holepunchInterval) - defer ticker.Stop() + pm.mutex.Lock() + pm.holepunchCurrentInterval = pm.holepunchMinInterval + pm.mutex.Unlock() - // Do initial check immediately - pm.checkHolepunchEndpoints() + timer := time.NewTimer(0) // Fire immediately for initial check + defer timer.Stop() for { select { case <-pm.holepunchStopChan: return - case <-ticker.C: - pm.checkHolepunchEndpoints() + case <-timer.C: + anyStatusChanged := pm.checkHolepunchEndpoints() + + pm.mutex.Lock() + if anyStatusChanged { + // Reset to minimum interval on any status change + pm.holepunchCurrentInterval = pm.holepunchMinInterval + } else { + // Apply exponential backoff when stable + newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier) + if newInterval > pm.holepunchMaxInterval { + newInterval = pm.holepunchMaxInterval + } + pm.holepunchCurrentInterval = newInterval + } + currentInterval := pm.holepunchCurrentInterval + pm.mutex.Unlock() + + timer.Reset(currentInterval) } } } // checkHolepunchEndpoints tests all holepunch endpoints -func (pm *PeerMonitor) checkHolepunchEndpoints() { +// Returns true if any endpoint's status changed +func (pm *PeerMonitor) checkHolepunchEndpoints() bool { pm.mutex.Lock() // Check if we're still running before doing any work if !pm.running { pm.mutex.Unlock() - return + return false } endpoints := make(map[int]string, len(pm.holepunchEndpoints)) for siteID, endpoint := range pm.holepunchEndpoints { @@ -504,6 +537,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { maxAttempts := pm.holepunchMaxAttempts pm.mutex.Unlock() + anyStatusChanged := false + for siteID, endpoint := range endpoints { // logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) result := pm.holepunchTester.TestEndpoint(endpoint, timeout) @@ -529,7 +564,9 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { pm.mutex.Unlock() // Log status changes - if !exists || previousStatus != result.Success { + statusChanged := !exists || previousStatus != result.Success + if statusChanged { + anyStatusChanged = true if result.Success { logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT) } else { @@ -562,7 +599,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { pm.mutex.Unlock() if !stillRunning { - return // Stop processing if shutdown is in progress + return anyStatusChanged // Stop processing if shutdown is in progress } if !result.Success && !isRelayed && failureCount >= maxAttempts { @@ -579,6 +616,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { } } } + + return anyStatusChanged } // GetHolepunchStatus returns the current holepunch status for all endpoints diff --git a/peers/monitor/wgtester.go b/peers/monitor/wgtester.go index dac2008..21f788a 100644 --- a/peers/monitor/wgtester.go +++ b/peers/monitor/wgtester.go @@ -36,6 +36,12 @@ type Client struct { timeout time.Duration maxAttempts int dialer Dialer + + // Exponential backoff fields + minInterval time.Duration // Minimum interval (initial) + maxInterval time.Duration // Maximum interval (cap for backoff) + backoffMultiplier float64 // Multiplier for each stable check + stableCountToBackoff int // Number of stable checks before backing off } // Dialer is a function that creates a connection @@ -50,18 +56,23 @@ type ConnectionStatus struct { // NewClient creates a new connection test client func NewClient(serverAddr string, dialer Dialer) (*Client, error) { return &Client{ - serverAddr: serverAddr, - shutdownCh: make(chan struct{}), - packetInterval: 2 * time.Second, - timeout: 500 * time.Millisecond, // Timeout for individual packets - maxAttempts: 3, // Default max attempts - dialer: dialer, + serverAddr: serverAddr, + shutdownCh: make(chan struct{}), + packetInterval: 2 * time.Second, + minInterval: 2 * time.Second, + maxInterval: 30 * time.Second, + backoffMultiplier: 1.5, + stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off + timeout: 500 * time.Millisecond, // Timeout for individual packets + maxAttempts: 3, // Default max attempts + dialer: dialer, }, nil } // SetPacketInterval changes how frequently packets are sent in monitor mode func (c *Client) SetPacketInterval(interval time.Duration) { c.packetInterval = interval + c.minInterval = interval } // SetTimeout changes the timeout for waiting for responses @@ -74,6 +85,16 @@ func (c *Client) SetMaxAttempts(attempts int) { c.maxAttempts = attempts } +// SetMaxInterval sets the maximum backoff interval +func (c *Client) SetMaxInterval(interval time.Duration) { + c.maxInterval = interval +} + +// SetBackoffMultiplier sets the multiplier for exponential backoff +func (c *Client) SetBackoffMultiplier(multiplier float64) { + c.backoffMultiplier = multiplier +} + // UpdateServerAddr updates the server address and resets the connection func (c *Client) UpdateServerAddr(serverAddr string) { c.connLock.Lock() @@ -138,6 +159,9 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { binary.BigEndian.PutUint32(packet[0:4], magicHeader) packet[4] = packetTypeRequest + // Reusable response buffer + responseBuffer := make([]byte, packetSize) + // Send multiple attempts as specified for attempt := 0; attempt < c.maxAttempts; attempt++ { select { @@ -157,20 +181,17 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { return false, 0 } - // logger.Debug("Attempting to send monitor packet to %s", c.serverAddr) _, err := c.conn.Write(packet) if err != nil { c.connLock.Unlock() logger.Info("Error sending packet: %v", err) continue } - // logger.Debug("Successfully sent monitor packet") // Set read deadline c.conn.SetReadDeadline(time.Now().Add(c.timeout)) // Wait for response - responseBuffer := make([]byte, packetSize) n, err := c.conn.Read(responseBuffer) c.connLock.Unlock() @@ -238,28 +259,50 @@ func (c *Client) StartMonitor(callback MonitorCallback) error { go func() { var lastConnected bool firstRun := true + stableCount := 0 + currentInterval := c.minInterval - ticker := time.NewTicker(c.packetInterval) - defer ticker.Stop() + timer := time.NewTimer(currentInterval) + defer timer.Stop() for { select { case <-c.shutdownCh: return - case <-ticker.C: + case <-timer.C: ctx, cancel := context.WithTimeout(context.Background(), c.timeout) connected, rtt := c.TestConnection(ctx) cancel() + statusChanged := connected != lastConnected + // Callback if status changed or it's the first check - if connected != lastConnected || firstRun { + if statusChanged || firstRun { callback(ConnectionStatus{ Connected: connected, RTT: rtt, }) lastConnected = connected firstRun = false + // Reset backoff on status change + stableCount = 0 + currentInterval = c.minInterval + } else { + // Status is stable, increment counter + stableCount++ + + // Apply exponential backoff after stable threshold + if stableCount >= c.stableCountToBackoff { + newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier) + if newInterval > c.maxInterval { + newInterval = c.maxInterval + } + currentInterval = newInterval + } } + + // Reset timer with current interval + timer.Reset(currentInterval) } } }() @@ -278,4 +321,4 @@ func (c *Client) StopMonitor() { close(c.shutdownCh) c.monitorRunning = false -} +} \ No newline at end of file From 20e0c18845e8062053ecb503de22dc13f5556f99 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 12 Jan 2026 12:29:42 -0800 Subject: [PATCH 19/27] Try to reduce cpu when idle Former-commit-id: ba91478b89832990cb366e2362a21c93e5ce698e --- dns/dns_proxy.go | 76 ++++++++++++++++------------------------ peers/monitor/monitor.go | 67 +++++++++++++++-------------------- 2 files changed, 60 insertions(+), 83 deletions(-) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 748a5a9..d010bc6 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -599,12 +599,12 @@ func (p *DNSProxy) runTunnelPacketSender() { defer p.wg.Done() logger.Debug("DNS tunnel packet sender goroutine started") - ticker := time.NewTicker(1 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-p.ctx.Done(): + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := p.tunnelEp.ReadContext(p.ctx) + if pkt == nil { + // Context was cancelled or endpoint closed logger.Debug("DNS tunnel packet sender exiting") // Drain any remaining packets for { @@ -615,36 +615,28 @@ func (p *DNSProxy) runTunnelPacketSender() { pkt.DecRef() } return - case <-ticker.C: - // Try to read packets - for i := 0; i < 10; i++ { - pkt := p.tunnelEp.Read() - if pkt == nil { - break - } - - // Extract packet data - slices := pkt.AsSlices() - if len(slices) > 0 { - var totalSize int - for _, slice := range slices { - totalSize += len(slice) - } - - buf := make([]byte, totalSize) - pos := 0 - for _, slice := range slices { - copy(buf[pos:], slice) - pos += len(slice) - } - - // Inject into MiddleDevice (outbound to WG) - p.middleDevice.InjectOutbound(buf) - } - - pkt.DecRef() - } } + + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + p.middleDevice.InjectOutbound(buf) + } + + pkt.DecRef() } } @@ -657,18 +649,12 @@ func (p *DNSProxy) runPacketSender() { const offset = 16 for { - select { - case <-p.ctx.Done(): - return - default: - } - - // Read packets from netstack endpoint - pkt := p.ep.Read() + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := p.ep.ReadContext(p.ctx) if pkt == nil { - // No packet available, small sleep to avoid busy loop - time.Sleep(1 * time.Millisecond) - continue + // Context was cancelled or endpoint closed + return } // Extract packet data as slices diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 1ec267e..45dd090 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -42,7 +42,7 @@ type PeerMonitor struct { stack *stack.Stack ep *channel.Endpoint activePorts map[uint16]bool - portsLock sync.Mutex + portsLock sync.RWMutex nsCtx context.Context nsCancel context.CancelFunc nsWg sync.WaitGroup @@ -809,9 +809,9 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool { } // Check if we are listening on this port - pm.portsLock.Lock() + pm.portsLock.RLock() active := pm.activePorts[uint16(port)] - pm.portsLock.Unlock() + pm.portsLock.RUnlock() if !active { return false @@ -842,13 +842,12 @@ func (pm *PeerMonitor) runPacketSender() { defer pm.nsWg.Done() logger.Debug("PeerMonitor: Packet sender goroutine started") - // Use a ticker to periodically check for packets without blocking indefinitely - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-pm.nsCtx.Done(): + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := pm.ep.ReadContext(pm.nsCtx) + if pkt == nil { + // Context was cancelled or endpoint closed logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets") // Drain any remaining packets before exiting for { @@ -860,36 +859,28 @@ func (pm *PeerMonitor) runPacketSender() { } logger.Debug("PeerMonitor: Packet sender goroutine exiting") return - case <-ticker.C: - // Try to read packets in batches - for i := 0; i < 10; i++ { - pkt := pm.ep.Read() - if pkt == nil { - break - } - - // Extract packet data - slices := pkt.AsSlices() - if len(slices) > 0 { - var totalSize int - for _, slice := range slices { - totalSize += len(slice) - } - - buf := make([]byte, totalSize) - pos := 0 - for _, slice := range slices { - copy(buf[pos:], slice) - pos += len(slice) - } - - // Inject into MiddleDevice (outbound to WG) - pm.middleDev.InjectOutbound(buf) - } - - pkt.DecRef() - } } + + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + pm.middleDev.InjectOutbound(buf) + } + + pkt.DecRef() } } From dada0cc1242a4d0921987b079576d14ac8e21366 Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Tue, 13 Jan 2026 14:30:02 -0800 Subject: [PATCH 20/27] add low power state for testing Former-commit-id: 996fe59999c64e63ec74a33c6b2f792d7d9130d4 --- olm/olm.go | 188 ++++++++++++++++++++++++++++++++++++++- peers/manager.go | 7 ++ peers/monitor/monitor.go | 29 ++++-- 3 files changed, 218 insertions(+), 6 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index 774a3cb..21295be 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -53,6 +53,11 @@ var ( updateRegister func(newData interface{}) stopPing chan struct{} peerManager *peers.PeerManager + // Power mode management + currentPowerMode string + originalPeerInterval time.Duration + originalHolepunchMinInterval time.Duration + originalHolepunchMaxInterval time.Duration ) // initTunnelInfo creates the shared UDP socket and holepunch manager. @@ -112,7 +117,7 @@ func Init(ctx context.Context, config GlobalConfig) { } }() } - + if config.LogFilePath != "" { logFile, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) if err != nil { @@ -432,6 +437,18 @@ func StartTunnel(config TunnelConfig) { APIServer: apiServer, }) + // Capture original intervals for power mode management + if peerManager != nil { + peerMonitor := peerManager.GetPeerMonitor() + if peerMonitor != nil { + originalPeerInterval = 2 * time.Second // Default peer interval + originalHolepunchMinInterval, originalHolepunchMaxInterval = peerMonitor.GetHolepunchIntervals() + } + } + + // Initialize power mode to normal + currentPowerMode = "normal" + for i := range wgData.Sites { site := wgData.Sites[i] var siteEndpoint string @@ -1156,3 +1173,172 @@ func SwitchOrg(orgID string) error { return nil } + +// SetPowerMode switches between normal and low power modes +// In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes +// In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored +func SetPowerMode(mode string) error { + // Validate mode + if mode != "normal" && mode != "low" { + return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode) + } + + // If already in the requested mode, return early + if currentPowerMode == mode { + logger.Debug("Already in %s power mode", mode) + return nil + } + + logger.Info("Switching to %s power mode", mode) + + if mode == "low" { + // Low Power Mode: Close websocket and reduce monitoring frequency + + // Close websocket connection - this stops: + // - WebSocket ping monitor (via pingMonitor() goroutine) + // - Application ping messages (via keepSendingPing() goroutine) + if olmClient != nil { + logger.Info("Closing websocket connection for low power mode") + if err := olmClient.Close(); err != nil { + logger.Error("Error closing websocket: %v", err) + } + } + + // Stop application ping goroutine + if stopPing != nil { + select { + case <-stopPing: + // Channel already closed + default: + close(stopPing) + } + } + + // Stop peer monitoring + if peerManager != nil { + peerManager.Stop() + } + + // Store original intervals if not already stored + if originalPeerInterval == 0 && peerManager != nil { + peerMonitor := peerManager.GetPeerMonitor() + if peerMonitor != nil { + originalPeerInterval = 2 * time.Second // Default peer interval + originalHolepunchMinInterval, originalHolepunchMaxInterval = peerMonitor.GetHolepunchIntervals() + } + } + + // Set monitoring intervals to 10 minutes + if peerManager != nil { + peerMonitor := peerManager.GetPeerMonitor() + if peerMonitor != nil { + lowPowerInterval := 10 * time.Minute + peerMonitor.SetInterval(lowPowerInterval) + peerMonitor.SetHolepunchInterval(lowPowerInterval, lowPowerInterval) + logger.Info("Set monitoring intervals to 10 minutes for low power mode") + } + } + + // Restart peer monitoring with new intervals (but websocket remains closed) + if peerManager != nil { + peerManager.Start() + } + + currentPowerMode = "low" + logger.Info("Switched to low power mode") + + } else { + // Normal Power Mode: Restore intervals and reconnect websocket + + // Restore monitoring intervals to original values + if peerManager != nil { + peerMonitor := peerManager.GetPeerMonitor() + if peerMonitor != nil { + // Restore peer interval + if originalPeerInterval == 0 { + originalPeerInterval = 2 * time.Second // Default if not captured + } + peerMonitor.SetInterval(originalPeerInterval) + + // Restore holepunch intervals + if originalHolepunchMinInterval == 0 { + originalHolepunchMinInterval = 2 * time.Second // Default if not captured + } + if originalHolepunchMaxInterval == 0 { + originalHolepunchMaxInterval = 30 * time.Second // Default if not captured + } + peerMonitor.SetHolepunchInterval(originalHolepunchMinInterval, originalHolepunchMaxInterval) + logger.Info("Restored monitoring intervals to normal (peer: %v, holepunch: %v-%v)", + originalPeerInterval, originalHolepunchMinInterval, originalHolepunchMaxInterval) + } + } + + // Restart peer monitoring with restored intervals + if peerManager != nil { + peerManager.Start() + } + + // Reconnect websocket - this restarts: + // - WebSocket ping monitor + // - Application ping messages (via OnConnect callback) + // Note: Since websocket client's Close() permanently closes the done channel, + // we need to create a new client instance and re-register handlers + if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" { + logger.Info("Reconnecting websocket for normal power mode") + + // Close old client if it exists + if olmClient != nil { + olmClient.Close() + } + + // Recreate stopPing channel for application pings + stopPing = make(chan struct{}) + + // Create a new websocket client + var ( + id = tunnelConfig.ID + secret = tunnelConfig.Secret + userToken = tunnelConfig.UserToken + ) + + olm, err := websocket.NewClient( + id, + secret, + userToken, + tunnelConfig.OrgID, + tunnelConfig.Endpoint, + tunnelConfig.PingIntervalDuration, + tunnelConfig.PingTimeoutDuration, + ) + if err != nil { + logger.Error("Failed to create new websocket client: %v", err) + return fmt.Errorf("failed to create new websocket client: %w", err) + } + + // Store the new client + olmClient = olm + + // Re-register essential handlers (simplified - only the most critical ones) + // The full handler registration happens in StartTunnel, so this is just for reconnection + olm.OnConnect(func() error { + logger.Info("Websocket Reconnected") + apiServer.SetConnectionStatus(true) + go keepSendingPing(olm) + return nil + }) + + // Connect to the WebSocket server + if err := olm.Connect(); err != nil { + logger.Error("Failed to reconnect websocket: %v", err) + return fmt.Errorf("failed to reconnect websocket: %w", err) + } + } else { + logger.Warn("Cannot reconnect websocket: tunnel config not available") + } + + currentPowerMode = "normal" + logger.Info("Switched to normal power mode") + } + + return nil +} diff --git a/peers/manager.go b/peers/manager.go index af781e5..56f3707 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -84,6 +84,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) { return peer, ok } +// GetPeerMonitor returns the internal peer monitor instance +func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor { + pm.mu.RLock() + defer pm.mu.RUnlock() + return pm.peerMonitor +} + func (pm *PeerManager) GetAllPeers() []SiteConfig { pm.mu.RLock() defer pm.mu.RUnlock() diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 45dd090..2bb0c80 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -62,11 +62,11 @@ type PeerMonitor struct { holepunchFailures map[int]int // siteID -> consecutive failure count // Exponential backoff fields for holepunch monitor - holepunchMinInterval time.Duration // Minimum interval (initial) - holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) - holepunchBackoffMultiplier float64 // Multiplier for each stable check - holepunchStableCount map[int]int // siteID -> consecutive stable status count - holepunchCurrentInterval time.Duration // Current interval with backoff applied + holepunchMinInterval time.Duration // Minimum interval (initial) + holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) + holepunchBackoffMultiplier float64 // Multiplier for each stable check + holepunchStableCount map[int]int // siteID -> consecutive stable status count + holepunchCurrentInterval time.Duration // Current interval with backoff applied // Rapid initial test fields rapidTestInterval time.Duration // interval between rapid test attempts @@ -167,6 +167,25 @@ func (pm *PeerMonitor) SetMaxAttempts(attempts int) { } } +// SetHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) SetHolepunchInterval(minInterval, maxInterval time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.holepunchMinInterval = minInterval + pm.holepunchMaxInterval = maxInterval + // Reset current interval to the new minimum + pm.holepunchCurrentInterval = minInterval +} + +// GetHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) GetHolepunchIntervals() (minInterval, maxInterval time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + return pm.holepunchMinInterval, pm.holepunchMaxInterval +} + // AddPeer adds a new peer to monitor func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint string) error { pm.mutex.Lock() From 15e96a779cc3ff109e4c4cf6a46ae0cdbd359ec9 Mon Sep 17 00:00:00 2001 From: Varun Narravula Date: Mon, 5 Jan 2026 01:41:54 -0800 Subject: [PATCH 21/27] refactor(olm): convert global state into an olm instance Former-commit-id: b755f77d95ecc9d645806fa33a11d91261cfd059 --- api/api.go | 29 +- main.go | 12 +- olm/connect.go | 223 +++++++++ olm/data.go | 197 ++++++++ olm/olm.go | 976 ++++++++------------------------------- olm/peer.go | 195 ++++++++ olm/{util.go => ping.go} | 6 +- olm/types.go | 2 +- 8 files changed, 841 insertions(+), 799 deletions(-) create mode 100644 olm/connect.go create mode 100644 olm/data.go create mode 100644 olm/peer.go rename olm/{util.go => ping.go} (89%) diff --git a/api/api.go b/api/api.go index 91d9f37..a6ac9cd 100644 --- a/api/api.go +++ b/api/api.go @@ -63,23 +63,26 @@ type StatusResponse struct { // API represents the HTTP server and its state type API struct { - addr string - socketPath string - listener net.Listener - server *http.Server + addr string + socketPath string + listener net.Listener + server *http.Server + onConnect func(ConnectionRequest) error onSwitchOrg func(SwitchOrgRequest) error onDisconnect func() error onExit func() error + statusMu sync.RWMutex peerStatuses map[int]*PeerStatus connectedAt time.Time isConnected bool isRegistered bool isTerminated bool - version string - agent string - orgID string + + version string + agent string + orgID string } // NewAPI creates a new HTTP server that listens on a TCP address @@ -173,7 +176,7 @@ func (s *API) Stop() error { // Close the server first, which will also close the listener gracefully if s.server != nil { - s.server.Close() + _ = s.server.Close() } // Clean up socket file if using Unix socket @@ -358,7 +361,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusAccepted) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "connection request accepted", }) } @@ -406,7 +409,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "ok", }) } @@ -423,7 +426,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) { // Return a success response first w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "shutdown initiated", }) @@ -472,7 +475,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "org switch request accepted", }) } @@ -506,7 +509,7 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "disconnect initiated", }) } diff --git a/main.go b/main.go index 5b6c15e..2bf8dcd 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/updates" - "github.com/fosrl/olm/olm" + olmpkg "github.com/fosrl/olm/olm" ) func main() { @@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt } // Create a new olm.Config struct and copy values from the main config - olmConfig := olm.GlobalConfig{ + olmConfig := olmpkg.OlmConfig{ LogLevel: config.LogLevel, EnableAPI: config.EnableAPI, HTTPAddr: config.HTTPAddr, @@ -222,13 +222,17 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE } - olm.Init(ctx, olmConfig) + olm, err := olmpkg.Init(ctx, olmConfig) + if err != nil { + logger.Fatal("Failed to initialize olm: %v", err) + } + if err := olm.StartApi(); err != nil { logger.Fatal("Failed to start API server: %v", err) } if config.ID != "" && config.Secret != "" && config.Endpoint != "" { - tunnelConfig := olm.TunnelConfig{ + tunnelConfig := olmpkg.TunnelConfig{ Endpoint: config.Endpoint, ID: config.ID, Secret: config.Secret, diff --git a/olm/connect.go b/olm/connect.go new file mode 100644 index 0000000..568c731 --- /dev/null +++ b/olm/connect.go @@ -0,0 +1,223 @@ +package olm + +import ( + "encoding/json" + "fmt" + "os" + "runtime" + "strconv" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + olmDevice "github.com/fosrl/olm/device" + "github.com/fosrl/olm/dns" + dnsOverride "github.com/fosrl/olm/dns/override" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" +) + +func (o *Olm) handleConnect(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + + var wgData WgData + + if o.connected { + logger.Info("Already connected. Ignoring new connection request.") + return + } + + if o.stopRegister != nil { + o.stopRegister() + o.stopRegister = nil + } + + if o.updateRegister != nil { + o.updateRegister = nil + } + + // if there is an existing tunnel then close it + if o.dev != nil { + logger.Info("Got new message. Closing existing tunnel!") + o.dev.Close() + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &wgData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + + o.tdev, err = func() (tun.Device, error) { + if o.tunnelConfig.FileDescriptorTun != 0 { + return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU) + } + ifName := o.tunnelConfig.InterfaceName + if runtime.GOOS == "darwin" { // this is if we dont pass a fd + ifName, err = network.FindUnusedUTUN() + if err != nil { + return nil, err + } + } + return tun.CreateTUN(ifName, o.tunnelConfig.MTU) + }() + if err != nil { + logger.Error("Failed to create TUN device: %v", err) + return + } + + // if config.FileDescriptorTun == 0 { + if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? + o.tunnelConfig.InterfaceName = realInterfaceName + } + // } + + // Wrap TUN device with packet filter for DNS proxy + o.middleDev = olmDevice.NewMiddleDevice(o.tdev) + + wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") + // Use filtered device instead of raw TUN device + o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger)) + + if o.tunnelConfig.EnableUAPI { + fileUAPI, err := func() (*os.File, error) { + if o.tunnelConfig.FileDescriptorUAPI != 0 { + fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) + } + return os.NewFile(uintptr(fd), ""), nil + } + return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName) + }() + if err != nil { + logger.Error("UAPI listen error: %v", err) + os.Exit(1) + return + } + + o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI) + if err != nil { + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) + } + + go func() { + for { + conn, err := o.uapiListener.Accept() + if err != nil { + return + } + go o.dev.IpcHandle(conn) + } + }() + logger.Info("UAPI listener started") + } + + if err = o.dev.Up(); err != nil { + logger.Error("Failed to bring up WireGuard device: %v", err) + } + + // Extract interface IP (strip CIDR notation if present) + interfaceIP := wgData.TunnelIP + if strings.Contains(interfaceIP, "/") { + interfaceIP = strings.Split(interfaceIP, "/")[0] + } + + // Create and start DNS proxy + o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP) + if err != nil { + logger.Error("Failed to create DNS proxy: %v", err) + } + + if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil { + logger.Error("Failed to o.tunnelConfigure interface: %v", err) + } + + if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet + logger.Error("Failed to add route for utility subnet: %v", err) + } + + // Create peer manager with integrated peer monitoring + o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ + Device: o.dev, + DNSProxy: o.dnsProxy, + InterfaceName: o.tunnelConfig.InterfaceName, + PrivateKey: o.privateKey, + MiddleDev: o.middleDev, + LocalIP: interfaceIP, + SharedBind: o.sharedBind, + WSClient: o.olmClient, + APIServer: o.apiServer, + }) + + for i := range wgData.Sites { + site := wgData.Sites[i] + var siteEndpoint string + // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer + if site.RelayEndpoint != "" { + siteEndpoint = site.RelayEndpoint + } else { + siteEndpoint = site.Endpoint + } + + o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) + + if err := o.peerManager.AddPeer(site); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + logger.Info("Configured peer %s", site.PublicKey) + } + + o.peerManager.Start() + + if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime + logger.Error("Failed to start DNS proxy: %v", err) + } + + if o.tunnelConfig.OverrideDNS { + // Set up DNS override to use our DNS proxy + if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { + logger.Error("Failed to setup DNS override: %v", err) + return + } + + network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()}) + } + + o.apiServer.SetRegistered(true) + + o.connected = true + + // Invoke onConnected callback if configured + if o.olmConfig.OnConnected != nil { + go o.olmConfig.OnConnected() + } + + logger.Info("WireGuard device created.") +} + +func (o *Olm) handleTerminate(msg websocket.WSMessage) { + logger.Info("Received terminate message") + o.apiServer.SetTerminated(true) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) + o.apiServer.ClearPeerStatuses() + + network.ClearNetworkSettings() + + o.Close() + + if o.olmConfig.OnTerminated != nil { + go o.olmConfig.OnTerminated() + } +} diff --git a/olm/data.go b/olm/data.go new file mode 100644 index 0000000..9c8d33f --- /dev/null +++ b/olm/data.go @@ -0,0 +1,197 @@ +package olm + +import ( + "encoding/json" + "time" + + "github.com/fosrl/newt/holepunch" + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" +) + +func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) { + logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var addSubnetsData peers.PeerAdd + if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil { + logger.Error("Error unmarshaling add-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId) + return + } + + // Add new subnets + for _, subnet := range addSubnetsData.RemoteSubnets { + if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases + for _, alias := range addSubnetsData.Aliases { + if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } +} + +func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) { + logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeSubnetsData peers.RemovePeerData + if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil { + logger.Error("Error unmarshaling remove-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId) + return + } + + // Remove subnets + for _, subnet := range removeSubnetsData.RemoteSubnets { + if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Remove aliases + for _, alias := range removeSubnetsData.Aliases { + if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } +} + +func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) { + logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateSubnetsData peers.UpdatePeerData + if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil { + logger.Error("Error unmarshaling update-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId) + return + } + + // Add new subnets BEFORE removing old ones to preserve shared subnets + // This ensures that if an old and new subnet are the same on different peers, + // the route won't be temporarily removed + for _, subnet := range updateSubnetsData.NewRemoteSubnets { + if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Remove old subnets after new ones are added + for _, subnet := range updateSubnetsData.OldRemoteSubnets { + if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases BEFORE removing old ones to preserve shared IP addresses + // This ensures that if an old and new alias share the same IP, the IP won't be + // temporarily removed from the allowed IPs list + for _, alias := range updateSubnetsData.NewAliases { + if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } + + // Remove old aliases after new ones are added + for _, alias := range updateSubnetsData.OldAliases { + if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } + + logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) +} + +func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { + logger.Debug("Received peer-handshake message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling handshake data: %v", err) + return + } + + var handshakeData struct { + SiteId int `json:"siteId"` + ExitNode struct { + PublicKey string `json:"publicKey"` + Endpoint string `json:"endpoint"` + RelayPort uint16 `json:"relayPort"` + } `json:"exitNode"` + } + + if err := json.Unmarshal(jsonData, &handshakeData); err != nil { + logger.Error("Error unmarshaling handshake data: %v", err) + return + } + + // Get existing peer from PeerManager + _, exists := o.peerManager.GetPeer(handshakeData.SiteId) + if exists { + logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) + return + } + + relayPort := handshakeData.ExitNode.RelayPort + if relayPort == 0 { + relayPort = 21820 // default relay port + } + + siteId := handshakeData.SiteId + exitNode := holepunch.ExitNode{ + Endpoint: handshakeData.ExitNode.Endpoint, + RelayPort: relayPort, + PublicKey: handshakeData.ExitNode.PublicKey, + SiteIds: []int{siteId}, + } + + added := o.holePunchManager.AddExitNode(exitNode) + if added { + logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) + } else { + logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) + } + + o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt + o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud + + // Send handshake acknowledgment back to server with retry + o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + }, 1*time.Second) + + logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) +} diff --git a/olm/olm.go b/olm/olm.go index 774a3cb..6d8f7a5 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -2,15 +2,11 @@ package olm import ( "context" - "encoding/json" "fmt" "net" "net/http" _ "net/http/pprof" "os" - "runtime" - "strconv" - "strings" "time" "github.com/fosrl/newt/bind" @@ -30,41 +26,49 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -var ( - privateKey wgtypes.Key - connected bool - dev *device.Device - uapiListener net.Listener - tdev tun.Device - middleDev *olmDevice.MiddleDevice - interfaceName string +type Olm struct { + privateKey wgtypes.Key + logFile *os.File + + connected bool + tunnelRunning bool + + uapiListener net.Listener + dev *device.Device + tdev tun.Device + middleDev *olmDevice.MiddleDevice + sharedBind *bind.SharedBind + dnsProxy *dns.DNSProxy apiServer *api.API olmClient *websocket.Client - tunnelCancel context.CancelFunc - tunnelRunning bool - sharedBind *bind.SharedBind holePunchManager *holepunch.Manager - globalConfig GlobalConfig - tunnelConfig TunnelConfig - globalCtx context.Context - stopRegister func() - stopPeerSend func() - updateRegister func(newData interface{}) - stopPing chan struct{} peerManager *peers.PeerManager -) + + olmCtx context.Context + tunnelCancel context.CancelFunc + + olmConfig OlmConfig + tunnelConfig TunnelConfig + + stopRegister func() + stopPeerSend func() + updateRegister func(newData any) + + stopPing chan struct{} +} // initTunnelInfo creates the shared UDP socket and holepunch manager. // This is used during initial tunnel setup and when switching organizations. -func initTunnelInfo(clientID string) error { - var err error - privateKey, err = wgtypes.GeneratePrivateKey() +func (o *Olm) initTunnelInfo(clientID string) error { + privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { logger.Error("Failed to generate private key: %v", err) return err } + o.privateKey = privateKey + sourcePort, err := util.FindAvailableUDPPort(49152, 65535) if err != nil { return fmt.Errorf("failed to find available UDP port: %w", err) @@ -80,27 +84,26 @@ func initTunnelInfo(clientID string) error { return fmt.Errorf("failed to create UDP socket: %w", err) } - sharedBind, err = bind.New(udpConn) + sharedBind, err := bind.New(udpConn) if err != nil { - udpConn.Close() + _ = udpConn.Close() return fmt.Errorf("failed to create shared bind: %w", err) } + o.sharedBind = sharedBind + // Add a reference for the hole punch senders (creator already has one reference for WireGuard) sharedBind.AddRef() logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) // Create the holepunch manager - holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) + o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) return nil } -func Init(ctx context.Context, config GlobalConfig) { - globalConfig = config - globalCtx = ctx - +func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) // Start pprof server if enabled @@ -112,25 +115,27 @@ func Init(ctx context.Context, config GlobalConfig) { } }() } - + + var logFile *os.File if config.LogFilePath != "" { - logFile, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + file, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) if err != nil { logger.Fatal("Failed to open log file: %v", err) + return nil, err } - // TODO: figure out how to close file, if set - logger.SetOutput(logFile) - return + logger.SetOutput(file) + logFile = file } logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() if err != nil { logger.Fatal("Insufficient permissions to create native TUN interface: %v", err) - return + return nil, err } + var apiServer *api.API if config.HTTPAddr != "" { apiServer = api.NewAPI(config.HTTPAddr) } else if config.SocketPath != "" { @@ -143,18 +148,24 @@ func Init(ctx context.Context, config GlobalConfig) { apiServer.SetVersion(config.Version) apiServer.SetAgent(config.Agent) - // Set up API handlers - apiServer.SetHandlers( + newOlm := &Olm{ + logFile: logFile, + olmCtx: ctx, + apiServer: apiServer, + olmConfig: config, + } + + newOlm.registerAPICallbacks() + + return newOlm, nil +} + +func (o *Olm) registerAPICallbacks() { + o.apiServer.SetHandlers( // onConnect func(req api.ConnectionRequest) error { logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) - // Stop any existing tunnel before starting a new one - if olmClient != nil { - logger.Info("Stopping existing tunnel before starting new connection") - StopTunnel() - } - tunnelConfig := TunnelConfig{ Endpoint: req.Endpoint, ID: req.ID, @@ -208,7 +219,7 @@ func Init(ctx context.Context, config GlobalConfig) { // Start the tunnel process with the new credentials if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" { logger.Info("Starting tunnel with new credentials") - go StartTunnel(tunnelConfig) + go o.StartTunnel(tunnelConfig) } return nil @@ -216,66 +227,64 @@ func Init(ctx context.Context, config GlobalConfig) { // onSwitchOrg func(req api.SwitchOrgRequest) error { logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID) - return SwitchOrg(req.OrgID) + return o.SwitchOrg(req.OrgID) }, // onDisconnect func() error { logger.Info("Processing disconnect request via API") - return StopTunnel() + return o.StopTunnel() }, // onExit func() error { logger.Info("Processing shutdown request via API") - Close() - if globalConfig.OnExit != nil { - globalConfig.OnExit() + o.Close() + if o.olmConfig.OnExit != nil { + o.olmConfig.OnExit() } return nil }, ) } -func StartTunnel(config TunnelConfig) { - if tunnelRunning { +func (o *Olm) StartTunnel(config TunnelConfig) { + if o.tunnelRunning { logger.Info("Tunnel already running") return } - tunnelRunning = true // Also set it here in case it is called externally - tunnelConfig = config + o.tunnelRunning = true // Also set it here in case it is called externally + o.tunnelConfig = config // Reset terminated status when tunnel starts - apiServer.SetTerminated(false) + o.apiServer.SetTerminated(false) // debug print out the whole config logger.Debug("Starting tunnel with config: %+v", config) // Create a cancellable context for this tunnel process - tunnelCtx, cancel := context.WithCancel(globalCtx) - tunnelCancel = cancel - defer func() { - tunnelCancel = nil - }() + tunnelCtx, cancel := context.WithCancel(o.olmCtx) + o.tunnelCancel = cancel // Recreate channels for this tunnel session - stopPing = make(chan struct{}) + o.stopPing = make(chan struct{}) var ( id = config.ID secret = config.Secret userToken = config.UserToken ) - interfaceName = config.InterfaceName - apiServer.SetOrgID(config.OrgID) + o.tunnelConfig.InterfaceName = config.InterfaceName - // Create a new olm client using the provided credentials - olm, err := websocket.NewClient( - id, // Use provided ID - secret, // Use provided secret - userToken, // Use provided user token OPTIONAL + o.apiServer.SetOrgID(config.OrgID) + + // Create a new olmClient client using the provided credentials + olmClient, err := websocket.NewClient( + id, + secret, + userToken, config.OrgID, - config.Endpoint, // Use provided endpoint + config.Endpoint, config.PingIntervalDuration, config.PingTimeoutDuration, ) @@ -284,638 +293,70 @@ func StartTunnel(config TunnelConfig) { return } - // Store the client reference globally - olmClient = olm - // Create shared UDP socket and holepunch manager - if err := initTunnelInfo(id); err != nil { + if err := o.initTunnelInfo(id); err != nil { logger.Error("%v", err) return } - olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - - var wgData WgData - - if connected { - logger.Info("Already connected. Ignoring new connection request.") - return - } - - if stopRegister != nil { - stopRegister() - stopRegister = nil - } - - if updateRegister != nil { - updateRegister = nil - } - - // if there is an existing tunnel then close it - if dev != nil { - logger.Info("Got new message. Closing existing tunnel!") - dev.Close() - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &wgData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - tdev, err = func() (tun.Device, error) { - if config.FileDescriptorTun != 0 { - return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) - } - ifName := interfaceName - if runtime.GOOS == "darwin" { // this is if we dont pass a fd - ifName, err = network.FindUnusedUTUN() - if err != nil { - return nil, err - } - } - return tun.CreateTUN(ifName, config.MTU) - }() - if err != nil { - logger.Error("Failed to create TUN device: %v", err) - return - } - - // if config.FileDescriptorTun == 0 { - if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? - interfaceName = realInterfaceName - } - // } - - // Wrap TUN device with packet filter for DNS proxy - middleDev = olmDevice.NewMiddleDevice(tdev) - - wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") - // Use filtered device instead of raw TUN device - dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger)) - - if config.EnableUAPI { - fileUAPI, err := func() (*os.File, error) { - if config.FileDescriptorUAPI != 0 { - fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32) - if err != nil { - return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) - } - return os.NewFile(uintptr(fd), ""), nil - } - return olmDevice.UapiOpen(interfaceName) - }() - if err != nil { - logger.Error("UAPI listen error: %v", err) - os.Exit(1) - return - } - - uapiListener, err = olmDevice.UapiListen(interfaceName, fileUAPI) - if err != nil { - logger.Error("Failed to listen on uapi socket: %v", err) - os.Exit(1) - } - - go func() { - for { - conn, err := uapiListener.Accept() - if err != nil { - return - } - go dev.IpcHandle(conn) - } - }() - logger.Info("UAPI listener started") - } - - if err = dev.Up(); err != nil { - logger.Error("Failed to bring up WireGuard device: %v", err) - } - - // Extract interface IP (strip CIDR notation if present) - interfaceIP := wgData.TunnelIP - if strings.Contains(interfaceIP, "/") { - interfaceIP = strings.Split(interfaceIP, "/")[0] - } - - // Create and start DNS proxy - dnsProxy, err = dns.NewDNSProxy(middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP) - if err != nil { - logger.Error("Failed to create DNS proxy: %v", err) - } - - if err = network.ConfigureInterface(interfaceName, wgData.TunnelIP, config.MTU); err != nil { - logger.Error("Failed to configure interface: %v", err) - } - - if network.AddRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet - logger.Error("Failed to add route for utility subnet: %v", err) - } - - // Create peer manager with integrated peer monitoring - peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ - Device: dev, - DNSProxy: dnsProxy, - InterfaceName: interfaceName, - PrivateKey: privateKey, - MiddleDev: middleDev, - LocalIP: interfaceIP, - SharedBind: sharedBind, - WSClient: olm, - APIServer: apiServer, - }) - - for i := range wgData.Sites { - site := wgData.Sites[i] - var siteEndpoint string - // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer - if site.RelayEndpoint != "" { - siteEndpoint = site.RelayEndpoint - } else { - siteEndpoint = site.Endpoint - } - - apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) - - if err := peerManager.AddPeer(site); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - - logger.Info("Configured peer %s", site.PublicKey) - } - - peerManager.Start() - - if err := dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime - logger.Error("Failed to start DNS proxy: %v", err) - } - - if config.OverrideDNS { - // Set up DNS override to use our DNS proxy - if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy.GetProxyIP()); err != nil { - logger.Error("Failed to setup DNS override: %v", err) - return - } - - network.SetDNSServers([]string{dnsProxy.GetProxyIP().String()}) - } - - apiServer.SetRegistered(true) - - connected = true - - // Invoke onConnected callback if configured - if globalConfig.OnConnected != nil { - go globalConfig.OnConnected() - } - - logger.Info("WireGuard device created.") - }) - - olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { - logger.Debug("Received update-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var updateData peers.SiteConfig - if err := json.Unmarshal(jsonData, &updateData); err != nil { - logger.Error("Error unmarshaling update data: %v", err) - return - } - - // Get existing peer from PeerManager - existingPeer, exists := peerManager.GetPeer(updateData.SiteId) - if !exists { - logger.Warn("Peer with site ID %d not found", updateData.SiteId) - return - } - - // Create updated site config by merging with existing data - siteConfig := existingPeer - - if updateData.Endpoint != "" { - siteConfig.Endpoint = updateData.Endpoint - } - if updateData.RelayEndpoint != "" { - siteConfig.RelayEndpoint = updateData.RelayEndpoint - } - if updateData.PublicKey != "" { - siteConfig.PublicKey = updateData.PublicKey - } - if updateData.ServerIP != "" { - siteConfig.ServerIP = updateData.ServerIP - } - if updateData.ServerPort != 0 { - siteConfig.ServerPort = updateData.ServerPort - } - if updateData.RemoteSubnets != nil { - siteConfig.RemoteSubnets = updateData.RemoteSubnets - } - - if err := peerManager.UpdatePeer(siteConfig); err != nil { - logger.Error("Failed to update peer: %v", err) - return - } - - // If the endpoint changed, trigger holepunch to refresh NAT mappings - if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { - logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) - holePunchManager.TriggerHolePunch() - holePunchManager.ResetInterval() - } - - // Update successful - logger.Info("Successfully updated peer for site %d", updateData.SiteId) - }) - - // Handler for adding a new peer - olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { - logger.Debug("Received add-peer message: %v", msg.Data) - - if stopPeerSend != nil { - stopPeerSend() - stopPeerSend = nil - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var siteConfig peers.SiteConfig - if err := json.Unmarshal(jsonData, &siteConfig); err != nil { - logger.Error("Error unmarshaling add data: %v", err) - return - } - - holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it - - if err := peerManager.AddPeer(siteConfig); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - - // Add successful - logger.Info("Successfully added peer for site %d", siteConfig.SiteId) - }) - - // Handler for removing a peer - olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) { - logger.Debug("Received remove-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var removeData peers.PeerRemove - if err := json.Unmarshal(jsonData, &removeData); err != nil { - logger.Error("Error unmarshaling remove data: %v", err) - return - } - - if err := peerManager.RemovePeer(removeData.SiteId); err != nil { - logger.Error("Failed to remove peer: %v", err) - return - } - - // Remove any exit nodes associated with this peer from hole punching - if holePunchManager != nil { - removed := holePunchManager.RemoveExitNodesByPeer(removeData.SiteId) - if removed > 0 { - logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId) - } - } - - // Remove successful - logger.Info("Successfully removed peer for site %d", removeData.SiteId) - }) - - // Handler for adding remote subnets to a peer - olm.RegisterHandler("olm/wg/peer/data/add", func(msg websocket.WSMessage) { - logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var addSubnetsData peers.PeerAdd - if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil { - logger.Error("Error unmarshaling add-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(addSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId) - return - } - - // Add new subnets - for _, subnet := range addSubnetsData.RemoteSubnets { - if err := peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to add allowed IP %s: %v", subnet, err) - } - } - - // Add new aliases - for _, alias := range addSubnetsData.Aliases { - if err := peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil { - logger.Error("Failed to add alias %s: %v", alias.Alias, err) - } - } - }) - - // Handler for removing remote subnets from a peer - olm.RegisterHandler("olm/wg/peer/data/remove", func(msg websocket.WSMessage) { - logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var removeSubnetsData peers.RemovePeerData - if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil { - logger.Error("Error unmarshaling remove-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(removeSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId) - return - } - - // Remove subnets - for _, subnet := range removeSubnetsData.RemoteSubnets { - if err := peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to remove allowed IP %s: %v", subnet, err) - } - } - - // Remove aliases - for _, alias := range removeSubnetsData.Aliases { - if err := peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil { - logger.Error("Failed to remove alias %s: %v", alias.Alias, err) - } - } - }) - - // Handler for updating remote subnets of a peer (remove old, add new in one operation) - olm.RegisterHandler("olm/wg/peer/data/update", func(msg websocket.WSMessage) { - logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var updateSubnetsData peers.UpdatePeerData - if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil { - logger.Error("Error unmarshaling update-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(updateSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", updateSubnetsData.SiteId) - return - } - - // Add new subnets BEFORE removing old ones to preserve shared subnets - // This ensures that if an old and new subnet are the same on different peers, - // the route won't be temporarily removed - for _, subnet := range updateSubnetsData.NewRemoteSubnets { - if err := peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to add allowed IP %s: %v", subnet, err) - } - } - - // Remove old subnets after new ones are added - for _, subnet := range updateSubnetsData.OldRemoteSubnets { - if err := peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to remove allowed IP %s: %v", subnet, err) - } - } - - // Add new aliases BEFORE removing old ones to preserve shared IP addresses - // This ensures that if an old and new alias share the same IP, the IP won't be - // temporarily removed from the allowed IPs list - for _, alias := range updateSubnetsData.NewAliases { - if err := peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil { - logger.Error("Failed to add alias %s: %v", alias.Alias, err) - } - } - - // Remove old aliases after new ones are added - for _, alias := range updateSubnetsData.OldAliases { - if err := peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil { - logger.Error("Failed to remove alias %s: %v", alias.Alias, err) - } - } - - logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) - }) - - olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) { - logger.Debug("Received relay-peer message: %v", msg.Data) - - // Check if peerManager is still valid (may be nil during shutdown) - if peerManager == nil { - logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") - return - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var relayData peers.RelayPeerData - if err := json.Unmarshal(jsonData, &relayData); err != nil { - logger.Error("Error unmarshaling relay data: %v", err) - return - } - - primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - // Update HTTP server to mark this peer as using relay - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) - - peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort) - }) - - olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) { - logger.Debug("Received unrelay-peer message: %v", msg.Data) - - // Check if peerManager is still valid (may be nil during shutdown) - if peerManager == nil { - logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") - return - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var relayData peers.UnRelayPeerData - if err := json.Unmarshal(jsonData, &relayData); err != nil { - logger.Error("Error unmarshaling relay data: %v", err) - return - } - - primaryRelay, err := util.ResolveDomain(relayData.Endpoint) - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - // Update HTTP server to mark this peer as using relay - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) - - peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) - }) + // Handlers for managing connection status + olmClient.RegisterHandler("olm/wg/connect", o.handleConnect) + olmClient.RegisterHandler("olm/terminate", o.handleTerminate) + + // Handlers for managing peers + olmClient.RegisterHandler("olm/wg/peer/add", o.handleWgPeerAdd) + olmClient.RegisterHandler("olm/wg/peer/remove", o.handleWgPeerRemove) + olmClient.RegisterHandler("olm/wg/peer/update", o.handleWgPeerUpdate) + olmClient.RegisterHandler("olm/wg/peer/relay", o.handleWgPeerRelay) + olmClient.RegisterHandler("olm/wg/peer/unrelay", o.handleWgPeerUnrelay) + + // Handlers for managing remote subnets to a peer + olmClient.RegisterHandler("olm/wg/peer/data/add", o.handleWgPeerAddData) + olmClient.RegisterHandler("olm/wg/peer/data/remove", o.handleWgPeerRemoveData) + olmClient.RegisterHandler("olm/wg/peer/data/update", o.handleWgPeerUpdateData) // Handler for peer handshake - adds exit node to holepunch rotation and notifies server - olm.RegisterHandler("olm/wg/peer/holepunch/site/add", func(msg websocket.WSMessage) { - logger.Debug("Received peer-handshake message: %v", msg.Data) + olmClient.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling handshake data: %v", err) - return - } - - var handshakeData struct { - SiteId int `json:"siteId"` - ExitNode struct { - PublicKey string `json:"publicKey"` - Endpoint string `json:"endpoint"` - RelayPort uint16 `json:"relayPort"` - } `json:"exitNode"` - } - - if err := json.Unmarshal(jsonData, &handshakeData); err != nil { - logger.Error("Error unmarshaling handshake data: %v", err) - return - } - - // Get existing peer from PeerManager - _, exists := peerManager.GetPeer(handshakeData.SiteId) - if exists { - logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) - return - } - - relayPort := handshakeData.ExitNode.RelayPort - if relayPort == 0 { - relayPort = 21820 // default relay port - } - - siteId := handshakeData.SiteId - exitNode := holepunch.ExitNode{ - Endpoint: handshakeData.ExitNode.Endpoint, - RelayPort: relayPort, - PublicKey: handshakeData.ExitNode.PublicKey, - SiteIds: []int{siteId}, - } - - added := holePunchManager.AddExitNode(exitNode) - if added { - logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) - } else { - logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) - } - - holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt - holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud - - // Send handshake acknowledgment back to server with retry - stopPeerSend, _ = olm.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": handshakeData.SiteId, - }, 1*time.Second) - - logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) - }) - - olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { - logger.Info("Received terminate message") - apiServer.SetTerminated(true) - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) - apiServer.ClearPeerStatuses() - network.ClearNetworkSettings() - Close() - - if globalConfig.OnTerminated != nil { - go globalConfig.OnTerminated() - } - }) - - olm.RegisterHandler("pong", func(msg websocket.WSMessage) { - logger.Debug("Received pong message") - }) - - olm.OnConnect(func() error { + olmClient.OnConnect(func() error { logger.Info("Websocket Connected") - apiServer.SetConnectionStatus(true) + o.apiServer.SetConnectionStatus(true) - if connected { + if o.connected { logger.Debug("Already connected, skipping registration") return nil } - publicKey := privateKey.PublicKey() + publicKey := o.privateKey.PublicKey() // delay for 500ms to allow for time for the hp to get processed time.Sleep(500 * time.Millisecond) - if stopRegister == nil { + if o.stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) - stopRegister, updateRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + o.stopRegister, o.updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]any{ "publicKey": publicKey.String(), "relay": !config.Holepunch, - "olmVersion": globalConfig.Version, - "olmAgent": globalConfig.Agent, + "olmVersion": o.olmConfig.Version, + "olmAgent": o.olmConfig.Agent, "orgId": config.OrgID, "userToken": userToken, }, 1*time.Second) // Invoke onRegistered callback if configured - if globalConfig.OnRegistered != nil { - go globalConfig.OnRegistered() + if o.olmConfig.OnRegistered != nil { + go o.olmConfig.OnRegistered() } } - go keepSendingPing(olm) + go o.keepSendingPing(olmClient) return nil }) - olm.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { - holePunchManager.SetToken(token) + olmClient.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { + o.holePunchManager.SetToken(token) logger.Debug("Got exit nodes for hole punching: %v", exitNodes) @@ -939,141 +380,113 @@ func StartTunnel(config TunnelConfig) { // Start hole punching using the manager logger.Info("Starting hole punch for %d exit nodes", len(exitNodes)) - if err := holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { + if err := o.holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { logger.Warn("Failed to start hole punch: %v", err) } }) - olm.OnAuthError(func(statusCode int, message string) { + olmClient.OnAuthError(func(statusCode int, message string) { logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message) - apiServer.SetTerminated(true) - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) - apiServer.ClearPeerStatuses() + o.apiServer.SetTerminated(true) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) + o.apiServer.ClearPeerStatuses() network.ClearNetworkSettings() - Close() + o.Close() - if globalConfig.OnAuthError != nil { - go globalConfig.OnAuthError(statusCode, message) + if o.olmConfig.OnAuthError != nil { + go o.olmConfig.OnAuthError(statusCode, message) } - if globalConfig.OnTerminated != nil { - go globalConfig.OnTerminated() + if o.olmConfig.OnTerminated != nil { + go o.olmConfig.OnTerminated() } }) // Connect to the WebSocket server - if err := olm.Connect(); err != nil { + if err := olmClient.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) return } - defer olm.Close() + defer func() { _ = olmClient.Close() }() + + o.olmClient = olmClient // Wait for context cancellation <-tunnelCtx.Done() logger.Info("Tunnel process context cancelled, cleaning up") } -func AddDevice(fd uint32) error { - if middleDev == nil { - return fmt.Errorf("middle device is not initialized") - } - - if tunnelConfig.MTU == 0 { - // error - return fmt.Errorf("tunnel MTU is not set") - } - - tdev, err := olmDevice.CreateTUNFromFD(fd, tunnelConfig.MTU) - - if err != nil { - return fmt.Errorf("failed to create TUN device from fd: %v", err) - } - - // if config.FileDescriptorTun == 0 { - if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? - interfaceName = realInterfaceName - } - - // Here we replace the existing TUN device in the middle device with the new one - middleDev.AddDevice(tdev) - - return nil -} - -func Close() { +func (o *Olm) Close() { // Restore original DNS configuration // we do this first to avoid any DNS issues if something else gets stuck if err := dnsOverride.RestoreDNSOverride(); err != nil { logger.Error("Failed to restore DNS: %v", err) } - // Stop hole punch manager - if holePunchManager != nil { - holePunchManager.Stop() - holePunchManager = nil + if o.holePunchManager != nil { + o.holePunchManager.Stop() + o.holePunchManager = nil } - if stopPing != nil { - select { - case <-stopPing: - // Channel already closed - default: - close(stopPing) - } + if o.stopPing != nil { + close(o.stopPing) + o.stopPing = nil } - if stopRegister != nil { - stopRegister() - stopRegister = nil + if o.stopRegister != nil { + o.stopRegister() + o.stopRegister = nil } - if updateRegister != nil { - updateRegister = nil + // Close() also calls Stop() internally + if o.peerManager != nil { + o.peerManager.Close() + o.peerManager = nil } - if peerManager != nil { - peerManager.Close() // Close() also calls Stop() internally - peerManager = nil + if o.uapiListener != nil { + _ = o.uapiListener.Close() + o.uapiListener = nil } - if uapiListener != nil { - uapiListener.Close() - uapiListener = nil + if o.logFile != nil { + _ = o.logFile.Close() + o.logFile = nil } // Stop DNS proxy first - it uses the middleDev for packet filtering - logger.Debug("Stopping DNS proxy") - if dnsProxy != nil { - dnsProxy.Stop() - dnsProxy = nil + if o.dnsProxy != nil { + logger.Debug("Stopping DNS proxy") + o.dnsProxy.Stop() + o.dnsProxy = nil } // Close MiddleDevice first - this closes the TUN and signals the closed channel // This unblocks the pump goroutine and allows WireGuard's TUN reader to exit - logger.Debug("Closing MiddleDevice") - if middleDev != nil { - middleDev.Close() - middleDev = nil + // Note: o.tdev is closed by o.middleDev.Close() since middleDev wraps it + if o.middleDev != nil { + logger.Debug("Closing MiddleDevice") + _ = o.middleDev.Close() + o.middleDev = nil } - // Note: tdev is closed by middleDev.Close() since middleDev wraps it - tdev = nil // Now close WireGuard device - its TUN reader should have exited by now - logger.Debug("Closing WireGuard device") - if dev != nil { - dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference - dev = nil + // This will call sharedBind.Close() which releases WireGuard's reference + if o.dev != nil { + logger.Debug("Closing WireGuard device") + o.dev.Close() + o.dev = nil } - // Release the hole punch reference to the shared bind - if sharedBind != nil { - // Release hole punch reference (WireGuard already released its reference via dev.Close()) - logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount()) - sharedBind.Release() - sharedBind = nil + // Release the hole punch reference to the shared bind (WireGuard already + // released its reference via dev.Close()) + if o.sharedBind != nil { + logger.Debug("Releasing shared bind (refcount before release: %d)", o.sharedBind.GetRefCount()) + _ = o.sharedBind.Release() logger.Info("Released shared UDP bind") + o.sharedBind = nil } logger.Info("Olm service stopped") @@ -1081,78 +494,85 @@ func Close() { // StopTunnel stops just the tunnel process and websocket connection // without shutting down the entire application -func StopTunnel() error { +func (o *Olm) StopTunnel() error { logger.Info("Stopping tunnel process") + if !o.tunnelRunning { + logger.Debug("Tunnel not running, nothing to stop") + return nil + } + // Cancel the tunnel context if it exists - if tunnelCancel != nil { - tunnelCancel() + if o.tunnelCancel != nil { + o.tunnelCancel() // Give it a moment to clean up time.Sleep(200 * time.Millisecond) } // Close the websocket connection - if olmClient != nil { - olmClient.Close() - olmClient = nil + if o.olmClient != nil { + _ = o.olmClient.Close() + o.olmClient = nil } - Close() + o.Close() // Reset the connected state - connected = false - tunnelRunning = false + o.connected = false + o.tunnelRunning = false // Update API server status - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) network.ClearNetworkSettings() - apiServer.ClearPeerStatuses() + o.apiServer.ClearPeerStatuses() logger.Info("Tunnel process stopped") return nil } -func StopApi() error { - if apiServer != nil { - err := apiServer.Stop() +func (o *Olm) StopApi() error { + if o.apiServer != nil { + err := o.apiServer.Stop() if err != nil { return fmt.Errorf("failed to stop API server: %w", err) } } + return nil } -func StartApi() error { - if apiServer != nil { - err := apiServer.Start() +func (o *Olm) StartApi() error { + if o.apiServer != nil { + err := o.apiServer.Start() if err != nil { return fmt.Errorf("failed to start API server: %w", err) } } + return nil } -func GetStatus() api.StatusResponse { - return apiServer.GetStatus() +func (o *Olm) GetStatus() api.StatusResponse { + return o.apiServer.GetStatus() } -func SwitchOrg(orgID string) error { +func (o *Olm) SwitchOrg(orgID string) error { logger.Info("Processing org switch request to orgId: %s", orgID) // stop the tunnel - if err := StopTunnel(); err != nil { + if err := o.StopTunnel(); err != nil { return fmt.Errorf("failed to stop existing tunnel: %w", err) } // Update the org ID in the API server and global config - apiServer.SetOrgID(orgID) + o.apiServer.SetOrgID(orgID) - tunnelConfig.OrgID = orgID + o.tunnelConfig.OrgID = orgID // Restart the tunnel with the same config but new org ID - go StartTunnel(tunnelConfig) + go o.StartTunnel(o.tunnelConfig) return nil } diff --git a/olm/peer.go b/olm/peer.go new file mode 100644 index 0000000..8acec42 --- /dev/null +++ b/olm/peer.go @@ -0,0 +1,195 @@ +package olm + +import ( + "encoding/json" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" +) + +func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { + logger.Debug("Received add-peer message: %v", msg.Data) + + if o.stopPeerSend != nil { + o.stopPeerSend() + o.stopPeerSend = nil + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var siteConfig peers.SiteConfig + if err := json.Unmarshal(jsonData, &siteConfig); err != nil { + logger.Error("Error unmarshaling add data: %v", err) + return + } + + _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it + + if err := o.peerManager.AddPeer(siteConfig); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + logger.Info("Successfully added peer for site %d", siteConfig.SiteId) +} + +func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { + logger.Debug("Received remove-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeData peers.PeerRemove + if err := json.Unmarshal(jsonData, &removeData); err != nil { + logger.Error("Error unmarshaling remove data: %v", err) + return + } + + if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil { + logger.Error("Failed to remove peer: %v", err) + return + } + + // Remove any exit nodes associated with this peer from hole punching + if o.holePunchManager != nil { + removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId) + if removed > 0 { + logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId) + } + } + + logger.Info("Successfully removed peer for site %d", removeData.SiteId) +} + +func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) { + logger.Debug("Received update-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateData peers.SiteConfig + if err := json.Unmarshal(jsonData, &updateData); err != nil { + logger.Error("Error unmarshaling update data: %v", err) + return + } + + // Get existing peer from PeerManager + existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId) + if !exists { + logger.Warn("Peer with site ID %d not found", updateData.SiteId) + return + } + + // Create updated site config by merging with existing data + siteConfig := existingPeer + + if updateData.Endpoint != "" { + siteConfig.Endpoint = updateData.Endpoint + } + if updateData.RelayEndpoint != "" { + siteConfig.RelayEndpoint = updateData.RelayEndpoint + } + if updateData.PublicKey != "" { + siteConfig.PublicKey = updateData.PublicKey + } + if updateData.ServerIP != "" { + siteConfig.ServerIP = updateData.ServerIP + } + if updateData.ServerPort != 0 { + siteConfig.ServerPort = updateData.ServerPort + } + if updateData.RemoteSubnets != nil { + siteConfig.RemoteSubnets = updateData.RemoteSubnets + } + + if err := o.peerManager.UpdatePeer(siteConfig); err != nil { + logger.Error("Failed to update peer: %v", err) + return + } + + // If the endpoint changed, trigger holepunch to refresh NAT mappings + if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { + logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) + _ = o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetInterval() + } + + logger.Info("Successfully updated peer for site %d", updateData.SiteId) +} + +func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) { + logger.Debug("Received relay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if o.peerManager == nil { + logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.RelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) + if err != nil { + logger.Error("Failed to resolve primary relay endpoint: %v", err) + return + } + + // Update HTTP server to mark this peer as using relay + o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) + + o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort) +} + +func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { + logger.Debug("Received unrelay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if o.peerManager == nil { + logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.UnRelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.Endpoint) + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + // Update HTTP server to mark this peer as using relay + o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) + + o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) +} diff --git a/olm/util.go b/olm/ping.go similarity index 89% rename from olm/util.go rename to olm/ping.go index 6bfd171..bbeee9a 100644 --- a/olm/util.go +++ b/olm/ping.go @@ -9,7 +9,7 @@ import ( ) func sendPing(olm *websocket.Client) error { - err := olm.SendMessage("olm/ping", map[string]interface{}{ + err := olm.SendMessage("olm/ping", map[string]any{ "timestamp": time.Now().Unix(), "userToken": olm.GetConfig().UserToken, }) @@ -21,7 +21,7 @@ func sendPing(olm *websocket.Client) error { return nil } -func keepSendingPing(olm *websocket.Client) { +func (o *Olm) keepSendingPing(olm *websocket.Client) { // Send ping immediately on startup if err := sendPing(olm); err != nil { logger.Error("Failed to send initial ping: %v", err) @@ -35,7 +35,7 @@ func keepSendingPing(olm *websocket.Client) { for { select { - case <-stopPing: + case <-o.stopPing: logger.Info("Stopping ping messages") return case <-ticker.C: diff --git a/olm/types.go b/olm/types.go index 9187860..77c0b5f 100644 --- a/olm/types.go +++ b/olm/types.go @@ -12,7 +12,7 @@ type WgData struct { UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses } -type GlobalConfig struct { +type OlmConfig struct { // Logging LogLevel string LogFilePath string From 1ecb97306f3acb0b7c9419bf15dc80dbc8c8323c Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 13 Jan 2026 21:38:37 -0800 Subject: [PATCH 22/27] Add back AddDevice function Former-commit-id: cae0ffa2e151d157f485ae6e52f6069a2f883fc0 --- olm/olm.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/olm/olm.go b/olm/olm.go index 6d8f7a5..2db3630 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -576,3 +576,29 @@ func (o *Olm) SwitchOrg(orgID string) error { return nil } + +func (o *Olm) AddDevice(fd uint32) error { + if o.middleDev == nil { + return fmt.Errorf("middle device is not initialized") + } + + if o.tunnelConfig.MTU == 0 { + return fmt.Errorf("tunnel MTU is not set") + } + + tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU) + if err != nil { + return fmt.Errorf("failed to create TUN device from fd: %v", err) + } + + // Update interface name if available + if realInterfaceName, err2 := tdev.Name(); err2 == nil { + o.tunnelConfig.InterfaceName = realInterfaceName + } + + // Replace the existing TUN device in the middle device with the new one + o.middleDev.AddDevice(tdev) + + logger.Info("Added device from file descriptor %d", fd) + return nil +} From 2ab979058820c3b7da74d6f3c4c9b3edb3622b24 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 11:12:10 -0800 Subject: [PATCH 23/27] Reduce the pings Former-commit-id: 5c6ad1ea75f85558195791e3129e745a70b7fa54 --- olm/ping.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/olm/ping.go b/olm/ping.go index bbeee9a..fd7706a 100644 --- a/olm/ping.go +++ b/olm/ping.go @@ -9,6 +9,7 @@ import ( ) func sendPing(olm *websocket.Client) error { + logger.Debug("Sending ping message") err := olm.SendMessage("olm/ping", map[string]any{ "timestamp": time.Now().Unix(), "userToken": olm.GetConfig().UserToken, @@ -30,7 +31,7 @@ func (o *Olm) keepSendingPing(olm *websocket.Client) { } // Set up ticker for one minute intervals - ticker := time.NewTicker(1 * time.Minute) + ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { From c86df2c0416d593c2d60c26205d023e67e7a073f Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 11:58:12 -0800 Subject: [PATCH 24/27] Refactor operation Former-commit-id: 4f09d122bb53f3a32f73624edb2ca7d1c26b3175 --- olm/connect.go | 2 +- olm/data.go | 4 +- olm/olm.go | 115 ++++++++++++++------------------------------ olm/ping.go | 56 --------------------- websocket/client.go | 58 ++++++++++++++++------ 5 files changed, 82 insertions(+), 153 deletions(-) delete mode 100644 olm/ping.go diff --git a/olm/connect.go b/olm/connect.go index 568c731..a610ea4 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -154,7 +154,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { MiddleDev: o.middleDev, LocalIP: interfaceIP, SharedBind: o.sharedBind, - WSClient: o.olmClient, + WSClient: o.websocket, APIServer: o.apiServer, }) diff --git a/olm/data.go b/olm/data.go index 9c8d33f..93e64d0 100644 --- a/olm/data.go +++ b/olm/data.go @@ -189,9 +189,9 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud // Send handshake acknowledgment back to server with retry - o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ "siteId": handshakeData.SiteId, - }, 1*time.Second) + }, 1*time.Second, 10) logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) } diff --git a/olm/olm.go b/olm/olm.go index 15e3a6a..63b53a7 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -41,7 +41,7 @@ type Olm struct { dnsProxy *dns.DNSProxy apiServer *api.API - olmClient *websocket.Client + websocket *websocket.Client holePunchManager *holepunch.Manager peerManager *peers.PeerManager // Power mode management @@ -57,10 +57,11 @@ type Olm struct { tunnelConfig TunnelConfig stopRegister func() - stopPeerSend func() updateRegister func(newData any) - stopPing chan struct{} + stopServerPing func() + + stopPeerSend func() } // initTunnelInfo creates the shared UDP socket and holepunch manager. @@ -270,9 +271,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { tunnelCtx, cancel := context.WithCancel(o.olmCtx) o.tunnelCancel = cancel - // Recreate channels for this tunnel session - o.stopPing = make(chan struct{}) - var ( id = config.ID secret = config.Secret @@ -328,6 +326,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.apiServer.SetConnectionStatus(true) + // restart the ping if we need to + if o.stopServerPing == nil { + o.stopServerPing, _ = olmClient.SendMessageInterval("olm/ping", map[string]any{ + "timestamp": time.Now().Unix(), + "userToken": olmClient.GetConfig().UserToken, + }, 30*time.Second, -1) // -1 means dont time out with the max attempts + } + if o.connected { logger.Debug("Already connected, skipping registration") return nil @@ -347,7 +353,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { "olmAgent": o.olmConfig.Agent, "orgId": config.OrgID, "userToken": userToken, - }, 1*time.Second) + }, 1*time.Second, 10) // Invoke onRegistered callback if configured if o.olmConfig.OnRegistered != nil { @@ -355,8 +361,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } } - go o.keepSendingPing(olmClient) - return nil }) @@ -416,7 +420,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } defer func() { _ = olmClient.Close() }() - o.olmClient = olmClient + o.websocket = olmClient // Wait for context cancellation <-tunnelCtx.Done() @@ -435,9 +439,9 @@ func (o *Olm) Close() { o.holePunchManager = nil } - if o.stopPing != nil { - close(o.stopPing) - o.stopPing = nil + if o.stopServerPing != nil { + o.stopServerPing() + o.stopServerPing = nil } if o.stopRegister != nil { @@ -515,9 +519,9 @@ func (o *Olm) StopTunnel() error { } // Close the websocket connection - if o.olmClient != nil { - _ = o.olmClient.Close() - o.olmClient = nil + if o.websocket != nil { + _ = o.websocket.Close() + o.websocket = nil } o.Close() @@ -602,25 +606,13 @@ func (o *Olm) SetPowerMode(mode string) error { if mode == "low" { // Low Power Mode: Close websocket and reduce monitoring frequency - if o.olmClient != nil { + if o.websocket != nil { logger.Info("Closing websocket connection for low power mode") - if err := o.olmClient.Close(); err != nil { + if err := o.websocket.Close(); err != nil { logger.Error("Error closing websocket: %v", err) } } - if o.stopPing != nil { - select { - case <-o.stopPing: - default: - close(o.stopPing) - } - } - - if o.peerManager != nil { - o.peerManager.Stop() - } - if o.originalPeerInterval == 0 && o.peerManager != nil { peerMonitor := o.peerManager.GetPeerMonitor() if peerMonitor != nil { @@ -639,10 +631,6 @@ func (o *Olm) SetPowerMode(mode string) error { } } - if o.peerManager != nil { - o.peerManager.Start() - } - o.currentPowerMode = "low" logger.Info("Switched to low power mode") @@ -669,60 +657,19 @@ func (o *Olm) SetPowerMode(mode string) error { } } - if o.peerManager != nil { - o.peerManager.Start() - } + logger.Info("Reconnecting websocket for normal power mode") - if o.tunnelConfig.ID != "" && o.tunnelConfig.Secret != "" && o.tunnelConfig.Endpoint != "" { - logger.Info("Reconnecting websocket for normal power mode") - - if o.olmClient != nil { - o.olmClient.Close() - } - - o.stopPing = make(chan struct{}) - - var ( - id = o.tunnelConfig.ID - secret = o.tunnelConfig.Secret - userToken = o.tunnelConfig.UserToken - ) - - olm, err := websocket.NewClient( - id, - secret, - userToken, - o.tunnelConfig.OrgID, - o.tunnelConfig.Endpoint, - o.tunnelConfig.PingIntervalDuration, - o.tunnelConfig.PingTimeoutDuration, - ) - if err != nil { - logger.Error("Failed to create new websocket client: %v", err) - return fmt.Errorf("failed to create new websocket client: %w", err) - } - - o.olmClient = olm - - olm.OnConnect(func() error { - logger.Info("Websocket Reconnected") - o.apiServer.SetConnectionStatus(true) - go o.keepSendingPing(olm) - return nil - }) - - if err := olm.Connect(); err != nil { + if o.websocket != nil { + if err := o.websocket.Connect(); err != nil { logger.Error("Failed to reconnect websocket: %v", err) return fmt.Errorf("failed to reconnect websocket: %w", err) } - } else { - logger.Warn("Cannot reconnect websocket: tunnel config not available") } o.currentPowerMode = "normal" logger.Info("Switched to normal power mode") } - + return nil } @@ -749,6 +696,14 @@ func (o *Olm) AddDevice(fd uint32) error { o.middleDev.AddDevice(tdev) logger.Info("Added device from file descriptor %d", fd) - + return nil } + +func GetNetworkSettingsJSON() (string, error) { + return network.GetJSON() +} + +func GetNetworkSettingsIncrementor() int { + return network.GetIncrementor() +} diff --git a/olm/ping.go b/olm/ping.go deleted file mode 100644 index fd7706a..0000000 --- a/olm/ping.go +++ /dev/null @@ -1,56 +0,0 @@ -package olm - -import ( - "time" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/network" - "github.com/fosrl/olm/websocket" -) - -func sendPing(olm *websocket.Client) error { - logger.Debug("Sending ping message") - err := olm.SendMessage("olm/ping", map[string]any{ - "timestamp": time.Now().Unix(), - "userToken": olm.GetConfig().UserToken, - }) - if err != nil { - logger.Error("Failed to send ping message: %v", err) - return err - } - logger.Debug("Sent ping message") - return nil -} - -func (o *Olm) keepSendingPing(olm *websocket.Client) { - // Send ping immediately on startup - if err := sendPing(olm); err != nil { - logger.Error("Failed to send initial ping: %v", err) - } else { - logger.Info("Sent initial ping message") - } - - // Set up ticker for one minute intervals - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for { - select { - case <-o.stopPing: - logger.Info("Stopping ping messages") - return - case <-ticker.C: - if err := sendPing(olm); err != nil { - logger.Error("Failed to send periodic ping: %v", err) - } - } - } -} - -func GetNetworkSettingsJSON() (string, error) { - return network.GetJSON() -} - -func GetNetworkSettingsIncrementor() int { - return network.GetIncrementor() -} diff --git a/websocket/client.go b/websocket/client.go index 1c5afaf..34eea35 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -77,6 +77,7 @@ type Client struct { handlersMux sync.RWMutex reconnectInterval time.Duration isConnected bool + isDisconnected bool // Flag to track if client is intentionally disconnected reconnectMux sync.RWMutex pingInterval time.Duration pingTimeout time.Duration @@ -173,6 +174,9 @@ func (c *Client) GetConfig() *Config { // Connect establishes the WebSocket connection func (c *Client) Connect() error { + if c.isDisconnected { + c.isDisconnected = false + } go c.connectWithRetry() return nil } @@ -205,9 +209,25 @@ func (c *Client) Close() error { return nil } +// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later. +func (c *Client) Disconnect() error { + c.isDisconnected = true + c.setConnected(false) + + if c.conn != nil { + c.writeMux.Lock() + c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + c.writeMux.Unlock() + err := c.conn.Close() + c.conn = nil + return err + } + return nil +} + // SendMessage sends a message through the WebSocket connection func (c *Client) SendMessage(messageType string, data interface{}) error { - if c.conn == nil { + if c.isDisconnected || c.conn == nil { return fmt.Errorf("not connected") } @@ -223,7 +243,7 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { return c.conn.WriteJSON(msg) } -func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) { +func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) { stopChan := make(chan struct{}) updateChan := make(chan interface{}) var dataMux sync.Mutex @@ -231,30 +251,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter go func() { count := 0 - maxAttempts := 10 - err := c.SendMessage(messageType, currentData) // Send immediately - if err != nil { - logger.Error("Failed to send initial message: %v", err) + send := func() { + if c.isDisconnected || c.conn == nil { + return + } + err := c.SendMessage(messageType, currentData) + if err != nil { + logger.Error("Failed to send message: %v", err) + } + count++ } - count++ + + send() // Send immediately ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: - if count >= maxAttempts { + if maxAttempts != -1 && count >= maxAttempts { logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) return } dataMux.Lock() - err = c.SendMessage(messageType, currentData) + send() dataMux.Unlock() - if err != nil { - logger.Error("Failed to send message: %v", err) - } - count++ case newData := <-updateChan: dataMux.Lock() // Merge newData into currentData if both are maps @@ -277,6 +299,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter case <-stopChan: return } + // Suspend sending if disconnected + for c.isDisconnected { + select { + case <-stopChan: + return + case <-time.After(500 * time.Millisecond): + } + } } }() return func() { @@ -587,7 +617,7 @@ func (c *Client) pingMonitor() { case <-c.done: return case <-ticker.C: - if c.conn == nil { + if c.isDisconnected || c.conn == nil { return } c.writeMux.Lock() From 3470da76fccb0f8748a8a791be8d03001bc557e0 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 12:32:29 -0800 Subject: [PATCH 25/27] Update resetting intervals Former-commit-id: 303c2dc0b78336f6d9aafad87ff3854ed8461ab7 --- olm/olm.go | 111 ++++++++++++++++++++++++--------------- olm/types.go | 2 + peers/manager.go | 33 ++++++++++-- peers/monitor/monitor.go | 63 +++++++++++++++------- peers/peer.go | 22 +++++++- 5 files changed, 165 insertions(+), 66 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index 63b53a7..6a0a26f 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -7,6 +7,7 @@ import ( "net/http" _ "net/http/pprof" "os" + "sync" "time" "github.com/fosrl/newt/bind" @@ -46,9 +47,9 @@ type Olm struct { peerManager *peers.PeerManager // Power mode management currentPowerMode string - originalPeerInterval time.Duration - originalHolepunchMinInterval time.Duration - originalHolepunchMaxInterval time.Duration + powerModeMu sync.Mutex + wakeUpTimer *time.Timer + wakeUpDebounce time.Duration olmCtx context.Context tunnelCancel context.CancelFunc @@ -133,6 +134,10 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.SetOutput(file) logFile = file } + + if config.WakeUpDebounce == 0 { + config.WakeUpDebounce = 3 * time.Second + } logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() @@ -589,22 +594,38 @@ func (o *Olm) SwitchOrg(orgID string) error { // SetPowerMode switches between normal and low power modes // In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes // In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored +// Wake-up has a 3-second debounce to prevent rapid flip-flopping; sleep is immediate func (o *Olm) SetPowerMode(mode string) error { // Validate mode if mode != "normal" && mode != "low" { return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode) } + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + // If already in the requested mode, return early if o.currentPowerMode == mode { + // Cancel any pending wake-up timer if we're already in normal mode + if mode == "normal" && o.wakeUpTimer != nil { + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } logger.Debug("Already in %s power mode", mode) return nil } - logger.Info("Switching to %s power mode", mode) - if mode == "low" { - // Low Power Mode: Close websocket and reduce monitoring frequency + // Low Power Mode: Cancel any pending wake-up and immediately go to sleep + + // Cancel pending wake-up timer if any + if o.wakeUpTimer != nil { + logger.Debug("Cancelling pending wake-up timer") + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } + + logger.Info("Switching to low power mode") if o.websocket != nil { logger.Info("Closing websocket connection for low power mode") @@ -613,14 +634,6 @@ func (o *Olm) SetPowerMode(mode string) error { } } - if o.originalPeerInterval == 0 && o.peerManager != nil { - peerMonitor := o.peerManager.GetPeerMonitor() - if peerMonitor != nil { - o.originalPeerInterval = 2 * time.Second - o.originalHolepunchMinInterval, o.originalHolepunchMaxInterval = peerMonitor.GetHolepunchIntervals() - } - } - if o.peerManager != nil { peerMonitor := o.peerManager.GetPeerMonitor() if peerMonitor != nil { @@ -629,45 +642,61 @@ func (o *Olm) SetPowerMode(mode string) error { peerMonitor.SetHolepunchInterval(lowPowerInterval, lowPowerInterval) logger.Info("Set monitoring intervals to 10 minutes for low power mode") } + o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable } o.currentPowerMode = "low" logger.Info("Switched to low power mode") } else { - // Normal Power Mode: Restore intervals and reconnect websocket + // Normal Power Mode: Start debounce timer before actually waking up - if o.peerManager != nil { - peerMonitor := o.peerManager.GetPeerMonitor() - if peerMonitor != nil { - if o.originalPeerInterval == 0 { - o.originalPeerInterval = 2 * time.Second - } - peerMonitor.SetInterval(o.originalPeerInterval) - - if o.originalHolepunchMinInterval == 0 { - o.originalHolepunchMinInterval = 2 * time.Second - } - if o.originalHolepunchMaxInterval == 0 { - o.originalHolepunchMaxInterval = 30 * time.Second - } - peerMonitor.SetHolepunchInterval(o.originalHolepunchMinInterval, o.originalHolepunchMaxInterval) - logger.Info("Restored monitoring intervals to normal (peer: %v, holepunch: %v-%v)", - o.originalPeerInterval, o.originalHolepunchMinInterval, o.originalHolepunchMaxInterval) - } + // If there's already a pending wake-up timer, don't start another + if o.wakeUpTimer != nil { + logger.Debug("Wake-up already pending, ignoring duplicate request") + return nil } - logger.Info("Reconnecting websocket for normal power mode") + logger.Info("Wake-up requested, starting %v debounce timer", o.wakeUpDebounce) - if o.websocket != nil { - if err := o.websocket.Connect(); err != nil { - logger.Error("Failed to reconnect websocket: %v", err) - return fmt.Errorf("failed to reconnect websocket: %w", err) + o.wakeUpTimer = time.AfterFunc(o.wakeUpDebounce, func() { + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + + // Clear the timer reference + o.wakeUpTimer = nil + + // Double-check we're still in low power mode (could have changed) + if o.currentPowerMode == "normal" { + logger.Debug("Already in normal mode after debounce, skipping wake-up") + return } - } - o.currentPowerMode = "normal" - logger.Info("Switched to normal power mode") + logger.Info("Debounce complete, switching to normal power mode") + + // Restore intervals and reconnect websocket + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.ResetHolepunchInterval() + peerMonitor.ResetInterval() + } + + o.peerManager.UpdateAllPeersPersistentKeepalive(5) + } + + logger.Info("Reconnecting websocket for normal power mode") + + if o.websocket != nil { + if err := o.websocket.Connect(); err != nil { + logger.Error("Failed to reconnect websocket: %v", err) + return + } + } + + o.currentPowerMode = "normal" + logger.Info("Switched to normal power mode") + }) } return nil diff --git a/olm/types.go b/olm/types.go index 77c0b5f..397eab9 100644 --- a/olm/types.go +++ b/olm/types.go @@ -23,6 +23,8 @@ type OlmConfig struct { SocketPath string Version string Agent string + + WakeUpDebounce time.Duration // Debugging PprofAddr string // Address to serve pprof on (e.g., "localhost:6060") diff --git a/peers/manager.go b/peers/manager.go index 56f3707..0566775 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -50,6 +50,8 @@ type PeerManager struct { // key is the CIDR string, value is a set of siteIds that want this IP allowedIPClaims map[string]map[int]bool APIServer *api.API + + PersistentKeepalive int } // NewPeerManager creates a new PeerManager with an internal PeerMonitor @@ -127,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -166,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { return nil } +// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once +// without recreating them. Returns a map of siteId to error for any peers that failed to update. +func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error { + pm.mu.RLock() + defer pm.mu.RUnlock() + + pm.PersistentKeepalive = interval + + errors := make(map[int]error) + + for siteId, peer := range pm.peers { + err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval) + if err != nil { + errors[siteId] = err + } + } + + if len(errors) == 0 { + return nil + } + return errors +} + func (pm *PeerManager) RemovePeer(siteId int) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -245,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error { ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId) wgConfig := promotedPeer wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } @@ -321,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -331,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId) promotedWgConfig := promotedPeer promotedWgConfig.AllowedIps = promotedOwnedIPs - if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 2bb0c80..3ac4b54 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -28,13 +28,14 @@ import ( // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { - monitors map[int]*Client - mutex sync.Mutex - running bool - interval time.Duration - timeout time.Duration - maxAttempts int - wsClient *websocket.Client + monitors map[int]*Client + mutex sync.Mutex + running bool + defaultInterval time.Duration + interval time.Duration + timeout time.Duration + maxAttempts int + wsClient *websocket.Client // Netstack fields middleDev *middleDevice.MiddleDevice @@ -50,7 +51,6 @@ type PeerMonitor struct { // Holepunch testing fields sharedBind *bind.SharedBind holepunchTester *holepunch.HolepunchTester - holepunchInterval time.Duration holepunchTimeout time.Duration holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchStatus map[int]bool // siteID -> connected status @@ -62,11 +62,13 @@ type PeerMonitor struct { holepunchFailures map[int]int // siteID -> consecutive failure count // Exponential backoff fields for holepunch monitor - holepunchMinInterval time.Duration // Minimum interval (initial) - holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) - holepunchBackoffMultiplier float64 // Multiplier for each stable check - holepunchStableCount map[int]int // siteID -> consecutive stable status count - holepunchCurrentInterval time.Duration // Current interval with backoff applied + defaultHolepunchMinInterval time.Duration // Minimum interval (initial) + defaultHolepunchMaxInterval time.Duration + holepunchMinInterval time.Duration // Minimum interval (initial) + holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) + holepunchBackoffMultiplier float64 // Multiplier for each stable check + holepunchStableCount map[int]int // siteID -> consecutive stable status count + holepunchCurrentInterval time.Duration // Current interval with backoff applied // Rapid initial test fields rapidTestInterval time.Duration // interval between rapid test attempts @@ -85,6 +87,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), + defaultInterval: 2 * time.Second, interval: 2 * time.Second, // Default check interval (faster) timeout: 3 * time.Second, maxAttempts: 3, @@ -95,7 +98,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe nsCtx: ctx, nsCancel: cancel, sharedBind: sharedBind, - holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds holepunchTimeout: 2 * time.Second, // Faster timeout holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), @@ -109,11 +111,13 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe apiServer: apiServer, wgConnectionStatus: make(map[int]bool), // Exponential backoff settings for holepunch monitor - holepunchMinInterval: 2 * time.Second, - holepunchMaxInterval: 30 * time.Second, - holepunchBackoffMultiplier: 1.5, - holepunchStableCount: make(map[int]int), - holepunchCurrentInterval: 2 * time.Second, + defaultHolepunchMinInterval: 2 * time.Second, + defaultHolepunchMaxInterval: 30 * time.Second, + holepunchMinInterval: 2 * time.Second, + holepunchMaxInterval: 30 * time.Second, + holepunchBackoffMultiplier: 1.5, + holepunchStableCount: make(map[int]int), + holepunchCurrentInterval: 2 * time.Second, } if err := pm.initNetstack(); err != nil { @@ -141,6 +145,18 @@ func (pm *PeerMonitor) SetInterval(interval time.Duration) { } } +func (pm *PeerMonitor) ResetInterval() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.interval = pm.defaultInterval + + // Update interval for all existing monitors + for _, client := range pm.monitors { + client.SetPacketInterval(pm.defaultInterval) + } +} + // SetTimeout changes the timeout for waiting for responses func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { pm.mutex.Lock() @@ -186,6 +202,15 @@ func (pm *PeerMonitor) GetHolepunchIntervals() (minInterval, maxInterval time.Du return pm.holepunchMinInterval, pm.holepunchMaxInterval } +func (pm *PeerMonitor) ResetHolepunchInterval() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.holepunchMinInterval = pm.defaultHolepunchMinInterval + pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval + pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval +} + // AddPeer adds a new peer to monitor func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint string) error { pm.mutex.Lock() diff --git a/peers/peer.go b/peers/peer.go index 9370b9d..8211fa4 100644 --- a/peers/peer.go +++ b/peers/peer.go @@ -11,7 +11,7 @@ import ( ) // ConfigurePeer sets up or updates a peer within the WireGuard device -func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error { +func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error { var endpoint string if relay && siteConfig.RelayEndpoint != "" { endpoint = formatEndpoint(siteConfig.RelayEndpoint) @@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes } configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) - configBuilder.WriteString("persistent_keepalive_interval=5\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive)) config := configBuilder.String() logger.Debug("Configuring peer with config: %s", config) @@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [ return nil } +// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it +func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error { + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) + configBuilder.WriteString("update_only=true\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval)) + + config := configBuilder.String() + logger.Debug("Updating persistent keepalive for peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err) + } + + return nil +} + func formatEndpoint(endpoint string) string { if strings.Contains(endpoint, ":") { return endpoint From 3ba171452488a9a944687617a509009edce84a83 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 16:38:40 -0800 Subject: [PATCH 26/27] Power state getting set correctly Former-commit-id: 0895156efd764c365b9196e55dcb1199b3ec9b1c --- go.mod | 2 + go.sum | 2 - olm/data.go | 2 +- olm/olm.go | 56 ++++++----- olm/peer.go | 2 +- peers/monitor/monitor.go | 198 +++++++++++++++++++------------------- peers/monitor/wgtester.go | 109 +++++++++++++-------- websocket/client.go | 61 +++++++----- 8 files changed, 239 insertions(+), 193 deletions(-) diff --git a/go.mod b/go.mod index 4f42df6..0d6bbcb 100644 --- a/go.mod +++ b/go.mod @@ -30,3 +30,5 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect ) + +replace github.com/fosrl/newt => ../newt diff --git a/go.sum b/go.sum index a543b5a..f6ca61a 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/fosrl/newt v1.8.0 h1:wIRCO2shhCpkFzsbNbb4g2LC7mPzIpp2ialNveBMJy4= -github.com/fosrl/newt v1.8.0/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= diff --git a/olm/data.go b/olm/data.go index 93e64d0..fe0b36a 100644 --- a/olm/data.go +++ b/olm/data.go @@ -186,7 +186,7 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { } o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt - o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud + o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud // Send handshake acknowledgment back to server with retry o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ diff --git a/olm/olm.go b/olm/olm.go index 6a0a26f..3f197ae 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -46,10 +46,10 @@ type Olm struct { holePunchManager *holepunch.Manager peerManager *peers.PeerManager // Power mode management - currentPowerMode string - powerModeMu sync.Mutex - wakeUpTimer *time.Timer - wakeUpDebounce time.Duration + currentPowerMode string + powerModeMu sync.Mutex + wakeUpTimer *time.Timer + wakeUpDebounce time.Duration olmCtx context.Context tunnelCancel context.CancelFunc @@ -134,7 +134,7 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.SetOutput(file) logFile = file } - + if config.WakeUpDebounce == 0 { config.WakeUpDebounce = 3 * time.Second } @@ -628,23 +628,28 @@ func (o *Olm) SetPowerMode(mode string) error { logger.Info("Switching to low power mode") if o.websocket != nil { - logger.Info("Closing websocket connection for low power mode") - if err := o.websocket.Close(); err != nil { - logger.Error("Error closing websocket: %v", err) + logger.Info("Disconnecting websocket for low power mode") + if err := o.websocket.Disconnect(); err != nil { + logger.Error("Error disconnecting websocket: %v", err) } } + lowPowerInterval := 10 * time.Minute + if o.peerManager != nil { peerMonitor := o.peerManager.GetPeerMonitor() if peerMonitor != nil { - lowPowerInterval := 10 * time.Minute - peerMonitor.SetInterval(lowPowerInterval) - peerMonitor.SetHolepunchInterval(lowPowerInterval, lowPowerInterval) + peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval) + peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval) logger.Info("Set monitoring intervals to 10 minutes for low power mode") } o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable } + if o.holePunchManager != nil { + o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval) + } + o.currentPowerMode = "low" logger.Info("Switched to low power mode") @@ -673,20 +678,8 @@ func (o *Olm) SetPowerMode(mode string) error { } logger.Info("Debounce complete, switching to normal power mode") - - // Restore intervals and reconnect websocket - if o.peerManager != nil { - peerMonitor := o.peerManager.GetPeerMonitor() - if peerMonitor != nil { - peerMonitor.ResetHolepunchInterval() - peerMonitor.ResetInterval() - } - - o.peerManager.UpdateAllPeersPersistentKeepalive(5) - } - + logger.Info("Reconnecting websocket for normal power mode") - if o.websocket != nil { if err := o.websocket.Connect(); err != nil { logger.Error("Failed to reconnect websocket: %v", err) @@ -694,6 +687,21 @@ func (o *Olm) SetPowerMode(mode string) error { } } + // Restore intervals and reconnect websocket + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.ResetPeerHolepunchInterval() + peerMonitor.ResetPeerInterval() + } + + o.peerManager.UpdateAllPeersPersistentKeepalive(5) + } + + if o.holePunchManager != nil { + o.holePunchManager.ResetServerHolepunchInterval() + } + o.currentPowerMode = "normal" logger.Info("Switched to normal power mode") }) diff --git a/olm/peer.go b/olm/peer.go index 8acec42..9bc842e 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -123,7 +123,7 @@ func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) { if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) _ = o.holePunchManager.TriggerHolePunch() - o.holePunchManager.ResetInterval() + o.holePunchManager.ResetServerHolepunchInterval() } logger.Info("Successfully updated peer for site %d", updateData.SiteId) diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 3ac4b54..387b82f 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -28,14 +28,12 @@ import ( // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { - monitors map[int]*Client - mutex sync.Mutex - running bool - defaultInterval time.Duration - interval time.Duration + monitors map[int]*Client + mutex sync.Mutex + running bool timeout time.Duration - maxAttempts int - wsClient *websocket.Client + maxAttempts int + wsClient *websocket.Client // Netstack fields middleDev *middleDevice.MiddleDevice @@ -54,7 +52,8 @@ type PeerMonitor struct { holepunchTimeout time.Duration holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchStatus map[int]bool // siteID -> connected status - holepunchStopChan chan struct{} + holepunchStopChan chan struct{} + holepunchUpdateChan chan struct{} // Relay tracking fields relayedPeers map[int]bool // siteID -> whether the peer is currently relayed @@ -87,8 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), - defaultInterval: 2 * time.Second, - interval: 2 * time.Second, // Default check interval (faster) timeout: 3 * time.Second, maxAttempts: 3, wsClient: wsClient, @@ -118,6 +115,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe holepunchBackoffMultiplier: 1.5, holepunchStableCount: make(map[int]int), holepunchCurrentInterval: 2 * time.Second, + holepunchUpdateChan: make(chan struct{}, 1), } if err := pm.initNetstack(); err != nil { @@ -133,82 +131,76 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe } // SetInterval changes how frequently peers are checked -func (pm *PeerMonitor) SetInterval(interval time.Duration) { +func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.interval = interval - // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetPacketInterval(interval) + client.SetPacketInterval(minInterval, maxInterval) } + + logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval) } -func (pm *PeerMonitor) ResetInterval() { +func (pm *PeerMonitor) ResetPeerInterval() { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.interval = pm.defaultInterval - // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetPacketInterval(pm.defaultInterval) + client.ResetPacketInterval() } } -// SetTimeout changes the timeout for waiting for responses -func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { +// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) { pm.mutex.Lock() - defer pm.mutex.Unlock() - - pm.timeout = timeout - - // Update timeout for all existing monitors - for _, client := range pm.monitors { - client.SetTimeout(timeout) - } -} - -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (pm *PeerMonitor) SetMaxAttempts(attempts int) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - pm.maxAttempts = attempts - - // Update max attempts for all existing monitors - for _, client := range pm.monitors { - client.SetMaxAttempts(attempts) - } -} - -// SetHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring -func (pm *PeerMonitor) SetHolepunchInterval(minInterval, maxInterval time.Duration) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - pm.holepunchMinInterval = minInterval pm.holepunchMaxInterval = maxInterval // Reset current interval to the new minimum pm.holepunchCurrentInterval = minInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } } -// GetHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring -func (pm *PeerMonitor) GetHolepunchIntervals() (minInterval, maxInterval time.Duration) { +// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() return pm.holepunchMinInterval, pm.holepunchMaxInterval } -func (pm *PeerMonitor) ResetHolepunchInterval() { +func (pm *PeerMonitor) ResetPeerHolepunchInterval() { pm.mutex.Lock() - defer pm.mutex.Unlock() - pm.holepunchMinInterval = pm.defaultHolepunchMinInterval pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } } // AddPeer adds a new peer to monitor @@ -226,11 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st return err } - client.SetPacketInterval(pm.interval) - client.SetTimeout(pm.timeout) - client.SetMaxAttempts(pm.maxAttempts) - client.SetMaxInterval(30 * time.Second) // Allow backoff up to 30 seconds when stable - pm.monitors[siteID] = client pm.holepunchEndpoints[siteID] = holepunchEndpoint @@ -541,6 +528,15 @@ func (pm *PeerMonitor) runHolepunchMonitor() { select { case <-pm.holepunchStopChan: return + case <-pm.holepunchUpdateChan: + // Interval settings changed, reset to minimum + pm.mutex.Lock() + pm.holepunchCurrentInterval = pm.holepunchMinInterval + currentInterval := pm.holepunchCurrentInterval + pm.mutex.Unlock() + + timer.Reset(currentInterval) + logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval) case <-timer.C: anyStatusChanged := pm.checkHolepunchEndpoints() @@ -584,7 +580,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool { anyStatusChanged := false for siteID, endpoint := range endpoints { - // logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) + logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint) result := pm.holepunchTester.TestEndpoint(endpoint, timeout) pm.mutex.Lock() @@ -733,55 +729,55 @@ func (pm *PeerMonitor) Close() { logger.Debug("PeerMonitor: Cleanup complete") } -// TestPeer tests connectivity to a specific peer -func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { - pm.mutex.Lock() - client, exists := pm.monitors[siteID] - pm.mutex.Unlock() +// // TestPeer tests connectivity to a specific peer +// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { +// pm.mutex.Lock() +// client, exists := pm.monitors[siteID] +// pm.mutex.Unlock() - if !exists { - return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) - } +// if !exists { +// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) +// } - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - defer cancel() +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// defer cancel() - connected, rtt := client.TestConnection(ctx) - return connected, rtt, nil -} +// connected, rtt := client.TestPeerConnection(ctx) +// return connected, rtt, nil +// } -// TestAllPeers tests connectivity to all peers -func (pm *PeerMonitor) TestAllPeers() map[int]struct { - Connected bool - RTT time.Duration -} { - pm.mutex.Lock() - peers := make(map[int]*Client, len(pm.monitors)) - for siteID, client := range pm.monitors { - peers[siteID] = client - } - pm.mutex.Unlock() +// // TestAllPeers tests connectivity to all peers +// func (pm *PeerMonitor) TestAllPeers() map[int]struct { +// Connected bool +// RTT time.Duration +// } { +// pm.mutex.Lock() +// peers := make(map[int]*Client, len(pm.monitors)) +// for siteID, client := range pm.monitors { +// peers[siteID] = client +// } +// pm.mutex.Unlock() - results := make(map[int]struct { - Connected bool - RTT time.Duration - }) - for siteID, client := range peers { - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - connected, rtt := client.TestConnection(ctx) - cancel() +// results := make(map[int]struct { +// Connected bool +// RTT time.Duration +// }) +// for siteID, client := range peers { +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// connected, rtt := client.TestPeerConnection(ctx) +// cancel() - results[siteID] = struct { - Connected bool - RTT time.Duration - }{ - Connected: connected, - RTT: rtt, - } - } +// results[siteID] = struct { +// Connected bool +// RTT time.Duration +// }{ +// Connected: connected, +// RTT: rtt, +// } +// } - return results -} +// return results +// } // initNetstack initializes the gvisor netstack func (pm *PeerMonitor) initNetstack() error { diff --git a/peers/monitor/wgtester.go b/peers/monitor/wgtester.go index 21f788a..f06759a 100644 --- a/peers/monitor/wgtester.go +++ b/peers/monitor/wgtester.go @@ -32,16 +32,19 @@ type Client struct { monitorLock sync.Mutex connLock sync.Mutex // Protects connection operations shutdownCh chan struct{} + updateCh chan struct{} packetInterval time.Duration timeout time.Duration maxAttempts int dialer Dialer // Exponential backoff fields - minInterval time.Duration // Minimum interval (initial) - maxInterval time.Duration // Maximum interval (cap for backoff) - backoffMultiplier float64 // Multiplier for each stable check - stableCountToBackoff int // Number of stable checks before backing off + defaultMinInterval time.Duration // Default minimum interval (initial) + defaultMaxInterval time.Duration // Default maximum interval (cap for backoff) + minInterval time.Duration // Minimum interval (initial) + maxInterval time.Duration // Maximum interval (cap for backoff) + backoffMultiplier float64 // Multiplier for each stable check + stableCountToBackoff int // Number of stable checks before backing off } // Dialer is a function that creates a connection @@ -56,43 +59,59 @@ type ConnectionStatus struct { // NewClient creates a new connection test client func NewClient(serverAddr string, dialer Dialer) (*Client, error) { return &Client{ - serverAddr: serverAddr, - shutdownCh: make(chan struct{}), - packetInterval: 2 * time.Second, - minInterval: 2 * time.Second, - maxInterval: 30 * time.Second, - backoffMultiplier: 1.5, - stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off - timeout: 500 * time.Millisecond, // Timeout for individual packets - maxAttempts: 3, // Default max attempts - dialer: dialer, + serverAddr: serverAddr, + shutdownCh: make(chan struct{}), + updateCh: make(chan struct{}, 1), + packetInterval: 2 * time.Second, + defaultMinInterval: 2 * time.Second, + defaultMaxInterval: 30 * time.Second, + minInterval: 2 * time.Second, + maxInterval: 30 * time.Second, + backoffMultiplier: 1.5, + stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off + timeout: 500 * time.Millisecond, // Timeout for individual packets + maxAttempts: 3, // Default max attempts + dialer: dialer, }, nil } // SetPacketInterval changes how frequently packets are sent in monitor mode -func (c *Client) SetPacketInterval(interval time.Duration) { - c.packetInterval = interval - c.minInterval = interval +func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) { + c.monitorLock.Lock() + c.packetInterval = minInterval + c.minInterval = minInterval + c.maxInterval = maxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() + + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } -// SetTimeout changes the timeout for waiting for responses -func (c *Client) SetTimeout(timeout time.Duration) { - c.timeout = timeout -} +func (c *Client) ResetPacketInterval() { + c.monitorLock.Lock() + c.packetInterval = c.defaultMinInterval + c.minInterval = c.defaultMinInterval + c.maxInterval = c.defaultMaxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (c *Client) SetMaxAttempts(attempts int) { - c.maxAttempts = attempts -} - -// SetMaxInterval sets the maximum backoff interval -func (c *Client) SetMaxInterval(interval time.Duration) { - c.maxInterval = interval -} - -// SetBackoffMultiplier sets the multiplier for exponential backoff -func (c *Client) SetBackoffMultiplier(multiplier float64) { - c.backoffMultiplier = multiplier + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } // UpdateServerAddr updates the server address and resets the connection @@ -146,9 +165,10 @@ func (c *Client) ensureConnection() error { return nil } -// TestConnection checks if the connection to the server is working +// TestPeerConnection checks if the connection to the server is working // Returns true if connected, false otherwise -func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { +func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) { + logger.Debug("wgtester: testing connection to peer %s", c.serverAddr) if err := c.ensureConnection(); err != nil { logger.Warn("Failed to ensure connection: %v", err) return false, 0 @@ -232,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - return c.TestConnection(ctx) + return c.TestPeerConnection(ctx) } // MonitorCallback is the function type for connection status change callbacks @@ -269,9 +289,20 @@ func (c *Client) StartMonitor(callback MonitorCallback) error { select { case <-c.shutdownCh: return + case <-c.updateCh: + // Interval settings changed, reset to minimum + c.monitorLock.Lock() + currentInterval = c.minInterval + c.monitorLock.Unlock() + + // Reset backoff state + stableCount = 0 + + timer.Reset(currentInterval) + logger.Debug("Packet interval updated, reset to %v", currentInterval) case <-timer.C: ctx, cancel := context.WithTimeout(context.Background(), c.timeout) - connected, rtt := c.TestConnection(ctx) + connected, rtt := c.TestPeerConnection(ctx) cancel() statusChanged := connected != lastConnected @@ -321,4 +352,4 @@ func (c *Client) StopMonitor() { close(c.shutdownCh) c.monitorRunning = false -} \ No newline at end of file +} diff --git a/websocket/client.go b/websocket/client.go index 34eea35..f040aa4 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -236,7 +236,7 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { Data: data, } - logger.Debug("Sending message: %s, data: %+v", messageType, data) + logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data) c.writeMux.Lock() defer c.writeMux.Unlock() @@ -258,7 +258,7 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter } err := c.SendMessage(messageType, currentData) if err != nil { - logger.Error("Failed to send message: %v", err) + logger.Error("websocket: Failed to send message: %v", err) } count++ } @@ -271,7 +271,7 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter select { case <-ticker.C: if maxAttempts != -1 && count >= maxAttempts { - logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) + logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) return } dataMux.Lock() @@ -353,7 +353,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { tlsConfig = &tls.Config{} } tlsConfig.InsecureSkipVerify = true - logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } tokenData := map[string]interface{}{ @@ -382,7 +382,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { req.Header.Set("X-CSRF-Token", "x-csrf-protection") // print out the request for debugging - logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) + logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) // Make the request client := &http.Client{} @@ -399,7 +399,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) // Return AuthError for 401/403 status codes if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { @@ -415,7 +415,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - logger.Error("Failed to decode token response.") + logger.Error("websocket: Failed to decode token response.") return "", nil, fmt.Errorf("failed to decode token response: %w", err) } @@ -427,7 +427,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { return "", nil, fmt.Errorf("received empty token from server") } - logger.Debug("Received token: %s", tokenResp.Data.Token) + logger.Debug("websocket: Received token: %s", tokenResp.Data.Token) return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil } @@ -442,7 +442,7 @@ func (c *Client) connectWithRetry() { if err != nil { // Check if this is an auth error (401/403) if authErr, ok := err.(*AuthError); ok { - logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr) + logger.Error("websocket: Authentication failed: %v. Terminating tunnel and retrying...", authErr) // Trigger auth error callback if set (this should terminate the tunnel) if c.onAuthError != nil { c.onAuthError(authErr.StatusCode, authErr.Message) @@ -452,7 +452,7 @@ func (c *Client) connectWithRetry() { continue } // For other errors (5xx, network issues), continue retrying - logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) + logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) time.Sleep(c.reconnectInterval) continue } @@ -505,7 +505,7 @@ func (c *Client) establishConnection() error { // Use new TLS configuration method if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { - logger.Info("Setting up TLS configuration for WebSocket connection") + logger.Info("websocket: Setting up TLS configuration for WebSocket connection") tlsConfig, err := c.setupTLS() if err != nil { return fmt.Errorf("failed to setup TLS configuration: %w", err) @@ -519,7 +519,7 @@ func (c *Client) establishConnection() error { dialer.TLSClientConfig = &tls.Config{} } dialer.TLSClientConfig.InsecureSkipVerify = true - logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } conn, _, err := dialer.Dial(u.String(), nil) @@ -537,7 +537,7 @@ func (c *Client) establishConnection() error { if c.onConnect != nil { if err := c.onConnect(); err != nil { - logger.Error("OnConnect callback failed: %v", err) + logger.Error("websocket: OnConnect callback failed: %v", err) } } @@ -550,9 +550,9 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Handle new separate certificate configuration if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" { - logger.Info("Loading separate certificate files for mTLS") - logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile) - logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile) + logger.Info("websocket: Loading separate certificate files for mTLS") + logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile) + logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile) // Load client certificate and key cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile) @@ -563,7 +563,7 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Load CA certificates for remote validation if specified if len(c.tlsConfig.CAFiles) > 0 { - logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles) + logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles) caCertPool := x509.NewCertPool() for _, caFile := range c.tlsConfig.CAFiles { caCert, err := os.ReadFile(caFile) @@ -589,13 +589,13 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Fallback to existing PKCS12 implementation for backward compatibility if c.tlsConfig.PKCS12File != "" { - logger.Info("Loading PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)") return c.setupPKCS12TLS() } // Legacy fallback using config.TlsClientCert if c.config.TlsClientCert != "" { - logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)") return loadClientCertificate(c.config.TlsClientCert) } @@ -630,7 +630,7 @@ func (c *Client) pingMonitor() { // Expected during shutdown return default: - logger.Error("Ping failed: %v", err) + logger.Error("websocket: Ping failed: %v", err) c.reconnect() return } @@ -663,18 +663,23 @@ func (c *Client) readPumpWithDisconnectDetection() { var msg WSMessage err := c.conn.ReadJSON(&msg) if err != nil { - // Check if we're shutting down before logging error + // Check if we're shutting down or explicitly disconnected before logging error select { case <-c.done: // Expected during shutdown, don't log as error - logger.Debug("WebSocket connection closed during shutdown") + logger.Debug("websocket: connection closed during shutdown") return default: + // Check if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: connection closed: client was explicitly disconnected") + return + } // Unexpected error during normal operation if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { - logger.Error("WebSocket read error: %v", err) + logger.Error("websocket: read error: %v", err) } else { - logger.Debug("WebSocket connection closed: %v", err) + logger.Debug("websocket: connection closed: %v", err) } return // triggers reconnect via defer } @@ -696,6 +701,12 @@ func (c *Client) reconnect() { c.conn = nil } + // Don't reconnect if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected") + return + } + // Only reconnect if we're not shutting down select { case <-c.done: @@ -713,7 +724,7 @@ func (c *Client) setConnected(status bool) { // LoadClientCertificate Helper method to load client certificates (PKCS12 format) func loadClientCertificate(p12Path string) (*tls.Config, error) { - logger.Info("Loading tls-client-cert %s", p12Path) + logger.Info("websocket: Loading tls-client-cert %s", p12Path) // Read the PKCS12 file p12Data, err := os.ReadFile(p12Path) if err != nil { From 17b75bf58f48345ac10bef2b486006cd2c9aa481 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 16:51:04 -0800 Subject: [PATCH 27/27] Dont get token each time Former-commit-id: 07dfc651f19a767a44d880fd27200bfc91a54cc7 --- websocket/client.go | 45 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/websocket/client.go b/websocket/client.go index f040aa4..b50cf31 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -88,6 +88,10 @@ type Client struct { clientType string // Type of client (e.g., "newt", "olm") tlsConfig TLSConfig configNeedsSave bool // Flag to track if config needs to be saved + token string // Cached authentication token + exitNodes []ExitNode // Cached exit nodes from token response + tokenMux sync.RWMutex // Protects token and exitNodes + forceNewToken bool // Flag to force fetching a new token on next connection } type ClientOption func(*Client) @@ -462,15 +466,25 @@ func (c *Client) connectWithRetry() { } func (c *Client) establishConnection() error { - // Get token for authentication - token, exitNodes, err := c.getToken() - if err != nil { - return fmt.Errorf("failed to get token: %w", err) - } - - if c.onTokenUpdate != nil { - c.onTokenUpdate(token, exitNodes) + // Get token for authentication - reuse cached token unless forced to get new one + c.tokenMux.Lock() + needNewToken := c.token == "" || c.forceNewToken + if needNewToken { + token, exitNodes, err := c.getToken() + if err != nil { + c.tokenMux.Unlock() + return fmt.Errorf("failed to get token: %w", err) + } + c.token = token + c.exitNodes = exitNodes + c.forceNewToken = false + + if c.onTokenUpdate != nil { + c.onTokenUpdate(token, exitNodes) + } } + token := c.token + c.tokenMux.Unlock() // Parse the base URL to determine protocol and hostname baseURL, err := url.Parse(c.baseURL) @@ -522,8 +536,20 @@ func (c *Client) establishConnection() error { logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } - conn, _, err := dialer.Dial(u.String(), nil) + conn, resp, err := dialer.Dial(u.String(), nil) if err != nil { + // Check if this is an unauthorized error (401) + if resp != nil && resp.StatusCode == http.StatusUnauthorized { + logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized") + // Force getting a new token on next reconnect attempt + c.tokenMux.Lock() + c.forceNewToken = true + c.tokenMux.Unlock() + return &AuthError{ + StatusCode: http.StatusUnauthorized, + Message: "WebSocket connection unauthorized", + } + } return fmt.Errorf("failed to connect to WebSocket: %w", err) } @@ -675,6 +701,7 @@ func (c *Client) readPumpWithDisconnectDetection() { logger.Debug("websocket: connection closed: client was explicitly disconnected") return } + // Unexpected error during normal operation if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { logger.Error("websocket: read error: %v", err)