Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
e7ef3a9b37 Bump actions/setup-go from 6.1.0 to 6.2.0
Bumps [actions/setup-go](https://github.com/actions/setup-go) from 6.1.0 to 6.2.0.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](4dc6199c7b...7a3fe6cf4c)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-version: 6.2.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Former-commit-id: 3d3f32d95f
2026-01-21 20:56:09 +00:00
34 changed files with 1361 additions and 3287 deletions

View File

@@ -196,7 +196,7 @@ jobs:
shell: bash
- name: Install Go
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v6.2.0
with:
go-version-file: go.mod

View File

@@ -18,7 +18,7 @@ jobs:
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- name: Set up Go
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v6.2.0
with:
go-version: 1.25

89
API.md
View File

@@ -46,18 +46,7 @@ Initiates a new connection request to a Pangolin server.
"tlsClientCert": "string",
"pingInterval": "3s",
"pingTimeout": "5s",
"orgId": "string",
"fingerprint": {
"username": "string",
"hostname": "string",
"platform": "string",
"osVersion": "string",
"kernelVersion": "string",
"arch": "string",
"deviceModel": "string",
"serialNumber": "string"
},
"postures": {}
"orgId": "string"
}
```
@@ -78,16 +67,6 @@ Initiates a new connection request to a Pangolin server.
- `pingInterval`: Interval for pinging the server (default: 3s)
- `pingTimeout`: Timeout for each ping (default: 5s)
- `orgId`: Organization ID to connect to
- `fingerprint`: Device fingerprinting information (should be set before connecting)
- `username`: Current username on the device
- `hostname`: Device hostname
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
- `osVersion`: Operating system version
- `kernelVersion`: Kernel version
- `arch`: System architecture (e.g., amd64, arm64)
- `deviceModel`: Device model identifier
- `serialNumber`: Device serial number
- `postures`: Device posture/security information
**Response:**
- **Status Code:** `202 Accepted`
@@ -226,56 +205,6 @@ Switches to a different organization while maintaining the connection.
---
### PUT /metadata
Updates device fingerprinting and posture information. This endpoint can be called at any time to update metadata, but it's recommended to provide this information in the initial `/connect` request or immediately before connecting.
**Request Body:**
```json
{
"fingerprint": {
"username": "string",
"hostname": "string",
"platform": "string",
"osVersion": "string",
"kernelVersion": "string",
"arch": "string",
"deviceModel": "string",
"serialNumber": "string"
},
"postures": {}
}
```
**Optional Fields:**
- `fingerprint`: Device fingerprinting information
- `username`: Current username on the device
- `hostname`: Device hostname
- `platform`: Operating system platform (macos, windows, linux, ios, android, unknown)
- `osVersion`: Operating system version
- `kernelVersion`: Kernel version
- `arch`: System architecture (e.g., amd64, arm64)
- `deviceModel`: Device model identifier
- `serialNumber`: Device serial number
- `postures`: Device posture/security information (object with arbitrary key-value pairs)
**Response:**
- **Status Code:** `200 OK`
- **Content-Type:** `application/json`
```json
{
"status": "metadata updated"
}
```
**Error Responses:**
- `405 Method Not Allowed` - Non-PUT requests
- `400 Bad Request` - Invalid JSON
**Note:** It's recommended to call this endpoint BEFORE `/connect` to ensure fingerprinting information is available during the initial connection handshake.
---
### POST /exit
Initiates a graceful shutdown of the Olm process.
@@ -318,22 +247,6 @@ Simple health check endpoint to verify the API server is running.
## Usage Examples
### Update metadata before connecting (recommended)
```bash
curl -X PUT http://localhost:9452/metadata \
-H "Content-Type: application/json" \
-d '{
"fingerprint": {
"username": "john",
"hostname": "johns-laptop",
"platform": "macos",
"osVersion": "14.2.1",
"arch": "arm64",
"deviceModel": "MacBookPro18,3"
}
}'
```
### Connect to a peer
```bash
curl -X POST http://localhost:9452/connect \

View File

@@ -5,9 +5,6 @@ all: local
local:
CGO_ENABLED=0 go build -o ./bin/olm
docker-build:
docker build -t fosrl/olm:latest .
docker-build-release:
@if [ -z "$(tag)" ]; then \
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/http"
"strconv"
"sync"
"time"
@@ -33,12 +32,7 @@ type ConnectionRequest struct {
// SwitchOrgRequest defines the structure for switching organizations
type SwitchOrgRequest struct {
OrgID string `json:"org_id"`
}
// PowerModeRequest represents a request to change power mode
type PowerModeRequest struct {
Mode string `json:"mode"` // "normal" or "low"
OrgID string `json:"orgId"`
}
// PeerStatus represents the status of a peer connection
@@ -54,18 +48,11 @@ type PeerStatus struct {
HolepunchConnected bool `json:"holepunchConnected"`
}
// OlmError holds error information from registration failures
type OlmError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// StatusResponse is returned by the status endpoint
type StatusResponse struct {
Connected bool `json:"connected"`
Registered bool `json:"registered"`
Terminated bool `json:"terminated"`
OlmError *OlmError `json:"error,omitempty"`
Version string `json:"version,omitempty"`
Agent string `json:"agent,omitempty"`
OrgID string `json:"orgId,omitempty"`
@@ -73,37 +60,25 @@ type StatusResponse struct {
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
}
type MetadataChangeRequest struct {
Fingerprint map[string]any `json:"fingerprint"`
Postures map[string]any `json:"postures"`
}
// API represents the HTTP server and its state
type API struct {
addr string
socketPath string
listener net.Listener
server *http.Server
onConnect func(ConnectionRequest) error
onSwitchOrg func(SwitchOrgRequest) error
onMetadataChange func(MetadataChangeRequest) error
onDisconnect func() error
onExit func() error
onRebind func() error
onPowerMode func(PowerModeRequest) error
addr string
socketPath string
listener net.Listener
server *http.Server
onConnect func(ConnectionRequest) error
onSwitchOrg func(SwitchOrgRequest) error
onDisconnect func() error
onExit func() error
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
isRegistered bool
isTerminated bool
olmError *OlmError
version string
agent string
orgID string
version string
agent string
orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
@@ -126,49 +101,28 @@ func NewAPISocket(socketPath string) *API {
return s
}
func NewAPIStub() *API {
s := &API{
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// SetHandlers sets the callback functions for handling API requests
func (s *API) SetHandlers(
onConnect func(ConnectionRequest) error,
onSwitchOrg func(SwitchOrgRequest) error,
onMetadataChange func(MetadataChangeRequest) error,
onDisconnect func() error,
onExit func() error,
onRebind func() error,
onPowerMode func(PowerModeRequest) error,
) {
s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg
s.onMetadataChange = onMetadataChange
s.onDisconnect = onDisconnect
s.onExit = onExit
s.onRebind = onRebind
s.onPowerMode = onPowerMode
}
// Start starts the HTTP server
func (s *API) Start() error {
if s.socketPath == "" && s.addr == "" {
return fmt.Errorf("either socketPath or addr must be provided to start the API server")
}
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/metadata", s.handleMetadataChange)
mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit)
mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/rebind", s.handleRebind)
mux.HandleFunc("/power-mode", s.handlePowerMode)
s.server = &http.Server{
Handler: mux,
@@ -206,7 +160,7 @@ func (s *API) Stop() error {
// Close the server first, which will also close the listener gracefully
if s.server != nil {
_ = s.server.Close()
s.server.Close()
}
// Clean up socket file if using Unix socket
@@ -282,27 +236,6 @@ func (s *API) SetRegistered(registered bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isRegistered = registered
// Clear any registration error when successfully registered
if registered {
s.olmError = nil
}
}
// SetOlmError sets the registration error
func (s *API) SetOlmError(code string, message string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.olmError = &OlmError{
Code: code,
Message: message,
}
}
// ClearOlmError clears any registration error
func (s *API) ClearOlmError() {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.olmError = nil
}
func (s *API) SetTerminated(terminated bool) {
@@ -412,7 +345,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
_ = json.NewEncoder(w).Encode(map[string]string{
json.NewEncoder(w).Encode(map[string]string{
"status": "connection request accepted",
})
}
@@ -425,12 +358,12 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
}
s.statusMu.RLock()
defer s.statusMu.RUnlock()
resp := StatusResponse{
Connected: s.isConnected,
Registered: s.isRegistered,
Terminated: s.isTerminated,
OlmError: s.olmError,
Version: s.version,
Agent: s.agent,
OrgID: s.orgID,
@@ -438,18 +371,8 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
NetworkSettings: network.GetSettings(),
}
s.statusMu.RUnlock()
data, err := json.Marshal(resp)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
json.NewEncoder(w).Encode(resp)
}
// handleHealth handles the /health endpoint
@@ -461,7 +384,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
json.NewEncoder(w).Encode(map[string]string{
"status": "ok",
})
}
@@ -478,7 +401,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
// Return a success response first
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
json.NewEncoder(w).Encode(map[string]string{
"status": "shutdown initiated",
})
@@ -527,7 +450,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
json.NewEncoder(w).Encode(map[string]string{
"status": "org switch request accepted",
})
}
@@ -561,43 +484,16 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
json.NewEncoder(w).Encode(map[string]string{
"status": "disconnect initiated",
})
}
// handleMetadataChange handles the /metadata endpoint
func (s *API) handleMetadataChange(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req MetadataChangeRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
logger.Info("Received metadata change request via API: %v", req)
_ = s.onMetadataChange(req)
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "metadata updated",
})
}
func (s *API) GetStatus() StatusResponse {
return StatusResponse{
Connected: s.isConnected,
Registered: s.isRegistered,
Terminated: s.isTerminated,
OlmError: s.olmError,
Version: s.version,
Agent: s.agent,
OrgID: s.orgID,
@@ -605,74 +501,3 @@ func (s *API) GetStatus() StatusResponse {
NetworkSettings: network.GetSettings(),
}
}
// handleRebind handles the /rebind endpoint
// This triggers a socket rebind, which is necessary when network connectivity changes
// (e.g., WiFi to cellular transition on macOS/iOS) and the old socket becomes stale.
func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("Received rebind request via API")
// Call the rebind handler if set
if s.onRebind != nil {
if err := s.onRebind(); err != nil {
http.Error(w, fmt.Sprintf("Rebind failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "Rebind handler not configured", http.StatusNotImplemented)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "socket rebound successfully",
})
}
// handlePowerMode handles the /power-mode endpoint
// This allows changing the power mode between "normal" and "low"
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req PowerModeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate power mode
if req.Mode != "normal" && req.Mode != "low" {
http.Error(w, "Invalid power mode: must be 'normal' or 'low'", http.StatusBadRequest)
return
}
logger.Info("Received power mode change request via API: mode=%s", req.Mode)
// Call the power mode handler if set
if s.onPowerMode != nil {
if err := s.onPowerMode(req); err != nil {
http.Error(w, fmt.Sprintf("Power mode change failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "Power mode handler not configured", http.StatusNotImplemented)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": fmt.Sprintf("power mode changed to %s successfully", req.Mode),
})
}

View File

@@ -89,7 +89,6 @@ func DefaultConfig() *OlmConfig {
PingInterval: "3s",
PingTimeout: "5s",
DisableHolepunch: false,
OverrideDNS: false,
TunnelDNS: false,
// DoNotCreateNewClient: false,
sources: make(map[string]string),
@@ -325,9 +324,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "When enabled, the client uses custom DNS servers to resolve internal resources and aliases. This overrides your system's default DNS settings. Queries that cannot be resolved as a Pangolin resource will be forwarded to your configured Upstream DNS Server. (default false)")
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings")
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "When enabled, DNS queries are routed through the tunnel for remote resolution. To ensure queries are tunneled correctly, you must define the DNS server as a Pangolin resource and enter its address as an Upstream DNS Server. (default false)")
serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "Use tunnel for DNS traffic")
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
version := serviceFlags.Bool("version", false, "Print the version")

View File

@@ -1,12 +1,9 @@
package device
import (
"io"
"net/netip"
"os"
"sync"
"sync/atomic"
"time"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun"
@@ -21,68 +18,14 @@ type FilterRule struct {
Handler PacketHandler
}
// closeAwareDevice wraps a tun.Device along with a flag
// indicating whether its Close method was called.
type closeAwareDevice struct {
isClosed atomic.Bool
// MiddleDevice wraps a TUN device with packet filtering capabilities
type MiddleDevice struct {
tun.Device
closeEventCh chan struct{}
wg sync.WaitGroup
closeOnce sync.Once
}
func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice {
return &closeAwareDevice{
Device: tunDevice,
isClosed: atomic.Bool{},
closeEventCh: make(chan struct{}),
}
}
// redirectEvents redirects the Events() method of the underlying tun.Device
// to the given channel.
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
for {
select {
case ev, ok := <-c.Device.Events():
if !ok {
return
}
if ev == tun.EventDown {
continue
}
select {
case out <- ev:
case <-c.closeEventCh:
return
}
case <-c.closeEventCh:
return
}
}
}()
}
// Close calls the underlying Device's Close method
// after setting isClosed to true.
func (c *closeAwareDevice) Close() (err error) {
c.closeOnce.Do(func() {
c.isClosed.Store(true)
close(c.closeEventCh)
err = c.Device.Close()
c.wg.Wait()
})
return err
}
func (c *closeAwareDevice) IsClosed() bool {
return c.isClosed.Load()
rules []FilterRule
mutex sync.RWMutex
readCh chan readResult
injectCh chan []byte
closed chan struct{}
}
type readResult struct {
@@ -93,136 +36,58 @@ type readResult struct {
err error
}
// MiddleDevice wraps a TUN device with packet filtering capabilities
// and supports swapping the underlying device.
type MiddleDevice struct {
devices []*closeAwareDevice
mu sync.Mutex
cond *sync.Cond
rules []FilterRule
rulesMutex sync.RWMutex
readCh chan readResult
injectCh chan []byte
closed atomic.Bool
events chan tun.Event
}
// NewMiddleDevice creates a new filtered TUN device wrapper
func NewMiddleDevice(device tun.Device) *MiddleDevice {
d := &MiddleDevice{
devices: make([]*closeAwareDevice, 0),
Device: device,
rules: make([]FilterRule, 0),
readCh: make(chan readResult, 16),
readCh: make(chan readResult),
injectCh: make(chan []byte, 100),
events: make(chan tun.Event, 16),
closed: make(chan struct{}),
}
d.cond = sync.NewCond(&d.mu)
if device != nil {
d.AddDevice(device)
}
go d.pump()
return d
}
// AddDevice adds a new underlying TUN device, closing any previous one
func (d *MiddleDevice) AddDevice(device tun.Device) {
d.mu.Lock()
if d.closed.Load() {
d.mu.Unlock()
_ = device.Close()
return
}
var toClose *closeAwareDevice
if len(d.devices) > 0 {
toClose = d.devices[len(d.devices)-1]
}
cad := newCloseAwareDevice(device)
cad.redirectEvents(d.events)
d.devices = []*closeAwareDevice{cad}
// Start pump for the new device
go d.pump(cad)
d.cond.Broadcast()
d.mu.Unlock()
if toClose != nil {
logger.Debug("MiddleDevice: Closing previous device")
if err := toClose.Close(); err != nil {
logger.Debug("MiddleDevice: Error closing previous device: %v", err)
}
}
}
func (d *MiddleDevice) pump(dev *closeAwareDevice) {
func (d *MiddleDevice) pump() {
const defaultOffset = 16
batchSize := dev.BatchSize()
logger.Debug("MiddleDevice: pump started for device")
// Recover from panic if readCh is closed while we're trying to send
defer func() {
if r := recover(); r != nil {
logger.Debug("MiddleDevice: pump recovered from panic (channel closed)")
}
}()
batchSize := d.Device.BatchSize()
logger.Debug("MiddleDevice: pump started")
for {
// Check if this device is closed
if dev.IsClosed() {
logger.Debug("MiddleDevice: pump exiting, device is closed")
return
}
// Check if MiddleDevice itself is closed
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed")
// Check closed first with priority
select {
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel")
return
default:
}
// Allocate buffers for reading
// We allocate new buffers for each read to avoid race conditions
// since we pass them to the channel
bufs := make([][]byte, batchSize)
sizes := make([]int, batchSize)
for i := range bufs {
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
}
n, err := dev.Read(bufs, sizes, defaultOffset)
n, err := d.Device.Read(bufs, sizes, defaultOffset)
// Check if device was closed during read
if dev.IsClosed() {
logger.Debug("MiddleDevice: pump exiting, device closed during read")
return
}
// Check if MiddleDevice was closed during read
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read")
return
}
// Try to send the result - check closed state first to avoid sending on closed channel
if d.closed.Load() {
logger.Debug("MiddleDevice: pump exiting, device closed before send")
// Check closed again after read returns
select {
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
return
default:
}
// Now try to send the result
select {
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
default:
// Channel full, check if we should exit
if dev.IsClosed() || d.closed.Load() {
return
}
// Try again with blocking
select {
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
case <-dev.closeEventCh:
return
}
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
return
}
if err != nil {
@@ -234,28 +99,16 @@ func (d *MiddleDevice) pump(dev *closeAwareDevice) {
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
func (d *MiddleDevice) InjectOutbound(packet []byte) {
if d.closed.Load() {
return
}
// Use defer/recover to handle panic from sending on closed channel
// This can happen during shutdown race conditions
defer func() {
if r := recover(); r != nil {
logger.Debug("MiddleDevice: InjectOutbound recovered from panic (channel closed)")
}
}()
select {
case d.injectCh <- packet:
default:
// Channel full, drop packet
logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full")
case <-d.closed:
}
}
// AddRule adds a packet filtering rule
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
d.rulesMutex.Lock()
defer d.rulesMutex.Unlock()
d.mutex.Lock()
defer d.mutex.Unlock()
d.rules = append(d.rules, FilterRule{
DestIP: destIP,
Handler: handler,
@@ -264,8 +117,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
// RemoveRule removes all rules for a given destination IP
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
d.rulesMutex.Lock()
defer d.rulesMutex.Unlock()
d.mutex.Lock()
defer d.mutex.Unlock()
newRules := make([]FilterRule, 0, len(d.rules))
for _, rule := range d.rules {
if rule.DestIP != destIP {
@@ -277,120 +130,18 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
// Close stops the device
func (d *MiddleDevice) Close() error {
if !d.closed.CompareAndSwap(false, true) {
return nil // already closed
select {
case <-d.closed:
// Already closed
return nil
default:
logger.Debug("MiddleDevice: Closing, signaling closed channel")
close(d.closed)
}
d.mu.Lock()
devices := d.devices
d.devices = nil
d.cond.Broadcast()
d.mu.Unlock()
// Close underlying devices first - this causes the pump goroutines to exit
// when their read operations return errors
var lastErr error
logger.Debug("MiddleDevice: Closing %d devices", len(devices))
for _, device := range devices {
if err := device.Close(); err != nil {
logger.Debug("MiddleDevice: Error closing device: %v", err)
lastErr = err
}
}
// Now close channels to unblock any remaining readers
// The pump should have exited by now, but close channels to be safe
close(d.readCh)
close(d.injectCh)
close(d.events)
return lastErr
}
// Events returns the events channel
func (d *MiddleDevice) Events() <-chan tun.Event {
return d.events
}
// File returns the underlying file descriptor
func (d *MiddleDevice) File() *os.File {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return nil
}
continue
}
file := dev.File()
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return file
}
}
// MTU returns the MTU of the underlying device
func (d *MiddleDevice) MTU() (int, error) {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
continue
}
mtu, err := dev.MTU()
if err == nil {
return mtu, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return 0, err
}
}
// Name returns the name of the underlying device
func (d *MiddleDevice) Name() (string, error) {
for {
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return "", io.EOF
}
continue
}
name, err := dev.Name()
if err == nil {
return name, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return "", err
}
}
// BatchSize returns the batch size
func (d *MiddleDevice) BatchSize() int {
dev := d.peekLast()
if dev == nil {
return 1
}
return dev.BatchSize()
logger.Debug("MiddleDevice: Closing underlying TUN device")
err := d.Device.Close()
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
return err
}
// extractDestIP extracts destination IP from packet (fast path)
@@ -425,239 +176,156 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
// Read intercepts packets going UP from the TUN device (towards WireGuard)
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
for {
if d.closed.Load() {
logger.Debug("MiddleDevice: Read returning io.EOF, device closed")
return 0, io.EOF
// Check if already closed first (non-blocking)
select {
case <-d.closed:
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
return 0, os.ErrClosed
default:
}
// Now block waiting for data
select {
case res := <-d.readCh:
if res.err != nil {
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
return 0, res.err
}
// Wait for a device to be available
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
// Copy packets from result to provided buffers
count := 0
for i := 0; i < res.n && i < len(bufs); i++ {
// Handle offset mismatch if necessary
// We assume the pump used defaultOffset (16)
// If caller asks for different offset, we need to shift
src := res.bufs[i]
srcOffset := res.offset
srcSize := res.sizes[i]
// Calculate where the packet data starts and ends in src
pktData := src[srcOffset : srcOffset+srcSize]
// Ensure dest buffer is large enough
if len(bufs[i]) < offset+len(pktData) {
continue // Skip if buffer too small
}
copy(bufs[i][offset:], pktData)
sizes[i] = len(pktData)
count++
}
n = count
case pkt := <-d.injectCh:
if len(bufs) == 0 {
return 0, nil
}
if len(bufs[0]) < offset+len(pkt) {
return 0, nil // Buffer too small
}
copy(bufs[0][offset:], pkt)
sizes[0] = len(pkt)
n = 1
case <-d.closed:
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
return 0, os.ErrClosed // Signal that device is closed
}
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
if len(rules) == 0 {
return n, nil
}
// Process packets and filter out handled ones
writeIdx := 0
for readIdx := 0; readIdx < n; readIdx++ {
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
continue
}
// Now block waiting for data from readCh or injectCh
select {
case res, ok := <-d.readCh:
if !ok {
// Channel closed, device is shutting down
return 0, io.EOF
}
if res.err != nil {
// Check if device was swapped
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
}
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
return 0, res.err
}
// Copy packets from result to provided buffers
count := 0
for i := 0; i < res.n && i < len(bufs); i++ {
src := res.bufs[i]
srcOffset := res.offset
srcSize := res.sizes[i]
pktData := src[srcOffset : srcOffset+srcSize]
if len(bufs[i]) < offset+len(pktData) {
continue
}
copy(bufs[i][offset:], pktData)
sizes[i] = len(pktData)
count++
}
n = count
case pkt, ok := <-d.injectCh:
if !ok {
// Channel closed, device is shutting down
return 0, io.EOF
}
if len(bufs) == 0 {
return 0, nil
}
if len(bufs[0]) < offset+len(pkt) {
return 0, nil
}
copy(bufs[0][offset:], pkt)
sizes[0] = len(pkt)
n = 1
}
// Apply filtering rules
d.rulesMutex.RLock()
rules := d.rules
d.rulesMutex.RUnlock()
if len(rules) == 0 {
return n, nil
}
// Process packets and filter out handled ones
writeIdx := 0
for readIdx := 0; readIdx < n; readIdx++ {
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
destIP, ok := extractDestIP(packet)
if !ok {
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
continue
}
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
handled = true
break
}
}
}
if !handled {
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
}
}
return writeIdx, nil
if !handled {
// Keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
}
}
return writeIdx, err
}
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
for {
if d.closed.Load() {
return 0, io.EOF
}
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
if len(rules) == 0 {
return d.Device.Write(bufs, offset)
}
// Filter packets going down
filteredBufs := make([][]byte, 0, len(bufs))
for _, buf := range bufs {
if len(buf) <= offset {
continue
}
d.rulesMutex.RLock()
rules := d.rules
d.rulesMutex.RUnlock()
packet := buf[offset:]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
filteredBufs = append(filteredBufs, buf)
continue
}
var filteredBufs [][]byte
if len(rules) == 0 {
filteredBufs = bufs
} else {
filteredBufs = make([][]byte, 0, len(bufs))
for _, buf := range bufs {
if len(buf) <= offset {
continue
}
packet := buf[offset:]
destIP, ok := extractDestIP(packet)
if !ok {
filteredBufs = append(filteredBufs, buf)
continue
}
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
handled = true
break
}
}
}
if !handled {
filteredBufs = append(filteredBufs, buf)
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
}
}
}
if len(filteredBufs) == 0 {
return len(bufs), nil
if !handled {
filteredBufs = append(filteredBufs, buf)
}
n, err := dev.Write(filteredBufs, offset)
if err == nil {
return n, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
if len(filteredBufs) == 0 {
return len(bufs), nil // All packets were handled
}
return d.Device.Write(filteredBufs, offset)
}
func (d *MiddleDevice) waitForDevice() bool {
d.mu.Lock()
defer d.mu.Unlock()
for len(d.devices) == 0 && !d.closed.Load() {
d.cond.Wait()
}
return !d.closed.Load()
}
func (d *MiddleDevice) peekLast() *closeAwareDevice {
d.mu.Lock()
defer d.mu.Unlock()
if len(d.devices) == 0 {
return nil
}
return d.devices[len(d.devices)-1]
}
// WriteToTun writes packets directly to the underlying TUN device,
// bypassing WireGuard. This is useful for sending packets that should
// appear to come from the TUN interface (e.g., DNS responses from a proxy).
// Unlike Write(), this does not go through packet filtering rules.
func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) {
for {
if d.closed.Load() {
return 0, io.EOF
}
dev := d.peekLast()
if dev == nil {
if !d.waitForDevice() {
return 0, io.EOF
}
continue
}
n, err := dev.Write(bufs, offset)
if err == nil {
return n, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
}

View File

@@ -1,50 +0,0 @@
//go:build linux
package device
import (
"net"
"os"
"runtime"
"github.com/fosrl/newt/logger"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
)
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
if runtime.GOOS == "android" { // otherwise we get a permission denied
theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
return theTun, err
}
dupTunFd, err := unix.Dup(int(tunFd))
if err != nil {
logger.Error("Unable to dup tun fd: %v", err)
return nil, err
}
err = unix.SetNonblock(dupTunFd, true)
if err != nil {
unix.Close(dupTunFd)
return nil, err
}
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
device, err := tun.CreateTUNFromFile(file, mtuInt)
if err != nil {
file.Close()
return nil, err
}
return device, nil
}
func UapiOpen(interfaceName string) (*os.File, error) {
return ipc.UAPIOpen(interfaceName)
}
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
return ipc.UAPIListen(interfaceName, fileUAPI)
}

View File

@@ -1,4 +1,4 @@
//go:build darwin
//go:build !windows
package device
@@ -26,7 +26,7 @@ func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
}
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
device, err := tun.CreateTUNFromFile(file, 0)
device, err := tun.CreateTUNFromFile(file, mtuInt)
if err != nil {
file.Close()
return nil, err

View File

@@ -12,6 +12,7 @@ import (
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/device"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
@@ -33,17 +34,18 @@ type DNSProxy struct {
ep *channel.Endpoint
proxyIP netip.Addr
upstreamDNS []string
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
mtu int
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes
tunDevice tun.Device // Direct reference to underlying TUN device for responses
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
recordStore *DNSRecordStore // Local DNS records
// Tunnel DNS fields - for sending queries over WireGuard
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
tunnelEp *channel.Endpoint
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
tunnelEp *channel.Endpoint
tunnelActivePorts map[uint16]bool
tunnelPortsLock sync.Mutex
tunnelPortsLock sync.Mutex
ctx context.Context
cancel context.CancelFunc
@@ -51,7 +53,7 @@ type DNSProxy struct {
}
// NewDNSProxy creates a new DNS proxy
func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
proxyIP, err := PickIPFromSubnet(utilitySubnet)
if err != nil {
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
@@ -66,6 +68,7 @@ func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet strin
proxy := &DNSProxy{
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
middleDevice: middleDevice,
upstreamDNS: upstreamDns,
tunnelDNS: tunnelDns,
@@ -599,12 +602,12 @@ func (p *DNSProxy) runTunnelPacketSender() {
defer p.wg.Done()
logger.Debug("DNS tunnel packet sender goroutine started")
ticker := time.NewTicker(1 * time.Millisecond)
defer ticker.Stop()
for {
// Use blocking ReadContext instead of polling - much more CPU efficient
// This will block until a packet is available or context is cancelled
pkt := p.tunnelEp.ReadContext(p.ctx)
if pkt == nil {
// Context was cancelled or endpoint closed
select {
case <-p.ctx.Done():
logger.Debug("DNS tunnel packet sender exiting")
// Drain any remaining packets
for {
@@ -615,28 +618,36 @@ func (p *DNSProxy) runTunnelPacketSender() {
pkt.DecRef()
}
return
}
case <-ticker.C:
// Try to read packets
for i := 0; i < 10; i++ {
pkt := p.tunnelEp.Read()
if pkt == nil {
break
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
p.middleDevice.InjectOutbound(buf)
}
pkt.DecRef()
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
p.middleDevice.InjectOutbound(buf)
}
pkt.DecRef()
}
}
@@ -649,12 +660,18 @@ func (p *DNSProxy) runPacketSender() {
const offset = 16
for {
// Use blocking ReadContext instead of polling - much more CPU efficient
// This will block until a packet is available or context is cancelled
pkt := p.ep.ReadContext(p.ctx)
if pkt == nil {
// Context was cancelled or endpoint closed
select {
case <-p.ctx.Done():
return
default:
}
// Read packets from netstack endpoint
pkt := p.ep.Read()
if pkt == nil {
// No packet available, small sleep to avoid busy loop
time.Sleep(1 * time.Millisecond)
continue
}
// Extract packet data as slices
@@ -677,9 +694,9 @@ func (p *DNSProxy) runPacketSender() {
pos += len(slice)
}
// Write packet to TUN device via MiddleDevice
// Write packet to TUN device
// offset=16 indicates packet data starts at position 16 in the buffer
_, err := p.middleDevice.WriteToTun([][]byte{buf}, offset)
_, err := p.tunDevice.Write([][]byte{buf}, offset)
if err != nil {
logger.Error("Failed to write DNS response to TUN: %v", err)
}

View File

@@ -48,8 +48,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
domain = domain + "."
}
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?")
@@ -86,8 +86,8 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
domain = domain + "."
}
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?")
@@ -148,7 +148,7 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
domain = dns.Fqdn(domain)
var records []net.IP
switch recordType {
@@ -205,7 +205,7 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
domain = dns.Fqdn(domain)
switch recordType {
case RecordTypeA:
@@ -322,4 +322,4 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool {
}
return matchWildcardInternal(pattern, domain, pi+1, di+1)
}
}

View File

@@ -37,7 +37,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "autoco.internal.",
expected: false,
},
// Question mark wildcard tests
{
name: "host-0?.autoco.internal matches host-01.autoco.internal",
@@ -63,7 +63,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-012.autoco.internal.",
expected: false,
},
// Combined wildcard tests
{
name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal",
@@ -83,7 +83,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-01.autoco.internal.",
expected: false,
},
// Multiple asterisks
{
name: "*.*. autoco.internal matches any.thing.autoco.internal",
@@ -97,7 +97,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "single.autoco.internal.",
expected: false,
},
// Asterisk in middle
{
name: "host-*.autoco.internal matches host-anything.autoco.internal",
@@ -111,7 +111,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-.autoco.internal.",
expected: true,
},
// Multiple question marks
{
name: "host-??.autoco.internal matches host-01.autoco.internal",
@@ -125,7 +125,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "host-1.autoco.internal.",
expected: false,
},
// Exact match (no wildcards)
{
name: "exact.autoco.internal matches exact.autoco.internal",
@@ -139,7 +139,7 @@ func TestWildcardMatching(t *testing.T) {
domain: "other.autoco.internal.",
expected: false,
},
// Edge cases
{
name: "* matches anything",
@@ -154,7 +154,7 @@ func TestWildcardMatching(t *testing.T) {
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchWildcard(tt.pattern, tt.domain)
@@ -167,21 +167,21 @@ func TestWildcardMatching(t *testing.T) {
func TestDNSRecordStoreWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard records
wildcardIP := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", wildcardIP)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Add exact record
exactIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("exact.autoco.internal", exactIP)
if err != nil {
t.Fatalf("Failed to add exact record: %v", err)
}
// Test exact match takes precedence
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
@@ -190,7 +190,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
if !ips[0].Equal(exactIP) {
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
}
// Test wildcard match
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
@@ -199,7 +199,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
if !ips[0].Equal(wildcardIP) {
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
}
// Test non-match (base domain)
ips = store.GetRecords("autoco.internal.", RecordTypeA)
if len(ips) != 0 {
@@ -209,14 +209,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
func TestDNSRecordStoreComplexWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add complex wildcard pattern
ip1 := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.host-0?.autoco.internal", ip1)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Test matching domain
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
@@ -225,13 +225,13 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
if len(ips) > 0 && !ips[0].Equal(ip1) {
t.Errorf("Expected IP %v, got %v", ip1, ips[0])
}
// Test non-matching domain (missing prefix)
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
}
// Test non-matching domain (wrong ? position)
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
@@ -241,23 +241,23 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Verify it exists
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
}
// Remove wildcard record
store.RemoveRecord("*.autoco.internal", nil)
// Verify it's gone
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
@@ -267,40 +267,40 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
store := NewDNSRecordStore()
// Add multiple wildcard patterns that don't overlap
ip1 := net.ParseIP("10.0.0.1")
ip2 := net.ParseIP("10.0.0.2")
ip3 := net.ParseIP("10.0.0.3")
err := store.AddRecord("*.prod.autoco.internal", ip1)
if err != nil {
t.Fatalf("Failed to add first wildcard: %v", err)
}
err = store.AddRecord("*.dev.autoco.internal", ip2)
if err != nil {
t.Fatalf("Failed to add second wildcard: %v", err)
}
// Add a broader wildcard that matches both
err = store.AddRecord("*.autoco.internal", ip3)
if err != nil {
t.Fatalf("Failed to add third wildcard: %v", err)
}
// Test domain matching only the prod pattern and the broad pattern
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
if len(ips) != 2 {
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
}
// Test domain matching only the dev pattern and the broad pattern
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
if len(ips) != 2 {
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
}
// Test domain matching only the broad pattern
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
@@ -310,14 +310,14 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add IPv6 wildcard record
ip := net.ParseIP("2001:db8::1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
}
// Test wildcard match for IPv6
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
if len(ips) != 1 {
@@ -330,86 +330,21 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
func TestHasRecordWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Test HasRecord with wildcard match
if !store.HasRecord("host.autoco.internal.", RecordTypeA) {
t.Error("Expected HasRecord to return true for wildcard match")
}
// Test HasRecord with non-match
if store.HasRecord("autoco.internal.", RecordTypeA) {
t.Error("Expected HasRecord to return false for base domain")
}
}
func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
store := NewDNSRecordStore()
// Add record with mixed case
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("MyHost.AutoCo.Internal", ip)
if err != nil {
t.Fatalf("Failed to add mixed case record: %v", err)
}
// Test lookup with different cases
testCases := []string{
"myhost.autoco.internal.",
"MYHOST.AUTOCO.INTERNAL.",
"MyHost.AutoCo.Internal.",
"mYhOsT.aUtOcO.iNtErNaL.",
}
for _, domain := range testCases {
ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
}
if len(ips) > 0 && !ips[0].Equal(ip) {
t.Errorf("Expected IP %v for domain %q, got %v", ip, domain, ips[0])
}
}
// Test wildcard with mixed case
wildcardIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("*.Example.Com", wildcardIP)
if err != nil {
t.Fatalf("Failed to add mixed case wildcard: %v", err)
}
wildcardTestCases := []string{
"host.example.com.",
"HOST.EXAMPLE.COM.",
"Host.Example.Com.",
"HoSt.ExAmPlE.CoM.",
}
for _, domain := range wildcardTestCases {
ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
}
if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
t.Errorf("Expected IP %v for wildcard domain %q, got %v", wildcardIP, domain, ips[0])
}
}
// Test removal with different case
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
}
// Test HasRecord with different case
if !store.HasRecord("HOST.EXAMPLE.COM.", RecordTypeA) {
t.Error("Expected HasRecord to return true for mixed case wildcard match")
}
}
}

View File

@@ -1,16 +0,0 @@
//go:build android
package olm
import "net/netip"
// SetupDNSOverride is a no-op on Android
// Android handles DNS through the VpnService API at the Java/Kotlin layer
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
return nil
}
// RestoreDNSOverride is a no-op on Android
func RestoreDNSOverride() error {
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
@@ -14,7 +15,11 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
// Uses scutil for DNS configuration
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
var err error
configurator, err = platform.NewDarwinDNSConfigurator()
if err != nil {
@@ -33,7 +38,7 @@ func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
proxyIp,
dnsProxy.GetProxyIP(),
}
logger.Info("Setting DNS servers to: %v", newDNS)

View File

@@ -1,15 +0,0 @@
//go:build ios
package olm
import "net/netip"
// SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
return nil
}
// RestoreDNSOverride is a no-op on iOS as DNS configuration is handled by the system
func RestoreDNSOverride() error {
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
@@ -14,7 +15,11 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
var err error
// Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability
@@ -27,7 +32,7 @@ func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using systemd-resolved DNS configurator")
return setDNS(proxyIp, configurator)
return setDNS(dnsProxy, configurator)
}
logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err)
@@ -35,7 +40,7 @@ func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using NetworkManager DNS configurator")
return setDNS(proxyIp, configurator)
return setDNS(dnsProxy, configurator)
}
logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err)
@@ -43,7 +48,7 @@ func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using resolvconf DNS configurator")
return setDNS(proxyIp, configurator)
return setDNS(dnsProxy, configurator)
}
logger.Warn("Failed to create resolvconf configurator: %v, falling back", err)
}
@@ -55,11 +60,11 @@ func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
}
logger.Info("Using file-based DNS configurator")
return setDNS(proxyIp, configurator)
return setDNS(dnsProxy, configurator)
}
// setDNS is a helper function to set DNS and log the results
func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error {
func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
// Get current DNS servers before changing
currentDNS, err := conf.GetCurrentDNS()
if err != nil {
@@ -70,7 +75,7 @@ func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error {
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
proxyIp,
dnsProxy.GetProxyIP(),
}
logger.Info("Setting DNS servers to: %v", newDNS)

View File

@@ -7,6 +7,7 @@ import (
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
@@ -14,7 +15,11 @@ var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
// Uses registry-based configuration (automatically extracts interface GUID)
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
var err error
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
if err != nil {
@@ -33,7 +38,7 @@ func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
proxyIp,
dnsProxy.GetProxyIP(),
}
logger.Info("Setting DNS servers to: %v", newDNS)

View File

@@ -416,4 +416,4 @@ func (d *DarwinDNSConfigurator) clearState() error {
logger.Debug("Cleared DNS state file")
return nil
}
}

5
go.mod
View File

@@ -4,7 +4,7 @@ go 1.25
require (
github.com/Microsoft/go-winio v0.6.2
github.com/fosrl/newt v1.9.0
github.com/fosrl/newt v1.8.1
github.com/godbus/dbus/v5 v5.2.2
github.com/gorilla/websocket v1.5.3
github.com/miekg/dns v1.1.70
@@ -30,6 +30,3 @@ require (
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
)
// To be used ONLY for local development
// replace github.com/fosrl/newt => ../newt

4
go.sum
View File

@@ -1,7 +1,7 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE=
github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck=
github.com/fosrl/newt v1.8.1 h1:oP3xBEISoO/TENsHccqqs6LXpoOWCt6aiP75CfIWpvk=
github.com/fosrl/newt v1.8.1/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=

13
main.go
View File

@@ -10,7 +10,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates"
olmpkg "github.com/fosrl/olm/olm"
"github.com/fosrl/olm/olm"
)
func main() {
@@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
}
// Create a new olm.Config struct and copy values from the main config
olmConfig := olmpkg.OlmConfig{
olmConfig := olm.GlobalConfig{
LogLevel: config.LogLevel,
EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr,
@@ -219,20 +219,15 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
Agent: "Olm CLI",
OnExit: cancel, // Pass cancel function directly to trigger shutdown
OnTerminated: cancel,
PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE
}
olm, err := olmpkg.Init(ctx, olmConfig)
if err != nil {
logger.Fatal("Failed to initialize olm: %v", err)
}
olm.Init(ctx, olmConfig)
if err := olm.StartApi(); err != nil {
logger.Fatal("Failed to start API server: %v", err)
}
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
tunnelConfig := olmpkg.TunnelConfig{
tunnelConfig := olm.TunnelConfig{
Endpoint: config.Endpoint,
ID: config.ID,
Secret: config.Secret,

View File

@@ -1,299 +0,0 @@
package olm
import (
"encoding/json"
"fmt"
"os"
"runtime"
"strconv"
"strings"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
olmDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns"
dnsOverride "github.com/fosrl/olm/dns/override"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
)
// OlmErrorData represents the error data sent from the server
type OlmErrorData struct {
Code string `json:"code"`
Message string `json:"message"`
}
func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring connect message")
return
}
var wgData WgData
if o.registered {
logger.Info("Already connected. Ignoring new connection request.")
return
}
if o.stopRegister != nil {
o.stopRegister()
o.stopRegister = nil
}
if o.updateRegister != nil {
o.updateRegister = nil
}
// if there is an existing tunnel then close it
if o.dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
o.dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
o.tdev, err = func() (tun.Device, error) {
if o.tunnelConfig.FileDescriptorTun != 0 {
return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU)
}
ifName := o.tunnelConfig.InterfaceName
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
ifName, err = network.FindUnusedUTUN()
if err != nil {
return nil, err
}
}
return tun.CreateTUN(ifName, o.tunnelConfig.MTU)
}()
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
return
}
// if config.FileDescriptorTun == 0 {
if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
o.tunnelConfig.InterfaceName = realInterfaceName
}
// }
// Wrap TUN device with packet filter for DNS proxy
o.middleDev = olmDevice.NewMiddleDevice(o.tdev)
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
// Use filtered device instead of raw TUN device
o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger))
if o.tunnelConfig.EnableUAPI {
fileUAPI, err := func() (*os.File, error) {
if o.tunnelConfig.FileDescriptorUAPI != 0 {
fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err)
}
return os.NewFile(uintptr(fd), ""), nil
}
return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName)
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() {
for {
conn, err := o.uapiListener.Accept()
if err != nil {
return
}
go o.dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
}
if err = o.dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
// Extract interface IP (strip CIDR notation if present)
interfaceIP := wgData.TunnelIP
if strings.Contains(interfaceIP, "/") {
interfaceIP = strings.Split(interfaceIP, "/")[0]
}
// Create and start DNS proxy
o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP)
if err != nil {
logger.Error("Failed to create DNS proxy: %v", err)
}
if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil {
logger.Error("Failed to o.tunnelConfigure interface: %v", err)
}
if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet
logger.Error("Failed to add route for utility subnet: %v", err)
}
// Create peer manager with integrated peer monitoring
o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
Device: o.dev,
DNSProxy: o.dnsProxy,
InterfaceName: o.tunnelConfig.InterfaceName,
PrivateKey: o.privateKey,
MiddleDev: o.middleDev,
LocalIP: interfaceIP,
SharedBind: o.sharedBind,
WSClient: o.websocket,
APIServer: o.apiServer,
})
for i := range wgData.Sites {
site := wgData.Sites[i]
var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {
siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
}
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
if err := o.peerManager.AddPeer(site); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
logger.Info("Configured peer %s", site.PublicKey)
}
o.peerManager.Start()
if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime
logger.Error("Failed to start DNS proxy: %v", err)
}
if o.tunnelConfig.OverrideDNS {
// Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
logger.Error("Failed to setup DNS override: %v", err)
return
}
network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()})
}
o.apiServer.SetRegistered(true)
o.registered = true
// Start ping monitor now that we are registered and connected
o.websocket.StartPingMonitor()
// Invoke onConnected callback if configured
if o.olmConfig.OnConnected != nil {
go o.olmConfig.OnConnected()
}
logger.Info("WireGuard device created.")
}
func (o *Olm) handleOlmError(msg websocket.WSMessage) {
logger.Debug("Received olm error message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring olm error message")
return
}
var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling olm error data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &errorData); err != nil {
logger.Error("Error unmarshaling olm error data: %v", err)
return
}
logger.Error("Olm error (code: %s): %s", errorData.Code, errorData.Message)
// Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
// Invoke onOlmError callback if configured
if o.olmConfig.OnOlmError != nil {
go o.olmConfig.OnOlmError(errorData.Code, errorData.Message)
}
}
func (o *Olm) handleTerminate(msg websocket.WSMessage) {
logger.Info("Received terminate message")
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring terminate message")
return
}
var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling terminate error data: %v", err)
} else {
if err := json.Unmarshal(jsonData, &errorData); err != nil {
logger.Error("Error unmarshaling terminate error data: %v", err)
} else {
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
if errorData.Code == "TERMINATED_INACTIVITY" {
logger.Info("Ignoring...")
return
}
// Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
}
}
o.apiServer.SetTerminated(true)
o.apiServer.SetConnectionStatus(false)
o.apiServer.SetRegistered(false)
o.apiServer.ClearPeerStatuses()
network.ClearNetworkSettings()
o.Close()
if o.olmConfig.OnTerminated != nil {
go o.olmConfig.OnTerminated()
}
}

View File

@@ -1,365 +0,0 @@
package olm
import (
"encoding/json"
"time"
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
)
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring add-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var addSubnetsData peers.PeerAdd
if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil {
logger.Error("Error unmarshaling add-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId)
return
}
// Add new subnets
for _, subnet := range addSubnetsData.RemoteSubnets {
if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
}
}
// Add new aliases
for _, alias := range addSubnetsData.Aliases {
if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil {
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
}
}
}
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring remove-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeSubnetsData peers.RemovePeerData
if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil {
logger.Error("Error unmarshaling remove-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId)
return
}
// Remove subnets
for _, subnet := range removeSubnetsData.RemoteSubnets {
if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
}
}
// Remove aliases
for _, alias := range removeSubnetsData.Aliases {
if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil {
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
}
}
}
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring update-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateSubnetsData peers.UpdatePeerData
if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil {
logger.Error("Error unmarshaling update-remote-subnets data: %v", err)
return
}
if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists {
logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId)
return
}
// Add new subnets BEFORE removing old ones to preserve shared subnets
// This ensures that if an old and new subnet are the same on different peers,
// the route won't be temporarily removed
for _, subnet := range updateSubnetsData.NewRemoteSubnets {
if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
}
}
// Remove old subnets after new ones are added
for _, subnet := range updateSubnetsData.OldRemoteSubnets {
if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
}
}
// Add new aliases BEFORE removing old ones to preserve shared IP addresses
// This ensures that if an old and new alias share the same IP, the IP won't be
// temporarily removed from the allowed IPs list
for _, alias := range updateSubnetsData.NewAliases {
if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil {
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
}
}
// Remove old aliases after new ones are added
for _, alias := range updateSubnetsData.OldAliases {
if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil {
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
}
}
logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId)
}
// Handler for syncing peer configuration - reconciles expected state with actual state
func (o *Olm) handleSync(msg websocket.WSMessage) {
logger.Debug("Received sync message: %v", msg.Data)
if !o.registered {
logger.Warn("Not connected, ignoring sync request")
return
}
if o.peerManager == nil {
logger.Warn("Peer manager not initialized, ignoring sync request")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
var syncData SyncData
if err := json.Unmarshal(jsonData, &syncData); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
// Sync exit nodes for hole punching
o.syncExitNodes(syncData.ExitNodes)
// Build a map of expected peers from the incoming data
expectedPeers := make(map[int]peers.SiteConfig)
for _, site := range syncData.Sites {
expectedPeers[site.SiteId] = site
}
// Get all current peers
currentPeers := o.peerManager.GetAllPeers()
currentPeerMap := make(map[int]peers.SiteConfig)
for _, peer := range currentPeers {
currentPeerMap[peer.SiteId] = peer
}
// Find peers to remove (in current but not in expected)
for siteId := range currentPeerMap {
if _, exists := expectedPeers[siteId]; !exists {
logger.Info("Sync: Removing peer for site %d (no longer in expected config)", siteId)
if err := o.peerManager.RemovePeer(siteId); err != nil {
logger.Error("Sync: Failed to remove peer %d: %v", siteId, err)
} else {
// Remove any exit nodes associated with this peer from hole punching
if o.holePunchManager != nil {
removed := o.holePunchManager.RemoveExitNodesByPeer(siteId)
if removed > 0 {
logger.Info("Sync: Removed %d exit nodes associated with peer %d from hole punch rotation", removed, siteId)
}
}
}
}
}
// Find peers to add (in expected but not in current) and peers to update
for siteId, expectedSite := range expectedPeers {
if _, exists := currentPeerMap[siteId]; !exists {
// New peer - add it using the add flow (with holepunch)
logger.Info("Sync: Adding new peer for site %d", siteId)
o.holePunchManager.TriggerHolePunch()
// // TODO: do we need to send the message to the cloud to add the peer that way?
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
// logger.Error("Sync: Failed to add peer %d: %v", siteId, err)
// } else {
// logger.Info("Sync: Successfully added peer for site %d", siteId)
// }
// add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
}, 1*time.Second, 10)
} else {
// Existing peer - check if update is needed
currentSite := currentPeerMap[siteId]
needsUpdate := false
// Check if any fields have changed
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
needsUpdate = true
}
if expectedSite.RelayEndpoint != "" && expectedSite.RelayEndpoint != currentSite.RelayEndpoint {
needsUpdate = true
}
if expectedSite.PublicKey != "" && expectedSite.PublicKey != currentSite.PublicKey {
needsUpdate = true
}
if expectedSite.ServerIP != "" && expectedSite.ServerIP != currentSite.ServerIP {
needsUpdate = true
}
if expectedSite.ServerPort != 0 && expectedSite.ServerPort != currentSite.ServerPort {
needsUpdate = true
}
// Check remote subnets
if expectedSite.RemoteSubnets != nil && !slicesEqual(expectedSite.RemoteSubnets, currentSite.RemoteSubnets) {
needsUpdate = true
}
// Check aliases
if expectedSite.Aliases != nil && !aliasesEqual(expectedSite.Aliases, currentSite.Aliases) {
needsUpdate = true
}
if needsUpdate {
logger.Info("Sync: Updating peer for site %d", siteId)
// Merge expected data with current data
siteConfig := currentSite
if expectedSite.Endpoint != "" {
siteConfig.Endpoint = expectedSite.Endpoint
}
if expectedSite.RelayEndpoint != "" {
siteConfig.RelayEndpoint = expectedSite.RelayEndpoint
}
if expectedSite.PublicKey != "" {
siteConfig.PublicKey = expectedSite.PublicKey
}
if expectedSite.ServerIP != "" {
siteConfig.ServerIP = expectedSite.ServerIP
}
if expectedSite.ServerPort != 0 {
siteConfig.ServerPort = expectedSite.ServerPort
}
if expectedSite.RemoteSubnets != nil {
siteConfig.RemoteSubnets = expectedSite.RemoteSubnets
}
if expectedSite.Aliases != nil {
siteConfig.Aliases = expectedSite.Aliases
}
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
logger.Error("Sync: Failed to update peer %d: %v", siteId, err)
} else {
// If the endpoint changed, trigger holepunch to refresh NAT mappings
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
logger.Info("Sync: Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", siteId)
o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval()
}
logger.Info("Sync: Successfully updated peer for site %d", siteId)
}
}
}
}
logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers))
}
// syncExitNodes reconciles the expected exit nodes with the current ones in the hole punch manager
func (o *Olm) syncExitNodes(expectedExitNodes []SyncExitNode) {
if o.holePunchManager == nil {
logger.Warn("Hole punch manager not initialized, skipping exit node sync")
return
}
// Build a map of expected exit nodes by endpoint
expectedExitNodeMap := make(map[string]SyncExitNode)
for _, exitNode := range expectedExitNodes {
expectedExitNodeMap[exitNode.Endpoint] = exitNode
}
// Get current exit nodes from hole punch manager
currentExitNodes := o.holePunchManager.GetExitNodes()
currentExitNodeMap := make(map[string]holepunch.ExitNode)
for _, exitNode := range currentExitNodes {
currentExitNodeMap[exitNode.Endpoint] = exitNode
}
// Find exit nodes to remove (in current but not in expected)
for endpoint := range currentExitNodeMap {
if _, exists := expectedExitNodeMap[endpoint]; !exists {
logger.Info("Sync: Removing exit node %s (no longer in expected config)", endpoint)
o.holePunchManager.RemoveExitNode(endpoint)
}
}
// Find exit nodes to add (in expected but not in current)
for endpoint, expectedExitNode := range expectedExitNodeMap {
if _, exists := currentExitNodeMap[endpoint]; !exists {
logger.Info("Sync: Adding new exit node %s", endpoint)
relayPort := expectedExitNode.RelayPort
if relayPort == 0 {
relayPort = 21820 // default relay port
}
hpExitNode := holepunch.ExitNode{
Endpoint: expectedExitNode.Endpoint,
RelayPort: relayPort,
PublicKey: expectedExitNode.PublicKey,
SiteIds: expectedExitNode.SiteIds,
}
if o.holePunchManager.AddExitNode(hpExitNode) {
logger.Info("Sync: Successfully added exit node %s", endpoint)
}
o.holePunchManager.TriggerHolePunch()
}
}
logger.Info("Sync exit nodes completed: processed %d expected exit nodes, had %d current exit nodes", len(expectedExitNodeMap), len(currentExitNodeMap))
}

1375
olm/olm.go

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +0,0 @@
//go:build !windows
package olm
import "syscall"
// closeFD closes a file descriptor in a platform-specific way
func closeFD(fd uint32) error {
return syscall.Close(int(fd))
}

View File

@@ -1,10 +0,0 @@
//go:build windows
package olm
import "syscall"
// closeFD closes a file descriptor in a platform-specific way
func closeFD(fd uint32) error {
return syscall.Close(syscall.Handle(fd))
}

View File

@@ -1,282 +0,0 @@
package olm
import (
"encoding/json"
"time"
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
)
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring add-peer message")
return
}
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var siteConfig peers.SiteConfig
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfig); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
}
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
logger.Debug("Received remove-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring remove-peer message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeData peers.PeerRemove
if err := json.Unmarshal(jsonData, &removeData); err != nil {
logger.Error("Error unmarshaling remove data: %v", err)
return
}
if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil {
logger.Error("Failed to remove peer: %v", err)
return
}
// Remove any exit nodes associated with this peer from hole punching
if o.holePunchManager != nil {
removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId)
if removed > 0 {
logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId)
}
}
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
}
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
logger.Debug("Received update-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring update-peer message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateData peers.SiteConfig
if err := json.Unmarshal(jsonData, &updateData); err != nil {
logger.Error("Error unmarshaling update data: %v", err)
return
}
// Get existing peer from PeerManager
existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId)
if !exists {
logger.Warn("Peer with site ID %d not found", updateData.SiteId)
return
}
// Create updated site config by merging with existing data
siteConfig := existingPeer
if updateData.Endpoint != "" {
siteConfig.Endpoint = updateData.Endpoint
}
if updateData.RelayEndpoint != "" {
siteConfig.RelayEndpoint = updateData.RelayEndpoint
}
if updateData.PublicKey != "" {
siteConfig.PublicKey = updateData.PublicKey
}
if updateData.ServerIP != "" {
siteConfig.ServerIP = updateData.ServerIP
}
if updateData.ServerPort != 0 {
siteConfig.ServerPort = updateData.ServerPort
}
if updateData.RemoteSubnets != nil {
siteConfig.RemoteSubnets = updateData.RemoteSubnets
}
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
// If the endpoint changed, trigger holepunch to refresh NAT mappings
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
_ = o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval()
}
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
}
func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
logger.Debug("Received relay-peer message: %v", msg.Data)
// Check if peerManager is still valid (may be nil during shutdown)
if o.peerManager == nil {
logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData peers.RelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err)
return
}
// Update HTTP server to mark this peer as using relay
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true)
o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort)
}
func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
logger.Debug("Received unrelay-peer message: %v", msg.Data)
// Check if peerManager is still valid (may be nil during shutdown)
if o.peerManager == nil {
logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData peers.UnRelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false)
o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
}
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
logger.Debug("Received peer-handshake message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring peer-handshake message")
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling handshake data: %v", err)
return
}
var handshakeData struct {
SiteId int `json:"siteId"`
ExitNode struct {
PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"`
RelayPort uint16 `json:"relayPort"`
} `json:"exitNode"`
}
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
logger.Error("Error unmarshaling handshake data: %v", err)
return
}
// Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
if exists {
logger.Warn("Peer with site ID %d already added", handshakeData.SiteId)
return
}
relayPort := handshakeData.ExitNode.RelayPort
if relayPort == 0 {
relayPort = 21820 // default relay port
}
siteId := handshakeData.SiteId
exitNode := holepunch.ExitNode{
Endpoint: handshakeData.ExitNode.Endpoint,
RelayPort: relayPort,
PublicKey: handshakeData.ExitNode.PublicKey,
SiteIds: []int{siteId},
}
added := o.holePunchManager.AddExitNode(exitNode)
if added {
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
} else {
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
}
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
}, 1*time.Second, 10)
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
}

View File

@@ -12,22 +12,9 @@ type WgData struct {
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
}
type SyncData struct {
Sites []peers.SiteConfig `json:"sites"`
ExitNodes []SyncExitNode `json:"exitNodes"`
}
type SyncExitNode struct {
Endpoint string `json:"endpoint"`
RelayPort uint16 `json:"relayPort"`
PublicKey string `json:"publicKey"`
SiteIds []int `json:"siteIds"`
}
type OlmConfig struct {
type GlobalConfig struct {
// Logging
LogLevel string
LogFilePath string
LogLevel string
// HTTP server
EnableAPI bool
@@ -36,17 +23,11 @@ type OlmConfig struct {
Version string
Agent string
WakeUpDebounce time.Duration
// Debugging
PprofAddr string // Address to serve pprof on (e.g., "localhost:6060")
// Callbacks
OnRegistered func()
OnConnected func()
OnTerminated func()
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
OnOlmError func(code string, message string) // Called when registration fails
OnExit func() // Called when exit is requested via API
}
@@ -82,8 +63,5 @@ type TunnelConfig struct {
OverrideDNS bool
TunnelDNS bool
InitialFingerprint map[string]any
InitialPostures map[string]any
DisableRelay bool
}

View File

@@ -1,47 +1,55 @@
package olm
import (
"github.com/fosrl/olm/peers"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/olm/websocket"
)
// slicesEqual compares two string slices for equality (order-independent)
func slicesEqual(a, b []string) bool {
if len(a) != len(b) {
return false
func sendPing(olm *websocket.Client) error {
err := olm.SendMessage("olm/ping", map[string]interface{}{
"timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken,
})
if err != nil {
logger.Error("Failed to send ping message: %v", err)
return err
}
// Create a map to count occurrences in slice a
counts := make(map[string]int)
for _, v := range a {
counts[v]++
}
// Check if slice b has the same elements
for _, v := range b {
counts[v]--
if counts[v] < 0 {
return false
}
}
return true
logger.Debug("Sent ping message")
return nil
}
// aliasesEqual compares two Alias slices for equality (order-independent)
func aliasesEqual(a, b []peers.Alias) bool {
if len(a) != len(b) {
return false
func keepSendingPing(olm *websocket.Client) {
// Send ping immediately on startup
if err := sendPing(olm); err != nil {
logger.Error("Failed to send initial ping: %v", err)
} else {
logger.Info("Sent initial ping message")
}
// Create a map to count occurrences in slice a (using alias+address as key)
counts := make(map[string]int)
for _, v := range a {
key := v.Alias + "|" + v.AliasAddress
counts[key]++
}
// Check if slice b has the same elements
for _, v := range b {
key := v.Alias + "|" + v.AliasAddress
counts[key]--
if counts[key] < 0 {
return false
// Set up ticker for one minute intervals
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopPing:
logger.Info("Stopping ping messages")
return
case <-ticker.C:
if err := sendPing(olm); err != nil {
logger.Error("Failed to send periodic ping: %v", err)
}
}
}
return true
}
func GetNetworkSettingsJSON() (string, error) {
return network.GetJSON()
}
func GetNetworkSettingsIncrementor() int {
return network.GetIncrementor()
}

View File

@@ -50,8 +50,6 @@ type PeerManager struct {
// key is the CIDR string, value is a set of siteIds that want this IP
allowedIPClaims map[string]map[int]bool
APIServer *api.API
PersistentKeepalive int
}
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
@@ -86,13 +84,6 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
return peer, ok
}
// GetPeerMonitor returns the internal peer monitor instance
func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor {
pm.mu.RLock()
defer pm.mu.RUnlock()
return pm.peerMonitor
}
func (pm *PeerManager) GetAllPeers() []SiteConfig {
pm.mu.RLock()
defer pm.mu.RUnlock()
@@ -129,7 +120,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
return err
}
@@ -168,29 +159,6 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
return nil
}
// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once
// without recreating them. Returns a map of siteId to error for any peers that failed to update.
func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error {
pm.mu.RLock()
defer pm.mu.RUnlock()
pm.PersistentKeepalive = interval
errors := make(map[int]error)
for siteId, peer := range pm.peers {
err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval)
if err != nil {
errors[siteId] = err
}
}
if len(errors) == 0 {
return nil
}
return errors
}
func (pm *PeerManager) RemovePeer(siteId int) error {
pm.mu.Lock()
defer pm.mu.Unlock()
@@ -270,7 +238,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
wgConfig := promotedPeer
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}
@@ -346,7 +314,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
return err
}
@@ -356,7 +324,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
promotedWgConfig := promotedPeer
promotedWgConfig.AllowedIps = promotedOwnedIPs
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}

View File

@@ -31,7 +31,8 @@ type PeerMonitor struct {
monitors map[int]*Client
mutex sync.Mutex
running bool
timeout time.Duration
interval time.Duration
timeout time.Duration
maxAttempts int
wsClient *websocket.Client
@@ -41,7 +42,7 @@ type PeerMonitor struct {
stack *stack.Stack
ep *channel.Endpoint
activePorts map[uint16]bool
portsLock sync.RWMutex
portsLock sync.Mutex
nsCtx context.Context
nsCancel context.CancelFunc
nsWg sync.WaitGroup
@@ -49,26 +50,17 @@ type PeerMonitor struct {
// Holepunch testing fields
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchInterval time.Duration
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
holepunchStopChan chan struct{}
// Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
holepunchMaxAttempts int // max consecutive failures before triggering relay
holepunchFailures map[int]int // siteID -> consecutive failure count
// Exponential backoff fields for holepunch monitor
defaultHolepunchMinInterval time.Duration // Minimum interval (initial)
defaultHolepunchMaxInterval time.Duration
holepunchMinInterval time.Duration // Minimum interval (initial)
holepunchMaxInterval time.Duration // Maximum interval (cap for backoff)
holepunchBackoffMultiplier float64 // Multiplier for each stable check
holepunchStableCount map[int]int // siteID -> consecutive stable status count
holepunchCurrentInterval time.Duration // Current interval with backoff applied
// Rapid initial test fields
rapidTestInterval time.Duration // interval between rapid test attempts
rapidTestTimeout time.Duration // timeout for each rapid test attempt
@@ -86,6 +78,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
monitors: make(map[int]*Client),
interval: 2 * time.Second, // Default check interval (faster)
timeout: 3 * time.Second,
maxAttempts: 3,
wsClient: wsClient,
@@ -95,6 +88,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
nsCtx: ctx,
nsCancel: cancel,
sharedBind: sharedBind,
holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds
holepunchTimeout: 2 * time.Second, // Faster timeout
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
@@ -107,15 +101,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
apiServer: apiServer,
wgConnectionStatus: make(map[int]bool),
// Exponential backoff settings for holepunch monitor
defaultHolepunchMinInterval: 2 * time.Second,
defaultHolepunchMaxInterval: 30 * time.Second,
holepunchMinInterval: 2 * time.Second,
holepunchMaxInterval: 30 * time.Second,
holepunchBackoffMultiplier: 1.5,
holepunchStableCount: make(map[int]int),
holepunchCurrentInterval: 2 * time.Second,
holepunchUpdateChan: make(chan struct{}, 1),
}
if err := pm.initNetstack(); err != nil {
@@ -131,75 +116,41 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
}
// SetInterval changes how frequently peers are checked
func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) {
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.interval = interval
// Update interval for all existing monitors
for _, client := range pm.monitors {
client.SetPacketInterval(minInterval, maxInterval)
client.SetPacketInterval(interval)
}
logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval)
}
func (pm *PeerMonitor) ResetPeerInterval() {
// SetTimeout changes the timeout for waiting for responses
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
// Update interval for all existing monitors
pm.timeout = timeout
// Update timeout for all existing monitors
for _, client := range pm.monitors {
client.ResetPacketInterval()
client.SetTimeout(timeout)
}
}
// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
pm.holepunchMinInterval = minInterval
pm.holepunchMaxInterval = maxInterval
// Reset current interval to the new minimum
pm.holepunchCurrentInterval = minInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) {
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
return pm.holepunchMinInterval, pm.holepunchMaxInterval
}
pm.maxAttempts = attempts
func (pm *PeerMonitor) ResetPeerHolepunchInterval() {
pm.mutex.Lock()
pm.holepunchMinInterval = pm.defaultHolepunchMinInterval
pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval
pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
// Update max attempts for all existing monitors
for _, client := range pm.monitors {
client.SetMaxAttempts(attempts)
}
}
@@ -218,6 +169,10 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
return err
}
client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts)
pm.monitors[siteID] = client
pm.holepunchEndpoints[siteID] = holepunchEndpoint
@@ -515,59 +470,31 @@ func (pm *PeerMonitor) stopHolepunchMonitor() {
logger.Info("Stopped holepunch connection monitor")
}
// runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff
// runHolepunchMonitor runs the holepunch monitoring loop
func (pm *PeerMonitor) runHolepunchMonitor() {
pm.mutex.Lock()
pm.holepunchCurrentInterval = pm.holepunchMinInterval
pm.mutex.Unlock()
ticker := time.NewTicker(pm.holepunchInterval)
defer ticker.Stop()
timer := time.NewTimer(0) // Fire immediately for initial check
defer timer.Stop()
// Do initial check immediately
pm.checkHolepunchEndpoints()
for {
select {
case <-pm.holepunchStopChan:
return
case <-pm.holepunchUpdateChan:
// Interval settings changed, reset to minimum
pm.mutex.Lock()
pm.holepunchCurrentInterval = pm.holepunchMinInterval
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
case <-timer.C:
anyStatusChanged := pm.checkHolepunchEndpoints()
pm.mutex.Lock()
if anyStatusChanged {
// Reset to minimum interval on any status change
pm.holepunchCurrentInterval = pm.holepunchMinInterval
} else {
// Apply exponential backoff when stable
newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier)
if newInterval > pm.holepunchMaxInterval {
newInterval = pm.holepunchMaxInterval
}
pm.holepunchCurrentInterval = newInterval
}
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
case <-ticker.C:
pm.checkHolepunchEndpoints()
}
}
}
// checkHolepunchEndpoints tests all holepunch endpoints
// Returns true if any endpoint's status changed
func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
func (pm *PeerMonitor) checkHolepunchEndpoints() {
pm.mutex.Lock()
// Check if we're still running before doing any work
if !pm.running {
pm.mutex.Unlock()
return false
return
}
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
for siteID, endpoint := range pm.holepunchEndpoints {
@@ -577,10 +504,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
maxAttempts := pm.holepunchMaxAttempts
pm.mutex.Unlock()
anyStatusChanged := false
for siteID, endpoint := range endpoints {
// logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint)
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
pm.mutex.Lock()
@@ -604,9 +529,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
pm.mutex.Unlock()
// Log status changes
statusChanged := !exists || previousStatus != result.Success
if statusChanged {
anyStatusChanged = true
if !exists || previousStatus != result.Success {
if result.Success {
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
} else {
@@ -639,7 +562,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
pm.mutex.Unlock()
if !stillRunning {
return anyStatusChanged // Stop processing if shutdown is in progress
return // Stop processing if shutdown is in progress
}
if !result.Success && !isRelayed && failureCount >= maxAttempts {
@@ -656,8 +579,6 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
}
}
}
return anyStatusChanged
}
// GetHolepunchStatus returns the current holepunch status for all endpoints
@@ -729,55 +650,55 @@ func (pm *PeerMonitor) Close() {
logger.Debug("PeerMonitor: Cleanup complete")
}
// // TestPeer tests connectivity to a specific peer
// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
// pm.mutex.Lock()
// client, exists := pm.monitors[siteID]
// pm.mutex.Unlock()
// TestPeer tests connectivity to a specific peer
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
pm.mutex.Lock()
client, exists := pm.monitors[siteID]
pm.mutex.Unlock()
// if !exists {
// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
// }
if !exists {
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
}
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
// defer cancel()
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
defer cancel()
// connected, rtt := client.TestPeerConnection(ctx)
// return connected, rtt, nil
// }
connected, rtt := client.TestConnection(ctx)
return connected, rtt, nil
}
// // TestAllPeers tests connectivity to all peers
// func (pm *PeerMonitor) TestAllPeers() map[int]struct {
// Connected bool
// RTT time.Duration
// } {
// pm.mutex.Lock()
// peers := make(map[int]*Client, len(pm.monitors))
// for siteID, client := range pm.monitors {
// peers[siteID] = client
// }
// pm.mutex.Unlock()
// TestAllPeers tests connectivity to all peers
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
Connected bool
RTT time.Duration
} {
pm.mutex.Lock()
peers := make(map[int]*Client, len(pm.monitors))
for siteID, client := range pm.monitors {
peers[siteID] = client
}
pm.mutex.Unlock()
// results := make(map[int]struct {
// Connected bool
// RTT time.Duration
// })
// for siteID, client := range peers {
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
// connected, rtt := client.TestPeerConnection(ctx)
// cancel()
results := make(map[int]struct {
Connected bool
RTT time.Duration
})
for siteID, client := range peers {
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
connected, rtt := client.TestConnection(ctx)
cancel()
// results[siteID] = struct {
// Connected bool
// RTT time.Duration
// }{
// Connected: connected,
// RTT: rtt,
// }
// }
results[siteID] = struct {
Connected bool
RTT time.Duration
}{
Connected: connected,
RTT: rtt,
}
}
// return results
// }
return results
}
// initNetstack initializes the gvisor netstack
func (pm *PeerMonitor) initNetstack() error {
@@ -849,9 +770,9 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool {
}
// Check if we are listening on this port
pm.portsLock.RLock()
pm.portsLock.Lock()
active := pm.activePorts[uint16(port)]
pm.portsLock.RUnlock()
pm.portsLock.Unlock()
if !active {
return false
@@ -882,12 +803,13 @@ func (pm *PeerMonitor) runPacketSender() {
defer pm.nsWg.Done()
logger.Debug("PeerMonitor: Packet sender goroutine started")
// Use a ticker to periodically check for packets without blocking indefinitely
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
// Use blocking ReadContext instead of polling - much more CPU efficient
// This will block until a packet is available or context is cancelled
pkt := pm.ep.ReadContext(pm.nsCtx)
if pkt == nil {
// Context was cancelled or endpoint closed
select {
case <-pm.nsCtx.Done():
logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets")
// Drain any remaining packets before exiting
for {
@@ -899,28 +821,36 @@ func (pm *PeerMonitor) runPacketSender() {
}
logger.Debug("PeerMonitor: Packet sender goroutine exiting")
return
}
case <-ticker.C:
// Try to read packets in batches
for i := 0; i < 10; i++ {
pkt := pm.ep.Read()
if pkt == nil {
break
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
pm.middleDev.InjectOutbound(buf)
}
pkt.DecRef()
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
pm.middleDev.InjectOutbound(buf)
}
pkt.DecRef()
}
}

View File

@@ -32,19 +32,10 @@ type Client struct {
monitorLock sync.Mutex
connLock sync.Mutex // Protects connection operations
shutdownCh chan struct{}
updateCh chan struct{}
packetInterval time.Duration
timeout time.Duration
maxAttempts int
dialer Dialer
// Exponential backoff fields
defaultMinInterval time.Duration // Default minimum interval (initial)
defaultMaxInterval time.Duration // Default maximum interval (cap for backoff)
minInterval time.Duration // Minimum interval (initial)
maxInterval time.Duration // Maximum interval (cap for backoff)
backoffMultiplier float64 // Multiplier for each stable check
stableCountToBackoff int // Number of stable checks before backing off
}
// Dialer is a function that creates a connection
@@ -59,59 +50,28 @@ type ConnectionStatus struct {
// NewClient creates a new connection test client
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
return &Client{
serverAddr: serverAddr,
shutdownCh: make(chan struct{}),
updateCh: make(chan struct{}, 1),
packetInterval: 2 * time.Second,
defaultMinInterval: 2 * time.Second,
defaultMaxInterval: 30 * time.Second,
minInterval: 2 * time.Second,
maxInterval: 30 * time.Second,
backoffMultiplier: 1.5,
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
dialer: dialer,
serverAddr: serverAddr,
shutdownCh: make(chan struct{}),
packetInterval: 2 * time.Second,
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
dialer: dialer,
}, nil
}
// SetPacketInterval changes how frequently packets are sent in monitor mode
func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) {
c.monitorLock.Lock()
c.packetInterval = minInterval
c.minInterval = minInterval
c.maxInterval = maxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// Signal the goroutine to apply the new interval if running
if monitorRunning && updateCh != nil {
select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
func (c *Client) SetPacketInterval(interval time.Duration) {
c.packetInterval = interval
}
func (c *Client) ResetPacketInterval() {
c.monitorLock.Lock()
c.packetInterval = c.defaultMinInterval
c.minInterval = c.defaultMinInterval
c.maxInterval = c.defaultMaxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// SetTimeout changes the timeout for waiting for responses
func (c *Client) SetTimeout(timeout time.Duration) {
c.timeout = timeout
}
// Signal the goroutine to apply the new interval if running
if monitorRunning && updateCh != nil {
select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (c *Client) SetMaxAttempts(attempts int) {
c.maxAttempts = attempts
}
// UpdateServerAddr updates the server address and resets the connection
@@ -165,10 +125,9 @@ func (c *Client) ensureConnection() error {
return nil
}
// TestPeerConnection checks if the connection to the server is working
// TestConnection checks if the connection to the server is working
// Returns true if connected, false otherwise
func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
// logger.Debug("wgtester: testing connection to peer %s", c.serverAddr)
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
if err := c.ensureConnection(); err != nil {
logger.Warn("Failed to ensure connection: %v", err)
return false, 0
@@ -179,9 +138,6 @@ func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
packet[4] = packetTypeRequest
// Reusable response buffer
responseBuffer := make([]byte, packetSize)
// Send multiple attempts as specified
for attempt := 0; attempt < c.maxAttempts; attempt++ {
select {
@@ -201,17 +157,20 @@ func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
return false, 0
}
// logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
_, err := c.conn.Write(packet)
if err != nil {
c.connLock.Unlock()
logger.Info("Error sending packet: %v", err)
continue
}
// logger.Debug("Successfully sent monitor packet")
// Set read deadline
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
// Wait for response
responseBuffer := make([]byte, packetSize)
n, err := c.conn.Read(responseBuffer)
c.connLock.Unlock()
@@ -252,7 +211,7 @@ func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return c.TestPeerConnection(ctx)
return c.TestConnection(ctx)
}
// MonitorCallback is the function type for connection status change callbacks
@@ -279,61 +238,28 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
go func() {
var lastConnected bool
firstRun := true
stableCount := 0
currentInterval := c.minInterval
timer := time.NewTimer(currentInterval)
defer timer.Stop()
ticker := time.NewTicker(c.packetInterval)
defer ticker.Stop()
for {
select {
case <-c.shutdownCh:
return
case <-c.updateCh:
// Interval settings changed, reset to minimum
c.monitorLock.Lock()
currentInterval = c.minInterval
c.monitorLock.Unlock()
// Reset backoff state
stableCount = 0
timer.Reset(currentInterval)
logger.Debug("Packet interval updated, reset to %v", currentInterval)
case <-timer.C:
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
connected, rtt := c.TestPeerConnection(ctx)
connected, rtt := c.TestConnection(ctx)
cancel()
statusChanged := connected != lastConnected
// Callback if status changed or it's the first check
if statusChanged || firstRun {
if connected != lastConnected || firstRun {
callback(ConnectionStatus{
Connected: connected,
RTT: rtt,
})
lastConnected = connected
firstRun = false
// Reset backoff on status change
stableCount = 0
currentInterval = c.minInterval
} else {
// Status is stable, increment counter
stableCount++
// Apply exponential backoff after stable threshold
if stableCount >= c.stableCountToBackoff {
newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier)
if newInterval > c.maxInterval {
newInterval = c.maxInterval
}
currentInterval = newInterval
}
}
// Reset timer with current interval
timer.Reset(currentInterval)
}
}
}()

View File

@@ -11,7 +11,7 @@ import (
)
// ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error {
var endpoint string
if relay && siteConfig.RelayEndpoint != "" {
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
@@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
}
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive))
configBuilder.WriteString("persistent_keepalive_interval=5\n")
config := configBuilder.String()
logger.Debug("Configuring peer with config: %s", config)
@@ -134,24 +134,6 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [
return nil
}
// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it
func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error {
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
configBuilder.WriteString("update_only=true\n")
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval))
config := configBuilder.String()
logger.Debug("Updating persistent keepalive for peer with config: %s", config)
err := dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err)
}
return nil
}
func formatEndpoint(endpoint string) string {
if strings.Contains(endpoint, ":") {
return endpoint

View File

@@ -5,7 +5,6 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -55,9 +54,8 @@ type ExitNode struct {
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
ConfigVersion int `json:"configVersion,omitempty"`
Type string `json:"type"`
Data interface{} `json:"data"`
}
// this is not json anymore
@@ -79,7 +77,6 @@ type Client struct {
handlersMux sync.RWMutex
reconnectInterval time.Duration
isConnected bool
isDisconnected bool // Flag to track if client is intentionally disconnected
reconnectMux sync.RWMutex
pingInterval time.Duration
pingTimeout time.Duration
@@ -90,19 +87,6 @@ type Client struct {
clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig
configNeedsSave bool // Flag to track if config needs to be saved
configVersion int // Latest config version received from server
configVersionMux sync.RWMutex
token string // Cached authentication token
exitNodes []ExitNode // Cached exit nodes from token response
tokenMux sync.RWMutex // Protects token and exitNodes
forceNewToken bool // Flag to force fetching a new token on next connection
processingMessage bool // Flag to track if a message is currently being processed
processingMux sync.RWMutex // Protects processingMessage
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
getPingData func() map[string]any // Callback to get additional ping data
pingStarted bool // Flag to track if ping monitor has been started
pingStartedMux sync.Mutex // Protects pingStarted
pingDone chan struct{} // Channel to stop the ping monitor independently
}
type ClientOption func(*Client)
@@ -138,13 +122,6 @@ func WithTLSConfig(config TLSConfig) ClientOption {
}
}
// WithPingDataProvider sets a callback to provide additional data for ping messages
func WithPingDataProvider(fn func() map[string]any) ClientOption {
return func(c *Client) {
c.getPingData = fn
}
}
func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
@@ -177,7 +154,6 @@ func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.
pingInterval: pingInterval,
pingTimeout: pingTimeout,
clientType: "olm",
pingDone: make(chan struct{}),
}
// Apply options before loading config
@@ -197,9 +173,6 @@ func (c *Client) GetConfig() *Config {
// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
if c.isDisconnected {
c.isDisconnected = false
}
go c.connectWithRetry()
return nil
}
@@ -232,31 +205,9 @@ func (c *Client) Close() error {
return nil
}
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
func (c *Client) Disconnect() error {
c.isDisconnected = true
c.setConnected(false)
// Stop the ping monitor
c.stopPingMonitor()
// Wait for any message currently being processed to complete
c.processingWg.Wait()
if c.conn != nil {
c.writeMux.Lock()
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.writeMux.Unlock()
err := c.conn.Close()
c.conn = nil
return err
}
return nil
}
// SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessage(messageType string, data interface{}) error {
if c.isDisconnected || c.conn == nil {
if c.conn == nil {
return fmt.Errorf("not connected")
}
@@ -265,14 +216,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
Data: data,
}
logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
logger.Debug("Sending message: %s, data: %+v", messageType, data)
c.writeMux.Lock()
defer c.writeMux.Unlock()
return c.conn.WriteJSON(msg)
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
stopChan := make(chan struct{})
updateChan := make(chan interface{})
var dataMux sync.Mutex
@@ -280,32 +231,30 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
go func() {
count := 0
maxAttempts := 10
send := func() {
if c.isDisconnected || c.conn == nil {
return
}
err := c.SendMessage(messageType, currentData)
if err != nil {
logger.Error("websocket: Failed to send message: %v", err)
}
count++
err := c.SendMessage(messageType, currentData) // Send immediately
if err != nil {
logger.Error("Failed to send initial message: %v", err)
}
send() // Send immediately
count++
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if maxAttempts != -1 && count >= maxAttempts {
logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
if count >= maxAttempts {
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return
}
dataMux.Lock()
send()
err = c.SendMessage(messageType, currentData)
dataMux.Unlock()
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case newData := <-updateChan:
dataMux.Lock()
// Merge newData into currentData if both are maps
@@ -328,14 +277,6 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
case <-stopChan:
return
}
// Suspend sending if disconnected
for c.isDisconnected {
select {
case <-stopChan:
return
case <-time.After(500 * time.Millisecond):
}
}
}
}()
return func() {
@@ -382,7 +323,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
tlsConfig = &tls.Config{}
}
tlsConfig.InsecureSkipVerify = true
logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
tokenData := map[string]interface{}{
@@ -411,7 +352,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// print out the request for debugging
logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
// Make the request
client := &http.Client{}
@@ -428,7 +369,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
// Return AuthError for 401/403 status codes
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
@@ -444,7 +385,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("websocket: Failed to decode token response.")
logger.Error("Failed to decode token response.")
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
}
@@ -456,7 +397,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
return "", nil, fmt.Errorf("received empty token from server")
}
logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
logger.Debug("Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
}
@@ -470,8 +411,7 @@ func (c *Client) connectWithRetry() {
err := c.establishConnection()
if err != nil {
// Check if this is an auth error (401/403)
var authErr *AuthError
if errors.As(err, &authErr) {
if authErr, ok := err.(*AuthError); ok {
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
// Trigger auth error callback if set (this should terminate the tunnel)
if c.onAuthError != nil {
@@ -482,7 +422,7 @@ func (c *Client) connectWithRetry() {
continue
}
// For other errors (5xx, network issues), continue retrying
logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
time.Sleep(c.reconnectInterval)
continue
}
@@ -492,25 +432,15 @@ func (c *Client) connectWithRetry() {
}
func (c *Client) establishConnection() error {
// Get token for authentication - reuse cached token unless forced to get new one
c.tokenMux.Lock()
needNewToken := c.token == "" || c.forceNewToken
if needNewToken {
token, exitNodes, err := c.getToken()
if err != nil {
c.tokenMux.Unlock()
return fmt.Errorf("failed to get token: %w", err)
}
c.token = token
c.exitNodes = exitNodes
c.forceNewToken = false
if c.onTokenUpdate != nil {
c.onTokenUpdate(token, exitNodes)
}
// Get token for authentication
token, exitNodes, err := c.getToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
if c.onTokenUpdate != nil {
c.onTokenUpdate(token, exitNodes)
}
token := c.token
c.tokenMux.Unlock()
// Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL)
@@ -545,7 +475,7 @@ func (c *Client) establishConnection() error {
// Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
logger.Info("Setting up TLS configuration for WebSocket connection")
tlsConfig, err := c.setupTLS()
if err != nil {
return fmt.Errorf("failed to setup TLS configuration: %w", err)
@@ -559,38 +489,25 @@ func (c *Client) establishConnection() error {
dialer.TLSClientConfig = &tls.Config{}
}
dialer.TLSClientConfig.InsecureSkipVerify = true
logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
conn, resp, err := dialer.Dial(u.String(), nil)
conn, _, err := dialer.Dial(u.String(), nil)
if err != nil {
// Check if this is an unauthorized error (401)
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized")
// Force getting a new token on next reconnect attempt
c.tokenMux.Lock()
c.forceNewToken = true
c.tokenMux.Unlock()
return &AuthError{
StatusCode: http.StatusUnauthorized,
Message: "WebSocket connection unauthorized",
}
}
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
c.conn = conn
c.setConnected(true)
// Note: ping monitor is NOT started here - it will be started when
// StartPingMonitor() is called after registration completes
// Start the ping monitor
go c.pingMonitor()
// Start the read pump with disconnect detection
go c.readPumpWithDisconnectDetection()
if c.onConnect != nil {
if err := c.onConnect(); err != nil {
logger.Error("websocket: OnConnect callback failed: %v", err)
logger.Error("OnConnect callback failed: %v", err)
}
}
@@ -603,9 +520,9 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Handle new separate certificate configuration
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
logger.Info("websocket: Loading separate certificate files for mTLS")
logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
logger.Info("Loading separate certificate files for mTLS")
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
// Load client certificate and key
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
@@ -616,7 +533,7 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Load CA certificates for remote validation if specified
if len(c.tlsConfig.CAFiles) > 0 {
logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
caCertPool := x509.NewCertPool()
for _, caFile := range c.tlsConfig.CAFiles {
caCert, err := os.ReadFile(caFile)
@@ -642,13 +559,13 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Fallback to existing PKCS12 implementation for backward compatibility
if c.tlsConfig.PKCS12File != "" {
logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
return c.setupPKCS12TLS()
}
// Legacy fallback using config.TlsClientCert
if c.config.TlsClientCert != "" {
logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
return loadClientCertificate(c.config.TlsClientCert)
}
@@ -660,59 +577,6 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
return loadClientCertificate(c.tlsConfig.PKCS12File)
}
// sendPing sends a single ping message
func (c *Client) sendPing() {
if c.isDisconnected || c.conn == nil {
return
}
// Skip ping if a message is currently being processed
c.processingMux.RLock()
isProcessing := c.processingMessage
c.processingMux.RUnlock()
if isProcessing {
logger.Debug("websocket: Skipping ping, message is being processed")
return
}
// Send application-level ping with config version
c.configVersionMux.RLock()
configVersion := c.configVersion
c.configVersionMux.RUnlock()
pingData := map[string]any{
"timestamp": time.Now().Unix(),
"userToken": c.config.UserToken,
}
if c.getPingData != nil {
for k, v := range c.getPingData() {
pingData[k] = v
}
}
pingMsg := WSMessage{
Type: "olm/ping",
Data: pingData,
ConfigVersion: configVersion,
}
logger.Debug("websocket: Sending ping: %+v", pingMsg)
c.writeMux.Lock()
err := c.conn.WriteJSON(pingMsg)
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("websocket: Ping failed: %v", err)
c.reconnect()
return
}
}
}
// pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) pingMonitor() {
ticker := time.NewTicker(c.pingInterval)
@@ -722,65 +586,29 @@ func (c *Client) pingMonitor() {
select {
case <-c.done:
return
case <-c.pingDone:
return
case <-ticker.C:
c.sendPing()
if c.conn == nil {
return
}
c.writeMux.Lock()
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("Ping failed: %v", err)
c.reconnect()
return
}
}
}
}
}
// StartPingMonitor starts the ping monitor goroutine.
// This should be called after the client is registered and connected.
// It is safe to call multiple times - only the first call will start the monitor.
func (c *Client) StartPingMonitor() {
c.pingStartedMux.Lock()
defer c.pingStartedMux.Unlock()
if c.pingStarted {
return
}
c.pingStarted = true
// Create a new pingDone channel for this ping monitor instance
c.pingDone = make(chan struct{})
// Send an initial ping immediately
go func() {
c.sendPing()
c.pingMonitor()
}()
}
// stopPingMonitor stops the ping monitor goroutine if it's running.
func (c *Client) stopPingMonitor() {
c.pingStartedMux.Lock()
defer c.pingStartedMux.Unlock()
if !c.pingStarted {
return
}
// Close the pingDone channel to stop the monitor
close(c.pingDone)
c.pingStarted = false
}
// GetConfigVersion returns the current config version
func (c *Client) GetConfigVersion() int {
c.configVersionMux.RLock()
defer c.configVersionMux.RUnlock()
return c.configVersion
}
// setConfigVersion updates the config version if the new version is higher
func (c *Client) setConfigVersion(version int) {
c.configVersionMux.Lock()
defer c.configVersionMux.Unlock()
logger.Debug("websocket: setting config version to %d", version)
c.configVersion = version
}
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
func (c *Client) readPumpWithDisconnectDetection() {
defer func() {
@@ -805,47 +633,26 @@ func (c *Client) readPumpWithDisconnectDetection() {
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err != nil {
// Check if we're shutting down or explicitly disconnected before logging error
// Check if we're shutting down before logging error
select {
case <-c.done:
// Expected during shutdown, don't log as error
logger.Debug("websocket: connection closed during shutdown")
logger.Debug("WebSocket connection closed during shutdown")
return
default:
// Check if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: connection closed: client was explicitly disconnected")
return
}
// Unexpected error during normal operation
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Error("websocket: read error: %v", err)
logger.Error("WebSocket read error: %v", err)
} else {
logger.Debug("websocket: connection closed: %v", err)
logger.Debug("WebSocket connection closed: %v", err)
}
return // triggers reconnect via defer
}
}
// Update config version from incoming message
c.setConfigVersion(msg.ConfigVersion)
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
// Mark that we're processing a message
c.processingMux.Lock()
c.processingMessage = true
c.processingMux.Unlock()
c.processingWg.Add(1)
handler(msg)
// Mark that we're done processing
c.processingWg.Done()
c.processingMux.Lock()
c.processingMessage = false
c.processingMux.Unlock()
}
c.handlersMux.RUnlock()
}
@@ -859,12 +666,6 @@ func (c *Client) reconnect() {
c.conn = nil
}
// Don't reconnect if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
return
}
// Only reconnect if we're not shutting down
select {
case <-c.done:
@@ -882,7 +683,7 @@ func (c *Client) setConnected(status bool) {
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
func loadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("websocket: Loading tls-client-cert %s", p12Path)
logger.Info("Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path)
if err != nil {