diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 989e68c..c44a2d7 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -11,7 +11,9 @@ permissions: on: push: tags: - - "*" + - "[0-9]+.[0-9]+.[0-9]+" + - "[0-9]+.[0-9]+.[0-9]+-rc.[0-9]+" + workflow_dispatch: inputs: version: @@ -273,7 +275,7 @@ jobs: tags: | type=semver,pattern={{version}},value=${{ env.TAG }} type=semver,pattern={{major}}.{{minor}},value=${{ env.TAG }},enable=${{ env.PUBLISH_MINOR == 'true' && env.IS_RC != 'true' }} - type=raw,value=latest,enable=${{ env.PUBLISH_LATEST == 'true' && env.IS_RC != 'true' }} + type=raw,value=latest,enable=${{ env.IS_RC != 'true' }} flavor: | latest=false labels: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2349f3a..6fe7514 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,4 +22,4 @@ jobs: run: make go-build-release - name: Build Docker image - run: make docker-build-release + run: make docker-build diff --git a/Makefile b/Makefile index 8eed5c2..55ebf81 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,9 @@ all: local local: CGO_ENABLED=0 go build -o ./bin/olm +docker-build: + docker build -t fosrl/olm:latest . + docker-build-release: @if [ -z "$(tag)" ]; then \ echo "Error: tag is required. Usage: make docker-build-release tag="; \ diff --git a/README.md b/README.md index 97d0f66..0d7847e 100644 --- a/README.md +++ b/README.md @@ -20,13 +20,7 @@ When Olm receives WireGuard control messages, it will use the information encode ## Hole Punching -In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to newt. If you want to disable hole punching, use the `--disable-holepunch` flag. Hole punching attempts to orchestrate a NAT hole punch between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil. - -Right now, basic NAT hole punching is supported. We plan to add: - -- [ ] Birthday paradox -- [ ] UPnP -- [ ] LAN detection +In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to Newt. Hole punching attempts to orchestrate a NAT traversal between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil. ## Build diff --git a/api/api.go b/api/api.go index 787f958..a6ac9cd 100644 --- a/api/api.go +++ b/api/api.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "strconv" "sync" "time" @@ -62,23 +63,26 @@ type StatusResponse struct { // API represents the HTTP server and its state type API struct { - addr string - socketPath string - listener net.Listener - server *http.Server + addr string + socketPath string + listener net.Listener + server *http.Server + onConnect func(ConnectionRequest) error onSwitchOrg func(SwitchOrgRequest) error onDisconnect func() error onExit func() error + statusMu sync.RWMutex peerStatuses map[int]*PeerStatus connectedAt time.Time isConnected bool isRegistered bool isTerminated bool - version string - agent string - orgID string + + version string + agent string + orgID string } // NewAPI creates a new HTTP server that listens on a TCP address @@ -101,6 +105,14 @@ func NewAPISocket(socketPath string) *API { return s } +func NewAPIStub() *API { + s := &API{ + peerStatuses: make(map[int]*PeerStatus), + } + + return s +} + // SetHandlers sets the callback functions for handling API requests func (s *API) SetHandlers( onConnect func(ConnectionRequest) error, @@ -116,6 +128,10 @@ func (s *API) SetHandlers( // Start starts the HTTP server func (s *API) Start() error { + if s.socketPath == "" && s.addr == "" { + return fmt.Errorf("either socketPath or addr must be provided to start the API server") + } + mux := http.NewServeMux() mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) @@ -160,7 +176,7 @@ func (s *API) Stop() error { // Close the server first, which will also close the listener gracefully if s.server != nil { - s.server.Close() + _ = s.server.Close() } // Clean up socket file if using Unix socket @@ -345,7 +361,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusAccepted) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "connection request accepted", }) } @@ -358,7 +374,6 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { } s.statusMu.RLock() - defer s.statusMu.RUnlock() resp := StatusResponse{ Connected: s.isConnected, @@ -371,8 +386,18 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { NetworkSettings: network.GetSettings(), } + s.statusMu.RUnlock() + + data, err := json.Marshal(resp) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + w.Header().Set("Content-Length", strconv.Itoa(len(data))) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(data) } // handleHealth handles the /health endpoint @@ -384,7 +409,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "ok", }) } @@ -401,7 +426,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) { // Return a success response first w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "shutdown initiated", }) @@ -450,7 +475,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "org switch request accepted", }) } @@ -484,7 +509,7 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { // Return a success response w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{ + _ = json.NewEncoder(w).Encode(map[string]string{ "status": "disconnect initiated", }) } diff --git a/device/middle_device.go b/device/middle_device.go index b031871..7dfbec8 100644 --- a/device/middle_device.go +++ b/device/middle_device.go @@ -1,9 +1,12 @@ package device import ( + "io" "net/netip" "os" "sync" + "sync/atomic" + "time" "github.com/fosrl/newt/logger" "golang.zx2c4.com/wireguard/tun" @@ -18,14 +21,68 @@ type FilterRule struct { Handler PacketHandler } -// MiddleDevice wraps a TUN device with packet filtering capabilities -type MiddleDevice struct { +// closeAwareDevice wraps a tun.Device along with a flag +// indicating whether its Close method was called. +type closeAwareDevice struct { + isClosed atomic.Bool tun.Device - rules []FilterRule - mutex sync.RWMutex - readCh chan readResult - injectCh chan []byte - closed chan struct{} + closeEventCh chan struct{} + wg sync.WaitGroup + closeOnce sync.Once +} + +func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice { + return &closeAwareDevice{ + Device: tunDevice, + isClosed: atomic.Bool{}, + closeEventCh: make(chan struct{}), + } +} + +// redirectEvents redirects the Events() method of the underlying tun.Device +// to the given channel. +func (c *closeAwareDevice) redirectEvents(out chan tun.Event) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + for { + select { + case ev, ok := <-c.Device.Events(): + if !ok { + return + } + + if ev == tun.EventDown { + continue + } + + select { + case out <- ev: + case <-c.closeEventCh: + return + } + case <-c.closeEventCh: + return + } + } + }() +} + +// Close calls the underlying Device's Close method +// after setting isClosed to true. +func (c *closeAwareDevice) Close() (err error) { + c.closeOnce.Do(func() { + c.isClosed.Store(true) + close(c.closeEventCh) + err = c.Device.Close() + c.wg.Wait() + }) + + return err +} + +func (c *closeAwareDevice) IsClosed() bool { + return c.isClosed.Load() } type readResult struct { @@ -36,58 +93,136 @@ type readResult struct { err error } +// MiddleDevice wraps a TUN device with packet filtering capabilities +// and supports swapping the underlying device. +type MiddleDevice struct { + devices []*closeAwareDevice + mu sync.Mutex + cond *sync.Cond + rules []FilterRule + rulesMutex sync.RWMutex + readCh chan readResult + injectCh chan []byte + closed atomic.Bool + events chan tun.Event +} + // NewMiddleDevice creates a new filtered TUN device wrapper func NewMiddleDevice(device tun.Device) *MiddleDevice { d := &MiddleDevice{ - Device: device, + devices: make([]*closeAwareDevice, 0), rules: make([]FilterRule, 0), - readCh: make(chan readResult), + readCh: make(chan readResult, 16), injectCh: make(chan []byte, 100), - closed: make(chan struct{}), + events: make(chan tun.Event, 16), } - go d.pump() + d.cond = sync.NewCond(&d.mu) + + if device != nil { + d.AddDevice(device) + } + return d } -func (d *MiddleDevice) pump() { +// AddDevice adds a new underlying TUN device, closing any previous one +func (d *MiddleDevice) AddDevice(device tun.Device) { + d.mu.Lock() + if d.closed.Load() { + d.mu.Unlock() + _ = device.Close() + return + } + + var toClose *closeAwareDevice + if len(d.devices) > 0 { + toClose = d.devices[len(d.devices)-1] + } + + cad := newCloseAwareDevice(device) + cad.redirectEvents(d.events) + + d.devices = []*closeAwareDevice{cad} + + // Start pump for the new device + go d.pump(cad) + + d.cond.Broadcast() + d.mu.Unlock() + + if toClose != nil { + logger.Debug("MiddleDevice: Closing previous device") + if err := toClose.Close(); err != nil { + logger.Debug("MiddleDevice: Error closing previous device: %v", err) + } + } +} + +func (d *MiddleDevice) pump(dev *closeAwareDevice) { const defaultOffset = 16 - batchSize := d.Device.BatchSize() - logger.Debug("MiddleDevice: pump started") + batchSize := dev.BatchSize() + logger.Debug("MiddleDevice: pump started for device") + + // Recover from panic if readCh is closed while we're trying to send + defer func() { + if r := recover(); r != nil { + logger.Debug("MiddleDevice: pump recovered from panic (channel closed)") + } + }() for { - // Check closed first with priority - select { - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel") + // Check if this device is closed + if dev.IsClosed() { + logger.Debug("MiddleDevice: pump exiting, device is closed") + return + } + + // Check if MiddleDevice itself is closed + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed") return - default: } // Allocate buffers for reading - // We allocate new buffers for each read to avoid race conditions - // since we pass them to the channel bufs := make([][]byte, batchSize) sizes := make([]int, batchSize) for i := range bufs { bufs[i] = make([]byte, 2048) // Standard MTU + headroom } - n, err := d.Device.Read(bufs, sizes, defaultOffset) + n, err := dev.Read(bufs, sizes, defaultOffset) - // Check closed again after read returns - select { - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)") + // Check if device was closed during read + if dev.IsClosed() { + logger.Debug("MiddleDevice: pump exiting, device closed during read") + return + } + + // Check if MiddleDevice was closed during read + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read") + return + } + + // Try to send the result - check closed state first to avoid sending on closed channel + if d.closed.Load() { + logger.Debug("MiddleDevice: pump exiting, device closed before send") return - default: } - // Now try to send the result select { case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: - case <-d.closed: - logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)") - return + default: + // Channel full, check if we should exit + if dev.IsClosed() || d.closed.Load() { + return + } + // Try again with blocking + select { + case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: + case <-dev.closeEventCh: + return + } } if err != nil { @@ -99,16 +234,28 @@ func (d *MiddleDevice) pump() { // InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN) func (d *MiddleDevice) InjectOutbound(packet []byte) { + if d.closed.Load() { + return + } + // Use defer/recover to handle panic from sending on closed channel + // This can happen during shutdown race conditions + defer func() { + if r := recover(); r != nil { + logger.Debug("MiddleDevice: InjectOutbound recovered from panic (channel closed)") + } + }() select { case d.injectCh <- packet: - case <-d.closed: + default: + // Channel full, drop packet + logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full") } } // AddRule adds a packet filtering rule func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) { - d.mutex.Lock() - defer d.mutex.Unlock() + d.rulesMutex.Lock() + defer d.rulesMutex.Unlock() d.rules = append(d.rules, FilterRule{ DestIP: destIP, Handler: handler, @@ -117,8 +264,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) { // RemoveRule removes all rules for a given destination IP func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { - d.mutex.Lock() - defer d.mutex.Unlock() + d.rulesMutex.Lock() + defer d.rulesMutex.Unlock() newRules := make([]FilterRule, 0, len(d.rules)) for _, rule := range d.rules { if rule.DestIP != destIP { @@ -130,18 +277,120 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { // Close stops the device func (d *MiddleDevice) Close() error { - select { - case <-d.closed: - // Already closed - return nil - default: - logger.Debug("MiddleDevice: Closing, signaling closed channel") - close(d.closed) + if !d.closed.CompareAndSwap(false, true) { + return nil // already closed } - logger.Debug("MiddleDevice: Closing underlying TUN device") - err := d.Device.Close() - logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err) - return err + + d.mu.Lock() + devices := d.devices + d.devices = nil + d.cond.Broadcast() + d.mu.Unlock() + + // Close underlying devices first - this causes the pump goroutines to exit + // when their read operations return errors + var lastErr error + logger.Debug("MiddleDevice: Closing %d devices", len(devices)) + for _, device := range devices { + if err := device.Close(); err != nil { + logger.Debug("MiddleDevice: Error closing device: %v", err) + lastErr = err + } + } + + // Now close channels to unblock any remaining readers + // The pump should have exited by now, but close channels to be safe + close(d.readCh) + close(d.injectCh) + close(d.events) + + return lastErr +} + +// Events returns the events channel +func (d *MiddleDevice) Events() <-chan tun.Event { + return d.events +} + +// File returns the underlying file descriptor +func (d *MiddleDevice) File() *os.File { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return nil + } + continue + } + + file := dev.File() + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return file + } +} + +// MTU returns the MTU of the underlying device +func (d *MiddleDevice) MTU() (int, error) { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } + continue + } + + mtu, err := dev.MTU() + if err == nil { + return mtu, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return 0, err + } +} + +// Name returns the name of the underlying device +func (d *MiddleDevice) Name() (string, error) { + for { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return "", io.EOF + } + continue + } + + name, err := dev.Name() + if err == nil { + return name, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return "", err + } +} + +// BatchSize returns the batch size +func (d *MiddleDevice) BatchSize() int { + dev := d.peekLast() + if dev == nil { + return 1 + } + return dev.BatchSize() } // extractDestIP extracts destination IP from packet (fast path) @@ -176,156 +425,239 @@ func extractDestIP(packet []byte) (netip.Addr, bool) { // Read intercepts packets going UP from the TUN device (towards WireGuard) func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { - // Check if already closed first (non-blocking) - select { - case <-d.closed: - logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)") - return 0, os.ErrClosed - default: - } - - // Now block waiting for data - select { - case res := <-d.readCh: - if res.err != nil { - logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err) - return 0, res.err + for { + if d.closed.Load() { + logger.Debug("MiddleDevice: Read returning io.EOF, device closed") + return 0, io.EOF } - // Copy packets from result to provided buffers - count := 0 - for i := 0; i < res.n && i < len(bufs); i++ { - // Handle offset mismatch if necessary - // We assume the pump used defaultOffset (16) - // If caller asks for different offset, we need to shift - src := res.bufs[i] - srcOffset := res.offset - srcSize := res.sizes[i] - - // Calculate where the packet data starts and ends in src - pktData := src[srcOffset : srcOffset+srcSize] - - // Ensure dest buffer is large enough - if len(bufs[i]) < offset+len(pktData) { - continue // Skip if buffer too small + // Wait for a device to be available + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF } - - copy(bufs[i][offset:], pktData) - sizes[i] = len(pktData) - count++ - } - n = count - - case pkt := <-d.injectCh: - if len(bufs) == 0 { - return 0, nil - } - if len(bufs[0]) < offset+len(pkt) { - return 0, nil // Buffer too small - } - copy(bufs[0][offset:], pkt) - sizes[0] = len(pkt) - n = 1 - - case <-d.closed: - logger.Debug("MiddleDevice: Read returning os.ErrClosed") - return 0, os.ErrClosed // Signal that device is closed - } - - d.mutex.RLock() - rules := d.rules - d.mutex.RUnlock() - - if len(rules) == 0 { - return n, nil - } - - // Process packets and filter out handled ones - writeIdx := 0 - for readIdx := 0; readIdx < n; readIdx++ { - packet := bufs[readIdx][offset : offset+sizes[readIdx]] - - destIP, ok := extractDestIP(packet) - if !ok { - // Can't parse, keep packet - if writeIdx != readIdx { - bufs[writeIdx] = bufs[readIdx] - sizes[writeIdx] = sizes[readIdx] - } - writeIdx++ continue } - // Check if packet matches any rule - handled := false - for _, rule := range rules { - if rule.DestIP == destIP { - if rule.Handler(packet) { - // Packet was handled and should be dropped - handled = true - break + // Now block waiting for data from readCh or injectCh + select { + case res, ok := <-d.readCh: + if !ok { + // Channel closed, device is shutting down + return 0, io.EOF + } + if res.err != nil { + // Check if device was swapped + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err) + return 0, res.err + } + + // Copy packets from result to provided buffers + count := 0 + for i := 0; i < res.n && i < len(bufs); i++ { + src := res.bufs[i] + srcOffset := res.offset + srcSize := res.sizes[i] + + pktData := src[srcOffset : srcOffset+srcSize] + + if len(bufs[i]) < offset+len(pktData) { + continue + } + + copy(bufs[i][offset:], pktData) + sizes[i] = len(pktData) + count++ + } + n = count + + case pkt, ok := <-d.injectCh: + if !ok { + // Channel closed, device is shutting down + return 0, io.EOF + } + if len(bufs) == 0 { + return 0, nil + } + if len(bufs[0]) < offset+len(pkt) { + return 0, nil + } + copy(bufs[0][offset:], pkt) + sizes[0] = len(pkt) + n = 1 + } + + // Apply filtering rules + d.rulesMutex.RLock() + rules := d.rules + d.rulesMutex.RUnlock() + + if len(rules) == 0 { + return n, nil + } + + // Process packets and filter out handled ones + writeIdx := 0 + for readIdx := 0; readIdx < n; readIdx++ { + packet := bufs[readIdx][offset : offset+sizes[readIdx]] + + destIP, ok := extractDestIP(packet) + if !ok { + if writeIdx != readIdx { + bufs[writeIdx] = bufs[readIdx] + sizes[writeIdx] = sizes[readIdx] + } + writeIdx++ + continue + } + + handled := false + for _, rule := range rules { + if rule.DestIP == destIP { + if rule.Handler(packet) { + handled = true + break + } } } - } - if !handled { - // Keep packet - if writeIdx != readIdx { - bufs[writeIdx] = bufs[readIdx] - sizes[writeIdx] = sizes[readIdx] + if !handled { + if writeIdx != readIdx { + bufs[writeIdx] = bufs[readIdx] + sizes[writeIdx] = sizes[readIdx] + } + writeIdx++ } - writeIdx++ } - } - return writeIdx, err + return writeIdx, nil + } } // Write intercepts packets going DOWN to the TUN device (from WireGuard) func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) { - d.mutex.RLock() - rules := d.rules - d.mutex.RUnlock() + for { + if d.closed.Load() { + return 0, io.EOF + } - if len(rules) == 0 { - return d.Device.Write(bufs, offset) - } - - // Filter packets going down - filteredBufs := make([][]byte, 0, len(bufs)) - for _, buf := range bufs { - if len(buf) <= offset { + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } continue } - packet := buf[offset:] - destIP, ok := extractDestIP(packet) - if !ok { - // Can't parse, keep packet - filteredBufs = append(filteredBufs, buf) - continue - } + d.rulesMutex.RLock() + rules := d.rules + d.rulesMutex.RUnlock() - // Check if packet matches any rule - handled := false - for _, rule := range rules { - if rule.DestIP == destIP { - if rule.Handler(packet) { - // Packet was handled and should be dropped - handled = true - break + var filteredBufs [][]byte + if len(rules) == 0 { + filteredBufs = bufs + } else { + filteredBufs = make([][]byte, 0, len(bufs)) + for _, buf := range bufs { + if len(buf) <= offset { + continue + } + + packet := buf[offset:] + destIP, ok := extractDestIP(packet) + if !ok { + filteredBufs = append(filteredBufs, buf) + continue + } + + handled := false + for _, rule := range rules { + if rule.DestIP == destIP { + if rule.Handler(packet) { + handled = true + break + } + } + } + + if !handled { + filteredBufs = append(filteredBufs, buf) } } } - if !handled { - filteredBufs = append(filteredBufs, buf) + if len(filteredBufs) == 0 { + return len(bufs), nil } - } - if len(filteredBufs) == 0 { - return len(bufs), nil // All packets were handled - } + n, err := dev.Write(filteredBufs, offset) + if err == nil { + return n, nil + } - return d.Device.Write(filteredBufs, offset) + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } } + +func (d *MiddleDevice) waitForDevice() bool { + d.mu.Lock() + defer d.mu.Unlock() + + for len(d.devices) == 0 && !d.closed.Load() { + d.cond.Wait() + } + return !d.closed.Load() +} + +func (d *MiddleDevice) peekLast() *closeAwareDevice { + d.mu.Lock() + defer d.mu.Unlock() + + if len(d.devices) == 0 { + return nil + } + + return d.devices[len(d.devices)-1] +} + +// WriteToTun writes packets directly to the underlying TUN device, +// bypassing WireGuard. This is useful for sending packets that should +// appear to come from the TUN interface (e.g., DNS responses from a proxy). +// Unlike Write(), this does not go through packet filtering rules. +func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) { + for { + if d.closed.Load() { + return 0, io.EOF + } + + dev := d.peekLast() + if dev == nil { + if !d.waitForDevice() { + return 0, io.EOF + } + continue + } + + n, err := dev.Write(bufs, offset) + if err == nil { + return n, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } +} \ No newline at end of file diff --git a/device/tun_unix.go b/device/tun_darwin.go similarity index 91% rename from device/tun_unix.go rename to device/tun_darwin.go index c9bab60..df87d53 100644 --- a/device/tun_unix.go +++ b/device/tun_darwin.go @@ -1,4 +1,4 @@ -//go:build !windows +//go:build darwin package device @@ -26,7 +26,7 @@ func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { } file := os.NewFile(uintptr(dupTunFd), "/dev/tun") - device, err := tun.CreateTUNFromFile(file, mtuInt) + device, err := tun.CreateTUNFromFile(file, 0) if err != nil { file.Close() return nil, err diff --git a/device/tun_linux.go b/device/tun_linux.go new file mode 100644 index 0000000..902f269 --- /dev/null +++ b/device/tun_linux.go @@ -0,0 +1,50 @@ +//go:build linux + +package device + +import ( + "net" + "os" + "runtime" + + "github.com/fosrl/newt/logger" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/tun" +) + +func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { + if runtime.GOOS == "android" { // otherwise we get a permission denied + theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd)) + return theTun, err + } + + dupTunFd, err := unix.Dup(int(tunFd)) + if err != nil { + logger.Error("Unable to dup tun fd: %v", err) + return nil, err + } + + err = unix.SetNonblock(dupTunFd, true) + if err != nil { + unix.Close(dupTunFd) + return nil, err + } + + file := os.NewFile(uintptr(dupTunFd), "/dev/tun") + device, err := tun.CreateTUNFromFile(file, mtuInt) + if err != nil { + file.Close() + return nil, err + } + + return device, nil +} + +func UapiOpen(interfaceName string) (*os.File, error) { + return ipc.UAPIOpen(interfaceName) +} + +func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { + return ipc.UAPIListen(interfaceName, fileUAPI) +} diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 6d56379..f65e923 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -12,7 +12,6 @@ import ( "github.com/fosrl/newt/util" "github.com/fosrl/olm/device" "github.com/miekg/dns" - "golang.zx2c4.com/wireguard/tun" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" @@ -34,18 +33,17 @@ type DNSProxy struct { ep *channel.Endpoint proxyIP netip.Addr upstreamDNS []string - tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally + tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally mtu int - tunDevice tun.Device // Direct reference to underlying TUN device for responses - middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering + middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes recordStore *DNSRecordStore // Local DNS records // Tunnel DNS fields - for sending queries over WireGuard - tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries) - tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries - tunnelEp *channel.Endpoint + tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries) + tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries + tunnelEp *channel.Endpoint tunnelActivePorts map[uint16]bool - tunnelPortsLock sync.Mutex + tunnelPortsLock sync.Mutex ctx context.Context cancel context.CancelFunc @@ -53,7 +51,7 @@ type DNSProxy struct { } // NewDNSProxy creates a new DNS proxy -func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) { +func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) { proxyIP, err := PickIPFromSubnet(utilitySubnet) if err != nil { return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err) @@ -68,7 +66,6 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in proxy := &DNSProxy{ proxyIP: proxyIP, mtu: mtu, - tunDevice: tunDevice, middleDevice: middleDevice, upstreamDNS: upstreamDns, tunnelDNS: tunnelDns, @@ -602,12 +599,12 @@ func (p *DNSProxy) runTunnelPacketSender() { defer p.wg.Done() logger.Debug("DNS tunnel packet sender goroutine started") - ticker := time.NewTicker(1 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-p.ctx.Done(): + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := p.tunnelEp.ReadContext(p.ctx) + if pkt == nil { + // Context was cancelled or endpoint closed logger.Debug("DNS tunnel packet sender exiting") // Drain any remaining packets for { @@ -618,36 +615,28 @@ func (p *DNSProxy) runTunnelPacketSender() { pkt.DecRef() } return - case <-ticker.C: - // Try to read packets - for i := 0; i < 10; i++ { - pkt := p.tunnelEp.Read() - if pkt == nil { - break - } - - // Extract packet data - slices := pkt.AsSlices() - if len(slices) > 0 { - var totalSize int - for _, slice := range slices { - totalSize += len(slice) - } - - buf := make([]byte, totalSize) - pos := 0 - for _, slice := range slices { - copy(buf[pos:], slice) - pos += len(slice) - } - - // Inject into MiddleDevice (outbound to WG) - p.middleDevice.InjectOutbound(buf) - } - - pkt.DecRef() - } } + + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + p.middleDevice.InjectOutbound(buf) + } + + pkt.DecRef() } } @@ -660,18 +649,12 @@ func (p *DNSProxy) runPacketSender() { const offset = 16 for { - select { - case <-p.ctx.Done(): - return - default: - } - - // Read packets from netstack endpoint - pkt := p.ep.Read() + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := p.ep.ReadContext(p.ctx) if pkt == nil { - // No packet available, small sleep to avoid busy loop - time.Sleep(1 * time.Millisecond) - continue + // Context was cancelled or endpoint closed + return } // Extract packet data as slices @@ -694,9 +677,9 @@ func (p *DNSProxy) runPacketSender() { pos += len(slice) } - // Write packet to TUN device + // Write packet to TUN device via MiddleDevice // offset=16 indicates packet data starts at position 16 in the buffer - _, err := p.tunDevice.Write([][]byte{buf}, offset) + _, err := p.middleDevice.WriteToTun([][]byte{buf}, offset) if err != nil { logger.Error("Failed to write DNS response to TUN: %v", err) } diff --git a/dns/dns_records.go b/dns/dns_records.go index ed57b77..5308b0e 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -322,4 +322,4 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool { } return matchWildcardInternal(pattern, domain, pi+1, di+1) -} \ No newline at end of file +} diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go index 0bb18a1..f922afb 100644 --- a/dns/dns_records_test.go +++ b/dns/dns_records_test.go @@ -37,7 +37,7 @@ func TestWildcardMatching(t *testing.T) { domain: "autoco.internal.", expected: false, }, - + // Question mark wildcard tests { name: "host-0?.autoco.internal matches host-01.autoco.internal", @@ -63,7 +63,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-012.autoco.internal.", expected: false, }, - + // Combined wildcard tests { name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal", @@ -83,7 +83,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-01.autoco.internal.", expected: false, }, - + // Multiple asterisks { name: "*.*. autoco.internal matches any.thing.autoco.internal", @@ -97,7 +97,7 @@ func TestWildcardMatching(t *testing.T) { domain: "single.autoco.internal.", expected: false, }, - + // Asterisk in middle { name: "host-*.autoco.internal matches host-anything.autoco.internal", @@ -111,7 +111,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-.autoco.internal.", expected: true, }, - + // Multiple question marks { name: "host-??.autoco.internal matches host-01.autoco.internal", @@ -125,7 +125,7 @@ func TestWildcardMatching(t *testing.T) { domain: "host-1.autoco.internal.", expected: false, }, - + // Exact match (no wildcards) { name: "exact.autoco.internal matches exact.autoco.internal", @@ -139,7 +139,7 @@ func TestWildcardMatching(t *testing.T) { domain: "other.autoco.internal.", expected: false, }, - + // Edge cases { name: "* matches anything", @@ -154,7 +154,7 @@ func TestWildcardMatching(t *testing.T) { expected: true, }, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := matchWildcard(tt.pattern, tt.domain) @@ -167,21 +167,21 @@ func TestWildcardMatching(t *testing.T) { func TestDNSRecordStoreWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard records wildcardIP := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", wildcardIP) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Add exact record exactIP := net.ParseIP("10.0.0.2") err = store.AddRecord("exact.autoco.internal", exactIP) if err != nil { t.Fatalf("Failed to add exact record: %v", err) } - + // Test exact match takes precedence ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -190,7 +190,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) { if !ips[0].Equal(exactIP) { t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) } - + // Test wildcard match ips = store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -199,7 +199,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) { if !ips[0].Equal(wildcardIP) { t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) } - + // Test non-match (base domain) ips = store.GetRecords("autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -209,14 +209,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) { func TestDNSRecordStoreComplexWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add complex wildcard pattern ip1 := net.ParseIP("10.0.0.1") err := store.AddRecord("*.host-0?.autoco.internal", ip1) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Test matching domain ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -225,13 +225,13 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { if len(ips) > 0 && !ips[0].Equal(ip1) { t.Errorf("Expected IP %v, got %v", ip1, ips[0]) } - + // Test non-matching domain (missing prefix) ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) } - + // Test non-matching domain (wrong ? position) ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -241,23 +241,23 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { func TestDNSRecordStoreRemoveWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard record ip := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Verify it exists ips := store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP before removal, got %d", len(ips)) } - + // Remove wildcard record store.RemoveRecord("*.autoco.internal", nil) - + // Verify it's gone ips = store.GetRecords("host.autoco.internal.", RecordTypeA) if len(ips) != 0 { @@ -267,40 +267,40 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { func TestDNSRecordStoreMultipleWildcards(t *testing.T) { store := NewDNSRecordStore() - + // Add multiple wildcard patterns that don't overlap ip1 := net.ParseIP("10.0.0.1") ip2 := net.ParseIP("10.0.0.2") ip3 := net.ParseIP("10.0.0.3") - + err := store.AddRecord("*.prod.autoco.internal", ip1) if err != nil { t.Fatalf("Failed to add first wildcard: %v", err) } - + err = store.AddRecord("*.dev.autoco.internal", ip2) if err != nil { t.Fatalf("Failed to add second wildcard: %v", err) } - + // Add a broader wildcard that matches both err = store.AddRecord("*.autoco.internal", ip3) if err != nil { t.Fatalf("Failed to add third wildcard: %v", err) } - + // Test domain matching only the prod pattern and the broad pattern ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) } - + // Test domain matching only the dev pattern and the broad pattern ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) } - + // Test domain matching only the broad pattern ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) if len(ips) != 1 { @@ -310,14 +310,14 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) { func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add IPv6 wildcard record ip := net.ParseIP("2001:db8::1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add IPv6 wildcard record: %v", err) } - + // Test wildcard match for IPv6 ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) if len(ips) != 1 { @@ -330,21 +330,21 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { func TestHasRecordWildcard(t *testing.T) { store := NewDNSRecordStore() - + // Add wildcard record ip := net.ParseIP("10.0.0.1") err := store.AddRecord("*.autoco.internal", ip) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } - + // Test HasRecord with wildcard match if !store.HasRecord("host.autoco.internal.", RecordTypeA) { t.Error("Expected HasRecord to return true for wildcard match") } - + // Test HasRecord with non-match if store.HasRecord("autoco.internal.", RecordTypeA) { t.Error("Expected HasRecord to return false for base domain") } -} \ No newline at end of file +} diff --git a/dns/override/dns_override_android.go b/dns/override/dns_override_android.go new file mode 100644 index 0000000..d3fd78e --- /dev/null +++ b/dns/override/dns_override_android.go @@ -0,0 +1,16 @@ +//go:build android + +package olm + +import "net/netip" + +// SetupDNSOverride is a no-op on Android +// Android handles DNS through the VpnService API at the Java/Kotlin layer +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { + return nil +} + +// RestoreDNSOverride is a no-op on Android +func RestoreDNSOverride() error { + return nil +} \ No newline at end of file diff --git a/dns/override/dns_override_darwin.go b/dns/override/dns_override_darwin.go index 6ccc3fb..c1c3789 100644 --- a/dns/override/dns_override_darwin.go +++ b/dns/override/dns_override_darwin.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on macOS // Uses scutil for DNS configuration -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error configurator, err = platform.NewDarwinDNSConfigurator() if err != nil { @@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/dns/override/dns_override_ios.go b/dns/override/dns_override_ios.go new file mode 100644 index 0000000..6c95c71 --- /dev/null +++ b/dns/override/dns_override_ios.go @@ -0,0 +1,15 @@ +//go:build ios + +package olm + +import "net/netip" + +// SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { + return nil +} + +// RestoreDNSOverride is a no-op on iOS as DNS configuration is handled by the system +func RestoreDNSOverride() error { + return nil +} \ No newline at end of file diff --git a/dns/override/dns_override_unix.go b/dns/override/dns_override_unix.go index c3b31e8..12cb692 100644 --- a/dns/override/dns_override_unix.go +++ b/dns/override/dns_override_unix.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD // Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error // Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability @@ -32,7 +27,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName) if err == nil { logger.Info("Using systemd-resolved DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err) @@ -40,7 +35,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName) if err == nil { logger.Info("Using NetworkManager DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err) @@ -48,7 +43,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName) if err == nil { logger.Info("Using resolvconf DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } logger.Warn("Failed to create resolvconf configurator: %v, falling back", err) } @@ -60,11 +55,11 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { } logger.Info("Using file-based DNS configurator") - return setDNS(dnsProxy, configurator) + return setDNS(proxyIp, configurator) } // setDNS is a helper function to set DNS and log the results -func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error { +func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error { // Get current DNS servers before changing currentDNS, err := conf.GetCurrentDNS() if err != nil { @@ -75,7 +70,7 @@ func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/dns/override/dns_override_windows.go b/dns/override/dns_override_windows.go index a564079..16bbca1 100644 --- a/dns/override/dns_override_windows.go +++ b/dns/override/dns_override_windows.go @@ -7,7 +7,6 @@ import ( "net/netip" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/dns" platform "github.com/fosrl/olm/dns/platform" ) @@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator // SetupDNSOverride configures the system DNS to use the DNS proxy on Windows // Uses registry-based configuration (automatically extracts interface GUID) -func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { - if dnsProxy == nil { - return fmt.Errorf("DNS proxy is nil") - } - +func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error { var err error configurator, err = platform.NewWindowsDNSConfigurator(interfaceName) if err != nil { @@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { // Set new DNS servers to point to our proxy newDNS := []netip.Addr{ - dnsProxy.GetProxyIP(), + proxyIp, } logger.Info("Setting DNS servers to: %v", newDNS) diff --git a/dns/platform/darwin.go b/dns/platform/darwin.go index 61cc81b..8054c57 100644 --- a/dns/platform/darwin.go +++ b/dns/platform/darwin.go @@ -416,4 +416,4 @@ func (d *DarwinDNSConfigurator) clearState() error { logger.Debug("Cleared DNS state file") return nil -} \ No newline at end of file +} diff --git a/go.mod b/go.mod index 4844592..0d6bbcb 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.25 require ( github.com/Microsoft/go-winio v0.6.2 - github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01 + github.com/fosrl/newt v1.8.0 github.com/godbus/dbus/v5 v5.2.0 github.com/gorilla/websocket v1.5.3 github.com/miekg/dns v1.1.68 @@ -30,3 +30,5 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect ) + +replace github.com/fosrl/newt => ../newt diff --git a/go.sum b/go.sum index 9bf88e2..f6ca61a 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01 h1:VpuI42l4enih//6IFFQDln/B7WukfMePxIRIpXsNe/0= -github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= diff --git a/main.go b/main.go index f6c6973..2bf8dcd 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/updates" - "github.com/fosrl/olm/olm" + olmpkg "github.com/fosrl/olm/olm" ) func main() { @@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt } // Create a new olm.Config struct and copy values from the main config - olmConfig := olm.GlobalConfig{ + olmConfig := olmpkg.OlmConfig{ LogLevel: config.LogLevel, EnableAPI: config.EnableAPI, HTTPAddr: config.HTTPAddr, @@ -219,15 +219,20 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt Agent: "Olm CLI", OnExit: cancel, // Pass cancel function directly to trigger shutdown OnTerminated: cancel, + PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE + } + + olm, err := olmpkg.Init(ctx, olmConfig) + if err != nil { + logger.Fatal("Failed to initialize olm: %v", err) } - olm.Init(ctx, olmConfig) if err := olm.StartApi(); err != nil { logger.Fatal("Failed to start API server: %v", err) } if config.ID != "" && config.Secret != "" && config.Endpoint != "" { - tunnelConfig := olm.TunnelConfig{ + tunnelConfig := olmpkg.TunnelConfig{ Endpoint: config.Endpoint, ID: config.ID, Secret: config.Secret, diff --git a/olm/connect.go b/olm/connect.go new file mode 100644 index 0000000..a610ea4 --- /dev/null +++ b/olm/connect.go @@ -0,0 +1,223 @@ +package olm + +import ( + "encoding/json" + "fmt" + "os" + "runtime" + "strconv" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + olmDevice "github.com/fosrl/olm/device" + "github.com/fosrl/olm/dns" + dnsOverride "github.com/fosrl/olm/dns/override" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" +) + +func (o *Olm) handleConnect(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + + var wgData WgData + + if o.connected { + logger.Info("Already connected. Ignoring new connection request.") + return + } + + if o.stopRegister != nil { + o.stopRegister() + o.stopRegister = nil + } + + if o.updateRegister != nil { + o.updateRegister = nil + } + + // if there is an existing tunnel then close it + if o.dev != nil { + logger.Info("Got new message. Closing existing tunnel!") + o.dev.Close() + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &wgData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + + o.tdev, err = func() (tun.Device, error) { + if o.tunnelConfig.FileDescriptorTun != 0 { + return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU) + } + ifName := o.tunnelConfig.InterfaceName + if runtime.GOOS == "darwin" { // this is if we dont pass a fd + ifName, err = network.FindUnusedUTUN() + if err != nil { + return nil, err + } + } + return tun.CreateTUN(ifName, o.tunnelConfig.MTU) + }() + if err != nil { + logger.Error("Failed to create TUN device: %v", err) + return + } + + // if config.FileDescriptorTun == 0 { + if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything? + o.tunnelConfig.InterfaceName = realInterfaceName + } + // } + + // Wrap TUN device with packet filter for DNS proxy + o.middleDev = olmDevice.NewMiddleDevice(o.tdev) + + wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") + // Use filtered device instead of raw TUN device + o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger)) + + if o.tunnelConfig.EnableUAPI { + fileUAPI, err := func() (*os.File, error) { + if o.tunnelConfig.FileDescriptorUAPI != 0 { + fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) + } + return os.NewFile(uintptr(fd), ""), nil + } + return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName) + }() + if err != nil { + logger.Error("UAPI listen error: %v", err) + os.Exit(1) + return + } + + o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI) + if err != nil { + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) + } + + go func() { + for { + conn, err := o.uapiListener.Accept() + if err != nil { + return + } + go o.dev.IpcHandle(conn) + } + }() + logger.Info("UAPI listener started") + } + + if err = o.dev.Up(); err != nil { + logger.Error("Failed to bring up WireGuard device: %v", err) + } + + // Extract interface IP (strip CIDR notation if present) + interfaceIP := wgData.TunnelIP + if strings.Contains(interfaceIP, "/") { + interfaceIP = strings.Split(interfaceIP, "/")[0] + } + + // Create and start DNS proxy + o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP) + if err != nil { + logger.Error("Failed to create DNS proxy: %v", err) + } + + if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil { + logger.Error("Failed to o.tunnelConfigure interface: %v", err) + } + + if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet + logger.Error("Failed to add route for utility subnet: %v", err) + } + + // Create peer manager with integrated peer monitoring + o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ + Device: o.dev, + DNSProxy: o.dnsProxy, + InterfaceName: o.tunnelConfig.InterfaceName, + PrivateKey: o.privateKey, + MiddleDev: o.middleDev, + LocalIP: interfaceIP, + SharedBind: o.sharedBind, + WSClient: o.websocket, + APIServer: o.apiServer, + }) + + for i := range wgData.Sites { + site := wgData.Sites[i] + var siteEndpoint string + // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer + if site.RelayEndpoint != "" { + siteEndpoint = site.RelayEndpoint + } else { + siteEndpoint = site.Endpoint + } + + o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) + + if err := o.peerManager.AddPeer(site); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + logger.Info("Configured peer %s", site.PublicKey) + } + + o.peerManager.Start() + + if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime + logger.Error("Failed to start DNS proxy: %v", err) + } + + if o.tunnelConfig.OverrideDNS { + // Set up DNS override to use our DNS proxy + if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { + logger.Error("Failed to setup DNS override: %v", err) + return + } + + network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()}) + } + + o.apiServer.SetRegistered(true) + + o.connected = true + + // Invoke onConnected callback if configured + if o.olmConfig.OnConnected != nil { + go o.olmConfig.OnConnected() + } + + logger.Info("WireGuard device created.") +} + +func (o *Olm) handleTerminate(msg websocket.WSMessage) { + logger.Info("Received terminate message") + o.apiServer.SetTerminated(true) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) + o.apiServer.ClearPeerStatuses() + + network.ClearNetworkSettings() + + o.Close() + + if o.olmConfig.OnTerminated != nil { + go o.olmConfig.OnTerminated() + } +} diff --git a/olm/data.go b/olm/data.go new file mode 100644 index 0000000..80a52fc --- /dev/null +++ b/olm/data.go @@ -0,0 +1,344 @@ +package olm + +import ( + "encoding/json" + "time" + + "github.com/fosrl/newt/holepunch" + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" +) + +func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) { + logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var addSubnetsData peers.PeerAdd + if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil { + logger.Error("Error unmarshaling add-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId) + return + } + + // Add new subnets + for _, subnet := range addSubnetsData.RemoteSubnets { + if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases + for _, alias := range addSubnetsData.Aliases { + if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } +} + +func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) { + logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeSubnetsData peers.RemovePeerData + if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil { + logger.Error("Error unmarshaling remove-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId) + return + } + + // Remove subnets + for _, subnet := range removeSubnetsData.RemoteSubnets { + if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Remove aliases + for _, alias := range removeSubnetsData.Aliases { + if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } +} + +func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) { + logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateSubnetsData peers.UpdatePeerData + if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil { + logger.Error("Error unmarshaling update-remote-subnets data: %v", err) + return + } + + if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists { + logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId) + return + } + + // Add new subnets BEFORE removing old ones to preserve shared subnets + // This ensures that if an old and new subnet are the same on different peers, + // the route won't be temporarily removed + for _, subnet := range updateSubnetsData.NewRemoteSubnets { + if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to add allowed IP %s: %v", subnet, err) + } + } + + // Remove old subnets after new ones are added + for _, subnet := range updateSubnetsData.OldRemoteSubnets { + if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { + logger.Error("Failed to remove allowed IP %s: %v", subnet, err) + } + } + + // Add new aliases BEFORE removing old ones to preserve shared IP addresses + // This ensures that if an old and new alias share the same IP, the IP won't be + // temporarily removed from the allowed IPs list + for _, alias := range updateSubnetsData.NewAliases { + if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil { + logger.Error("Failed to add alias %s: %v", alias.Alias, err) + } + } + + // Remove old aliases after new ones are added + for _, alias := range updateSubnetsData.OldAliases { + if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil { + logger.Error("Failed to remove alias %s: %v", alias.Alias, err) + } + } + + logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) +} + +func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { + logger.Debug("Received peer-handshake message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling handshake data: %v", err) + return + } + + var handshakeData struct { + SiteId int `json:"siteId"` + ExitNode struct { + PublicKey string `json:"publicKey"` + Endpoint string `json:"endpoint"` + RelayPort uint16 `json:"relayPort"` + } `json:"exitNode"` + } + + if err := json.Unmarshal(jsonData, &handshakeData); err != nil { + logger.Error("Error unmarshaling handshake data: %v", err) + return + } + + // Get existing peer from PeerManager + _, exists := o.peerManager.GetPeer(handshakeData.SiteId) + if exists { + logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) + return + } + + relayPort := handshakeData.ExitNode.RelayPort + if relayPort == 0 { + relayPort = 21820 // default relay port + } + + siteId := handshakeData.SiteId + exitNode := holepunch.ExitNode{ + Endpoint: handshakeData.ExitNode.Endpoint, + RelayPort: relayPort, + PublicKey: handshakeData.ExitNode.PublicKey, + SiteIds: []int{siteId}, + } + + added := o.holePunchManager.AddExitNode(exitNode) + if added { + logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) + } else { + logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) + } + + o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt + o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud + + // Send handshake acknowledgment back to server with retry + o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + }, 1*time.Second, 10) + + logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) +} + +// Handler for syncing peer configuration - reconciles expected state with actual state +func (o *Olm) handleSync(msg websocket.WSMessage) { + logger.Debug("Received sync message: %v", msg.Data) + + if !o.connected { + logger.Warn("Not connected, ignoring sync request") + return + } + + if o.peerManager == nil { + logger.Warn("Peer manager not initialized, ignoring sync request") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling sync data: %v", err) + return + } + + var wgData WgData + if err := json.Unmarshal(jsonData, &wgData); err != nil { + logger.Error("Error unmarshaling sync data: %v", err) + return + } + + // Build a map of expected peers from the incoming data + expectedPeers := make(map[int]peers.SiteConfig) + for _, site := range wgData.Sites { + expectedPeers[site.SiteId] = site + } + + // Get all current peers + currentPeers := o.peerManager.GetAllPeers() + currentPeerMap := make(map[int]peers.SiteConfig) + for _, peer := range currentPeers { + currentPeerMap[peer.SiteId] = peer + } + + // Find peers to remove (in current but not in expected) + for siteId := range currentPeerMap { + if _, exists := expectedPeers[siteId]; !exists { + logger.Info("Sync: Removing peer for site %d (no longer in expected config)", siteId) + if err := o.peerManager.RemovePeer(siteId); err != nil { + logger.Error("Sync: Failed to remove peer %d: %v", siteId, err) + } else { + // Remove any exit nodes associated with this peer from hole punching + if o.holePunchManager != nil { + removed := o.holePunchManager.RemoveExitNodesByPeer(siteId) + if removed > 0 { + logger.Info("Sync: Removed %d exit nodes associated with peer %d from hole punch rotation", removed, siteId) + } + } + } + } + } + + // Find peers to add (in expected but not in current) and peers to update + for siteId, expectedSite := range expectedPeers { + if _, exists := currentPeerMap[siteId]; !exists { + // New peer - add it using the add flow (with holepunch) + logger.Info("Sync: Adding new peer for site %d", siteId) + + // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it + o.holePunchManager.TriggerHolePunch() + + // TODO: do we need to send the message to the cloud to add the peer that way? + if err := o.peerManager.AddPeer(expectedSite); err != nil { + logger.Error("Sync: Failed to add peer %d: %v", siteId, err) + } else { + logger.Info("Sync: Successfully added peer for site %d", siteId) + } + } else { + // Existing peer - check if update is needed + currentSite := currentPeerMap[siteId] + needsUpdate := false + + // Check if any fields have changed + if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint { + needsUpdate = true + } + if expectedSite.RelayEndpoint != "" && expectedSite.RelayEndpoint != currentSite.RelayEndpoint { + needsUpdate = true + } + if expectedSite.PublicKey != "" && expectedSite.PublicKey != currentSite.PublicKey { + needsUpdate = true + } + if expectedSite.ServerIP != "" && expectedSite.ServerIP != currentSite.ServerIP { + needsUpdate = true + } + if expectedSite.ServerPort != 0 && expectedSite.ServerPort != currentSite.ServerPort { + needsUpdate = true + } + // Check remote subnets + if expectedSite.RemoteSubnets != nil && !slicesEqual(expectedSite.RemoteSubnets, currentSite.RemoteSubnets) { + needsUpdate = true + } + // Check aliases + if expectedSite.Aliases != nil && !aliasesEqual(expectedSite.Aliases, currentSite.Aliases) { + needsUpdate = true + } + + if needsUpdate { + logger.Info("Sync: Updating peer for site %d", siteId) + + // Merge expected data with current data + siteConfig := currentSite + if expectedSite.Endpoint != "" { + siteConfig.Endpoint = expectedSite.Endpoint + } + if expectedSite.RelayEndpoint != "" { + siteConfig.RelayEndpoint = expectedSite.RelayEndpoint + } + if expectedSite.PublicKey != "" { + siteConfig.PublicKey = expectedSite.PublicKey + } + if expectedSite.ServerIP != "" { + siteConfig.ServerIP = expectedSite.ServerIP + } + if expectedSite.ServerPort != 0 { + siteConfig.ServerPort = expectedSite.ServerPort + } + if expectedSite.RemoteSubnets != nil { + siteConfig.RemoteSubnets = expectedSite.RemoteSubnets + } + if expectedSite.Aliases != nil { + siteConfig.Aliases = expectedSite.Aliases + } + + if err := o.peerManager.UpdatePeer(siteConfig); err != nil { + logger.Error("Sync: Failed to update peer %d: %v", siteId, err) + } else { + // If the endpoint changed, trigger holepunch to refresh NAT mappings + if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint { + logger.Info("Sync: Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", siteId) + o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetServerHolepunchInterval() + } + logger.Info("Sync: Successfully updated peer for site %d", siteId) + } + } + } + } + + logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers)) +} diff --git a/olm/olm.go b/olm/olm.go index 4cbb391..85dcbe6 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -2,13 +2,12 @@ package olm import ( "context" - "encoding/json" "fmt" "net" + "net/http" + _ "net/http/pprof" "os" - "runtime" - "strconv" - "strings" + "sync" "time" "github.com/fosrl/newt/bind" @@ -28,40 +27,55 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -var ( - privateKey wgtypes.Key - connected bool - dev *device.Device - uapiListener net.Listener - tdev tun.Device - middleDev *olmDevice.MiddleDevice +type Olm struct { + privateKey wgtypes.Key + logFile *os.File + + connected bool + tunnelRunning bool + + uapiListener net.Listener + dev *device.Device + tdev tun.Device + middleDev *olmDevice.MiddleDevice + sharedBind *bind.SharedBind + dnsProxy *dns.DNSProxy apiServer *api.API - olmClient *websocket.Client - tunnelCancel context.CancelFunc - tunnelRunning bool - sharedBind *bind.SharedBind + websocket *websocket.Client holePunchManager *holepunch.Manager - globalConfig GlobalConfig - tunnelConfig TunnelConfig - globalCtx context.Context - stopRegister func() - stopPeerSend func() - updateRegister func(newData interface{}) - stopPing chan struct{} peerManager *peers.PeerManager -) + // Power mode management + currentPowerMode string + powerModeMu sync.Mutex + wakeUpTimer *time.Timer + wakeUpDebounce time.Duration + + olmCtx context.Context + tunnelCancel context.CancelFunc + + olmConfig OlmConfig + tunnelConfig TunnelConfig + + stopRegister func() + updateRegister func(newData any) + + stopServerPing func() + + stopPeerSend func() +} // initTunnelInfo creates the shared UDP socket and holepunch manager. // This is used during initial tunnel setup and when switching organizations. -func initTunnelInfo(clientID string) error { - var err error - privateKey, err = wgtypes.GeneratePrivateKey() +func (o *Olm) initTunnelInfo(clientID string) error { + privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { logger.Error("Failed to generate private key: %v", err) return err } + o.privateKey = privateKey + sourcePort, err := util.FindAvailableUDPPort(49152, 65535) if err != nil { return fmt.Errorf("failed to find available UDP port: %w", err) @@ -77,57 +91,92 @@ func initTunnelInfo(clientID string) error { return fmt.Errorf("failed to create UDP socket: %w", err) } - sharedBind, err = bind.New(udpConn) + sharedBind, err := bind.New(udpConn) if err != nil { - udpConn.Close() + _ = udpConn.Close() return fmt.Errorf("failed to create shared bind: %w", err) } + o.sharedBind = sharedBind + // Add a reference for the hole punch senders (creator already has one reference for WireGuard) sharedBind.AddRef() logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) // Create the holepunch manager - holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) + o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) return nil } -func Init(ctx context.Context, config GlobalConfig) { - globalConfig = config - globalCtx = ctx - +func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) + // Start pprof server if enabled + if config.PprofAddr != "" { + go func() { + logger.Info("Starting pprof server on %s", config.PprofAddr) + if err := http.ListenAndServe(config.PprofAddr, nil); err != nil { + logger.Error("Failed to start pprof server: %v", err) + } + }() + } + + var logFile *os.File + if config.LogFilePath != "" { + file, err := os.OpenFile(config.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + logger.Fatal("Failed to open log file: %v", err) + return nil, err + } + + logger.SetOutput(file) + logFile = file + } + + if config.WakeUpDebounce == 0 { + config.WakeUpDebounce = 3 * time.Second + } + logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() if err != nil { logger.Fatal("Insufficient permissions to create native TUN interface: %v", err) - return + return nil, err } + var apiServer *api.API if config.HTTPAddr != "" { apiServer = api.NewAPI(config.HTTPAddr) } else if config.SocketPath != "" { apiServer = api.NewAPISocket(config.SocketPath) + } else { + // this is so is not null but it cant be started without either the socket path or http addr + apiServer = api.NewAPIStub() } apiServer.SetVersion(config.Version) apiServer.SetAgent(config.Agent) - // Set up API handlers - apiServer.SetHandlers( + newOlm := &Olm{ + logFile: logFile, + olmCtx: ctx, + apiServer: apiServer, + olmConfig: config, + } + + newOlm.registerAPICallbacks() + + return newOlm, nil +} + +func (o *Olm) registerAPICallbacks() { + o.apiServer.SetHandlers( // onConnect func(req api.ConnectionRequest) error { logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) - // Stop any existing tunnel before starting a new one - if olmClient != nil { - logger.Info("Stopping existing tunnel before starting new connection") - StopTunnel() - } - tunnelConfig := TunnelConfig{ Endpoint: req.Endpoint, ID: req.ID, @@ -181,7 +230,7 @@ func Init(ctx context.Context, config GlobalConfig) { // Start the tunnel process with the new credentials if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" { logger.Info("Starting tunnel with new credentials") - go StartTunnel(tunnelConfig) + go o.StartTunnel(tunnelConfig) } return nil @@ -189,66 +238,61 @@ func Init(ctx context.Context, config GlobalConfig) { // onSwitchOrg func(req api.SwitchOrgRequest) error { logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID) - return SwitchOrg(req.OrgID) + return o.SwitchOrg(req.OrgID) }, // onDisconnect func() error { logger.Info("Processing disconnect request via API") - return StopTunnel() + return o.StopTunnel() }, // onExit func() error { logger.Info("Processing shutdown request via API") - Close() - if globalConfig.OnExit != nil { - globalConfig.OnExit() + o.Close() + if o.olmConfig.OnExit != nil { + o.olmConfig.OnExit() } return nil }, ) } -func StartTunnel(config TunnelConfig) { - if tunnelRunning { +func (o *Olm) StartTunnel(config TunnelConfig) { + if o.tunnelRunning { logger.Info("Tunnel already running") return } - tunnelRunning = true // Also set it here in case it is called externally - tunnelConfig = config + o.tunnelRunning = true // Also set it here in case it is called externally + o.tunnelConfig = config // Reset terminated status when tunnel starts - apiServer.SetTerminated(false) + o.apiServer.SetTerminated(false) // debug print out the whole config logger.Debug("Starting tunnel with config: %+v", config) // Create a cancellable context for this tunnel process - tunnelCtx, cancel := context.WithCancel(globalCtx) - tunnelCancel = cancel - defer func() { - tunnelCancel = nil - }() - - // Recreate channels for this tunnel session - stopPing = make(chan struct{}) + tunnelCtx, cancel := context.WithCancel(o.olmCtx) + o.tunnelCancel = cancel var ( - interfaceName = config.InterfaceName - id = config.ID - secret = config.Secret - userToken = config.UserToken + id = config.ID + secret = config.Secret + userToken = config.UserToken ) - apiServer.SetOrgID(config.OrgID) + o.tunnelConfig.InterfaceName = config.InterfaceName - // Create a new olm client using the provided credentials - olm, err := websocket.NewClient( - id, // Use provided ID - secret, // Use provided secret - userToken, // Use provided user token OPTIONAL + o.apiServer.SetOrgID(config.OrgID) + + // Create a new olmClient client using the provided credentials + olmClient, err := websocket.NewClient( + id, + secret, + userToken, config.OrgID, - config.Endpoint, // Use provided endpoint + config.Endpoint, config.PingIntervalDuration, config.PingTimeoutDuration, ) @@ -257,785 +301,77 @@ func StartTunnel(config TunnelConfig) { return } - // Store the client reference globally - olmClient = olm - // Create shared UDP socket and holepunch manager - if err := initTunnelInfo(id); err != nil { + if err := o.initTunnelInfo(id); err != nil { logger.Error("%v", err) return } - olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - - var wgData WgData - - if connected { - logger.Info("Already connected. Ignoring new connection request.") - return - } - - if stopRegister != nil { - stopRegister() - stopRegister = nil - } - - if updateRegister != nil { - updateRegister = nil - } - - // if there is an existing tunnel then close it - if dev != nil { - logger.Info("Got new message. Closing existing tunnel!") - dev.Close() - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &wgData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - tdev, err = func() (tun.Device, error) { - if config.FileDescriptorTun != 0 { - return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) - } - var ifName = interfaceName - if runtime.GOOS == "darwin" { // this is if we dont pass a fd - ifName, err = network.FindUnusedUTUN() - if err != nil { - return nil, err - } - } - return tun.CreateTUN(ifName, config.MTU) - }() - - if err != nil { - logger.Error("Failed to create TUN device: %v", err) - return - } - - if config.FileDescriptorTun == 0 { - if realInterfaceName, err2 := tdev.Name(); err2 == nil { - interfaceName = realInterfaceName - } - } - - // Wrap TUN device with packet filter for DNS proxy - middleDev = olmDevice.NewMiddleDevice(tdev) - - wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") - // Use filtered device instead of raw TUN device - dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger)) - - if config.EnableUAPI { - fileUAPI, err := func() (*os.File, error) { - if config.FileDescriptorUAPI != 0 { - fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32) - if err != nil { - return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) - } - return os.NewFile(uintptr(fd), ""), nil - } - return olmDevice.UapiOpen(interfaceName) - }() - if err != nil { - logger.Error("UAPI listen error: %v", err) - os.Exit(1) - return - } - - uapiListener, err = olmDevice.UapiListen(interfaceName, fileUAPI) - if err != nil { - logger.Error("Failed to listen on uapi socket: %v", err) - os.Exit(1) - } - - go func() { - for { - conn, err := uapiListener.Accept() - if err != nil { - - return - } - go dev.IpcHandle(conn) - } - }() - logger.Info("UAPI listener started") - } - - if err = dev.Up(); err != nil { - logger.Error("Failed to bring up WireGuard device: %v", err) - } - - // Extract interface IP (strip CIDR notation if present) - interfaceIP := wgData.TunnelIP - if strings.Contains(interfaceIP, "/") { - interfaceIP = strings.Split(interfaceIP, "/")[0] - } - - // Create and start DNS proxy - dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP) - if err != nil { - logger.Error("Failed to create DNS proxy: %v", err) - } - - if err = network.ConfigureInterface(interfaceName, wgData.TunnelIP, config.MTU); err != nil { - logger.Error("Failed to configure interface: %v", err) - } - - if network.AddRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet - logger.Error("Failed to add route for utility subnet: %v", err) - } - - // Create peer manager with integrated peer monitoring - peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ - Device: dev, - DNSProxy: dnsProxy, - InterfaceName: interfaceName, - PrivateKey: privateKey, - MiddleDev: middleDev, - LocalIP: interfaceIP, - SharedBind: sharedBind, - WSClient: olm, - APIServer: apiServer, - }) - - for i := range wgData.Sites { - site := wgData.Sites[i] - var siteEndpoint string - // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer - if site.RelayEndpoint != "" { - siteEndpoint = site.RelayEndpoint - } else { - siteEndpoint = site.Endpoint - } - - apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) - - if err := peerManager.AddPeer(site); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - - logger.Info("Configured peer %s", site.PublicKey) - } - - peerManager.Start() - - if err := dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime - logger.Error("Failed to start DNS proxy: %v", err) - } - - if config.OverrideDNS { - // Set up DNS override to use our DNS proxy - if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy); err != nil { - logger.Error("Failed to setup DNS override: %v", err) - return - } - } - - apiServer.SetRegistered(true) - - connected = true - - // Invoke onConnected callback if configured - if globalConfig.OnConnected != nil { - go globalConfig.OnConnected() - } - - logger.Info("WireGuard device created.") - }) - - // Handler for syncing peer configuration - reconciles expected state with actual state - olm.RegisterHandler("olm/sync", func(msg websocket.WSMessage) { - logger.Debug("Received sync message: %v", msg.Data) - - if !connected { - logger.Warn("Not connected, ignoring sync request") - return - } - - if peerManager == nil { - logger.Warn("Peer manager not initialized, ignoring sync request") - return - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling sync data: %v", err) - return - } - - var wgData WgData - if err := json.Unmarshal(jsonData, &wgData); err != nil { - logger.Error("Error unmarshaling sync data: %v", err) - return - } - - // Build a map of expected peers from the incoming data - expectedPeers := make(map[int]peers.SiteConfig) - for _, site := range wgData.Sites { - expectedPeers[site.SiteId] = site - } - - // Get all current peers - currentPeers := peerManager.GetAllPeers() - currentPeerMap := make(map[int]peers.SiteConfig) - for _, peer := range currentPeers { - currentPeerMap[peer.SiteId] = peer - } - - // Find peers to remove (in current but not in expected) - for siteId := range currentPeerMap { - if _, exists := expectedPeers[siteId]; !exists { - logger.Info("Sync: Removing peer for site %d (no longer in expected config)", siteId) - if err := peerManager.RemovePeer(siteId); err != nil { - logger.Error("Sync: Failed to remove peer %d: %v", siteId, err) - } else { - // Remove any exit nodes associated with this peer from hole punching - if holePunchManager != nil { - removed := holePunchManager.RemoveExitNodesByPeer(siteId) - if removed > 0 { - logger.Info("Sync: Removed %d exit nodes associated with peer %d from hole punch rotation", removed, siteId) - } - } - } - } - } - - // Find peers to add (in expected but not in current) and peers to update - for siteId, expectedSite := range expectedPeers { - if _, exists := currentPeerMap[siteId]; !exists { - // New peer - add it using the add flow (with holepunch) - logger.Info("Sync: Adding new peer for site %d", siteId) - - // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it - holePunchManager.TriggerHolePunch() - - // TODO: do we need to send the message to the cloud to add the peer that way? - if err := peerManager.AddPeer(expectedSite); err != nil { - logger.Error("Sync: Failed to add peer %d: %v", siteId, err) - } else { - logger.Info("Sync: Successfully added peer for site %d", siteId) - } - } else { - // Existing peer - check if update is needed - currentSite := currentPeerMap[siteId] - needsUpdate := false - - // Check if any fields have changed - if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint { - needsUpdate = true - } - if expectedSite.RelayEndpoint != "" && expectedSite.RelayEndpoint != currentSite.RelayEndpoint { - needsUpdate = true - } - if expectedSite.PublicKey != "" && expectedSite.PublicKey != currentSite.PublicKey { - needsUpdate = true - } - if expectedSite.ServerIP != "" && expectedSite.ServerIP != currentSite.ServerIP { - needsUpdate = true - } - if expectedSite.ServerPort != 0 && expectedSite.ServerPort != currentSite.ServerPort { - needsUpdate = true - } - // Check remote subnets - if expectedSite.RemoteSubnets != nil && !slicesEqual(expectedSite.RemoteSubnets, currentSite.RemoteSubnets) { - needsUpdate = true - } - // Check aliases - if expectedSite.Aliases != nil && !aliasesEqual(expectedSite.Aliases, currentSite.Aliases) { - needsUpdate = true - } - - if needsUpdate { - logger.Info("Sync: Updating peer for site %d", siteId) - - // Merge expected data with current data - siteConfig := currentSite - if expectedSite.Endpoint != "" { - siteConfig.Endpoint = expectedSite.Endpoint - } - if expectedSite.RelayEndpoint != "" { - siteConfig.RelayEndpoint = expectedSite.RelayEndpoint - } - if expectedSite.PublicKey != "" { - siteConfig.PublicKey = expectedSite.PublicKey - } - if expectedSite.ServerIP != "" { - siteConfig.ServerIP = expectedSite.ServerIP - } - if expectedSite.ServerPort != 0 { - siteConfig.ServerPort = expectedSite.ServerPort - } - if expectedSite.RemoteSubnets != nil { - siteConfig.RemoteSubnets = expectedSite.RemoteSubnets - } - if expectedSite.Aliases != nil { - siteConfig.Aliases = expectedSite.Aliases - } - - if err := peerManager.UpdatePeer(siteConfig); err != nil { - logger.Error("Sync: Failed to update peer %d: %v", siteId, err) - } else { - // If the endpoint changed, trigger holepunch to refresh NAT mappings - if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint { - logger.Info("Sync: Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", siteId) - holePunchManager.TriggerHolePunch() - holePunchManager.ResetInterval() - } - logger.Info("Sync: Successfully updated peer for site %d", siteId) - } - } - } - } - - logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers)) - }) - - olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { - logger.Debug("Received update-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var updateData peers.SiteConfig - if err := json.Unmarshal(jsonData, &updateData); err != nil { - logger.Error("Error unmarshaling update data: %v", err) - return - } - - // Get existing peer from PeerManager - existingPeer, exists := peerManager.GetPeer(updateData.SiteId) - if !exists { - logger.Warn("Peer with site ID %d not found", updateData.SiteId) - return - } - - // Create updated site config by merging with existing data - siteConfig := existingPeer - - if updateData.Endpoint != "" { - siteConfig.Endpoint = updateData.Endpoint - } - if updateData.RelayEndpoint != "" { - siteConfig.RelayEndpoint = updateData.RelayEndpoint - } - if updateData.PublicKey != "" { - siteConfig.PublicKey = updateData.PublicKey - } - if updateData.ServerIP != "" { - siteConfig.ServerIP = updateData.ServerIP - } - if updateData.ServerPort != 0 { - siteConfig.ServerPort = updateData.ServerPort - } - if updateData.RemoteSubnets != nil { - siteConfig.RemoteSubnets = updateData.RemoteSubnets - } - - if err := peerManager.UpdatePeer(siteConfig); err != nil { - logger.Error("Failed to update peer: %v", err) - return - } - - // If the endpoint changed, trigger holepunch to refresh NAT mappings - if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { - logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) - holePunchManager.TriggerHolePunch() - holePunchManager.ResetInterval() - } - - // Update successful - logger.Info("Successfully updated peer for site %d", updateData.SiteId) - }) - - // Handler for adding a new peer - olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { - logger.Debug("Received add-peer message: %v", msg.Data) - - if stopPeerSend != nil { - stopPeerSend() - stopPeerSend = nil - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var siteConfig peers.SiteConfig - if err := json.Unmarshal(jsonData, &siteConfig); err != nil { - logger.Error("Error unmarshaling add data: %v", err) - return - } - - holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it - - if err := peerManager.AddPeer(siteConfig); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - - // Add successful - logger.Info("Successfully added peer for site %d", siteConfig.SiteId) - }) - - // Handler for removing a peer - olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) { - logger.Debug("Received remove-peer message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var removeData peers.PeerRemove - if err := json.Unmarshal(jsonData, &removeData); err != nil { - logger.Error("Error unmarshaling remove data: %v", err) - return - } - - if err := peerManager.RemovePeer(removeData.SiteId); err != nil { - logger.Error("Failed to remove peer: %v", err) - return - } - - // Remove any exit nodes associated with this peer from hole punching - if holePunchManager != nil { - removed := holePunchManager.RemoveExitNodesByPeer(removeData.SiteId) - if removed > 0 { - logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId) - } - } - - // Remove successful - logger.Info("Successfully removed peer for site %d", removeData.SiteId) - }) - - // Handler for adding remote subnets to a peer - olm.RegisterHandler("olm/wg/peer/data/add", func(msg websocket.WSMessage) { - logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var addSubnetsData peers.PeerAdd - if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil { - logger.Error("Error unmarshaling add-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(addSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId) - return - } - - // Add new subnets - for _, subnet := range addSubnetsData.RemoteSubnets { - if err := peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to add allowed IP %s: %v", subnet, err) - } - } - - // Add new aliases - for _, alias := range addSubnetsData.Aliases { - if err := peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil { - logger.Error("Failed to add alias %s: %v", alias.Alias, err) - } - } - }) - - // Handler for removing remote subnets from a peer - olm.RegisterHandler("olm/wg/peer/data/remove", func(msg websocket.WSMessage) { - logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var removeSubnetsData peers.RemovePeerData - if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil { - logger.Error("Error unmarshaling remove-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(removeSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId) - return - } - - // Remove subnets - for _, subnet := range removeSubnetsData.RemoteSubnets { - if err := peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to remove allowed IP %s: %v", subnet, err) - } - } - - // Remove aliases - for _, alias := range removeSubnetsData.Aliases { - if err := peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil { - logger.Error("Failed to remove alias %s: %v", alias.Alias, err) - } - } - }) - - // Handler for updating remote subnets of a peer (remove old, add new in one operation) - olm.RegisterHandler("olm/wg/peer/data/update", func(msg websocket.WSMessage) { - logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var updateSubnetsData peers.UpdatePeerData - if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil { - logger.Error("Error unmarshaling update-remote-subnets data: %v", err) - return - } - - if _, exists := peerManager.GetPeer(updateSubnetsData.SiteId); !exists { - logger.Debug("Peer %d not found for removing remote subnets and aliases", updateSubnetsData.SiteId) - return - } - - // Add new subnets BEFORE removing old ones to preserve shared subnets - // This ensures that if an old and new subnet are the same on different peers, - // the route won't be temporarily removed - for _, subnet := range updateSubnetsData.NewRemoteSubnets { - if err := peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to add allowed IP %s: %v", subnet, err) - } - } - - // Remove old subnets after new ones are added - for _, subnet := range updateSubnetsData.OldRemoteSubnets { - if err := peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil { - logger.Error("Failed to remove allowed IP %s: %v", subnet, err) - } - } - - // Add new aliases BEFORE removing old ones to preserve shared IP addresses - // This ensures that if an old and new alias share the same IP, the IP won't be - // temporarily removed from the allowed IPs list - for _, alias := range updateSubnetsData.NewAliases { - if err := peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil { - logger.Error("Failed to add alias %s: %v", alias.Alias, err) - } - } - - // Remove old aliases after new ones are added - for _, alias := range updateSubnetsData.OldAliases { - if err := peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil { - logger.Error("Failed to remove alias %s: %v", alias.Alias, err) - } - } - - logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) - }) - - olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) { - logger.Debug("Received relay-peer message: %v", msg.Data) - - // Check if peerManager is still valid (may be nil during shutdown) - if peerManager == nil { - logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") - return - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var relayData peers.RelayPeerData - if err := json.Unmarshal(jsonData, &relayData); err != nil { - logger.Error("Error unmarshaling relay data: %v", err) - return - } - - primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - // Update HTTP server to mark this peer as using relay - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) - - peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort) - }) - - olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) { - logger.Debug("Received unrelay-peer message: %v", msg.Data) - - // Check if peerManager is still valid (may be nil during shutdown) - if peerManager == nil { - logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") - return - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling data: %v", err) - return - } - - var relayData peers.UnRelayPeerData - if err := json.Unmarshal(jsonData, &relayData); err != nil { - logger.Error("Error unmarshaling relay data: %v", err) - return - } - - primaryRelay, err := util.ResolveDomain(relayData.Endpoint) - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - // Update HTTP server to mark this peer as using relay - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) - - peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) - }) + // Handlers for managing connection status + olmClient.RegisterHandler("olm/wg/connect", o.handleConnect) + olmClient.RegisterHandler("olm/terminate", o.handleTerminate) + + // Handlers for managing peers + olmClient.RegisterHandler("olm/wg/peer/add", o.handleWgPeerAdd) + olmClient.RegisterHandler("olm/wg/peer/remove", o.handleWgPeerRemove) + olmClient.RegisterHandler("olm/wg/peer/update", o.handleWgPeerUpdate) + olmClient.RegisterHandler("olm/wg/peer/relay", o.handleWgPeerRelay) + olmClient.RegisterHandler("olm/wg/peer/unrelay", o.handleWgPeerUnrelay) + + // Handlers for managing remote subnets to a peer + olmClient.RegisterHandler("olm/wg/peer/data/add", o.handleWgPeerAddData) + olmClient.RegisterHandler("olm/wg/peer/data/remove", o.handleWgPeerRemoveData) + olmClient.RegisterHandler("olm/wg/peer/data/update", o.handleWgPeerUpdateData) // Handler for peer handshake - adds exit node to holepunch rotation and notifies server - olm.RegisterHandler("olm/wg/peer/holepunch/site/add", func(msg websocket.WSMessage) { - logger.Debug("Received peer-handshake message: %v", msg.Data) + olmClient.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) + olmClient.RegisterHandler("olm/sync", o.handleSync) - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling handshake data: %v", err) - return - } - - var handshakeData struct { - SiteId int `json:"siteId"` - ExitNode struct { - PublicKey string `json:"publicKey"` - Endpoint string `json:"endpoint"` - RelayPort uint16 `json:"relayPort"` - } `json:"exitNode"` - } - - if err := json.Unmarshal(jsonData, &handshakeData); err != nil { - logger.Error("Error unmarshaling handshake data: %v", err) - return - } - - // Get existing peer from PeerManager - _, exists := peerManager.GetPeer(handshakeData.SiteId) - if exists { - logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) - return - } - - relayPort := handshakeData.ExitNode.RelayPort - if relayPort == 0 { - relayPort = 21820 // default relay port - } - - siteId := handshakeData.SiteId - exitNode := holepunch.ExitNode{ - Endpoint: handshakeData.ExitNode.Endpoint, - RelayPort: relayPort, - PublicKey: handshakeData.ExitNode.PublicKey, - SiteIds: []int{siteId}, - } - - added := holePunchManager.AddExitNode(exitNode) - if added { - logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) - } else { - logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) - } - - holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt - holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud - - // Send handshake acknowledgment back to server with retry - stopPeerSend, _ = olm.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": handshakeData.SiteId, - }, 1*time.Second) - - logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) - }) - - olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { - logger.Info("Received terminate message") - apiServer.SetTerminated(true) - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) - apiServer.ClearPeerStatuses() - network.ClearNetworkSettings() - Close() - - if globalConfig.OnTerminated != nil { - go globalConfig.OnTerminated() - } - }) - - olm.RegisterHandler("pong", func(msg websocket.WSMessage) { - logger.Debug("Received pong message") - }) - - olm.OnConnect(func() error { + olmClient.OnConnect(func() error { logger.Info("Websocket Connected") - apiServer.SetConnectionStatus(true) + o.apiServer.SetConnectionStatus(true) - if connected { + // restart the ping if we need to + if o.stopServerPing == nil { + o.stopServerPing, _ = olmClient.SendMessageInterval("olm/ping", map[string]any{ + "timestamp": time.Now().Unix(), + "userToken": olmClient.GetConfig().UserToken, + }, 30*time.Second, -1) // -1 means dont time out with the max attempts + } + + if o.connected { logger.Debug("Already connected, skipping registration") return nil } - publicKey := privateKey.PublicKey() + publicKey := o.privateKey.PublicKey() // delay for 500ms to allow for time for the hp to get processed time.Sleep(500 * time.Millisecond) - if stopRegister == nil { + if o.stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) - stopRegister, updateRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + o.stopRegister, o.updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]any{ "publicKey": publicKey.String(), "relay": !config.Holepunch, - "olmVersion": globalConfig.Version, - "olmAgent": globalConfig.Agent, + "olmVersion": o.olmConfig.Version, + "olmAgent": o.olmConfig.Agent, "orgId": config.OrgID, "userToken": userToken, - }, 1*time.Second) + }, 1*time.Second, 10) // Invoke onRegistered callback if configured - if globalConfig.OnRegistered != nil { - go globalConfig.OnRegistered() + if o.olmConfig.OnRegistered != nil { + go o.olmConfig.OnRegistered() } } - go keepSendingPing(olm) - return nil }) - olm.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { - holePunchManager.SetToken(token) + olmClient.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { + o.holePunchManager.SetToken(token) logger.Debug("Got exit nodes for hole punching: %v", exitNodes) @@ -1059,114 +395,113 @@ func StartTunnel(config TunnelConfig) { // Start hole punching using the manager logger.Info("Starting hole punch for %d exit nodes", len(exitNodes)) - if err := holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { + if err := o.holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { logger.Warn("Failed to start hole punch: %v", err) } }) - olm.OnAuthError(func(statusCode int, message string) { + olmClient.OnAuthError(func(statusCode int, message string) { logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message) - apiServer.SetTerminated(true) - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) - apiServer.ClearPeerStatuses() + o.apiServer.SetTerminated(true) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) + o.apiServer.ClearPeerStatuses() network.ClearNetworkSettings() - Close() + o.Close() - if globalConfig.OnAuthError != nil { - go globalConfig.OnAuthError(statusCode, message) + if o.olmConfig.OnAuthError != nil { + go o.olmConfig.OnAuthError(statusCode, message) } - if globalConfig.OnTerminated != nil { - go globalConfig.OnTerminated() + if o.olmConfig.OnTerminated != nil { + go o.olmConfig.OnTerminated() } }) // Connect to the WebSocket server - if err := olm.Connect(); err != nil { + if err := olmClient.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) return } - defer olm.Close() + defer func() { _ = olmClient.Close() }() + + o.websocket = olmClient // Wait for context cancellation <-tunnelCtx.Done() logger.Info("Tunnel process context cancelled, cleaning up") } -func Close() { +func (o *Olm) Close() { // Restore original DNS configuration // we do this first to avoid any DNS issues if something else gets stuck if err := dnsOverride.RestoreDNSOverride(); err != nil { logger.Error("Failed to restore DNS: %v", err) } - // Stop hole punch manager - if holePunchManager != nil { - holePunchManager.Stop() - holePunchManager = nil + if o.holePunchManager != nil { + o.holePunchManager.Stop() + o.holePunchManager = nil } - if stopPing != nil { - select { - case <-stopPing: - // Channel already closed - default: - close(stopPing) - } + if o.stopServerPing != nil { + o.stopServerPing() + o.stopServerPing = nil } - if stopRegister != nil { - stopRegister() - stopRegister = nil + if o.stopRegister != nil { + o.stopRegister() + o.stopRegister = nil } - if updateRegister != nil { - updateRegister = nil + // Close() also calls Stop() internally + if o.peerManager != nil { + o.peerManager.Close() + o.peerManager = nil } - if peerManager != nil { - peerManager.Close() // Close() also calls Stop() internally - peerManager = nil + if o.uapiListener != nil { + _ = o.uapiListener.Close() + o.uapiListener = nil } - if uapiListener != nil { - uapiListener.Close() - uapiListener = nil + if o.logFile != nil { + _ = o.logFile.Close() + o.logFile = nil } // Stop DNS proxy first - it uses the middleDev for packet filtering - logger.Debug("Stopping DNS proxy") - if dnsProxy != nil { - dnsProxy.Stop() - dnsProxy = nil + if o.dnsProxy != nil { + logger.Debug("Stopping DNS proxy") + o.dnsProxy.Stop() + o.dnsProxy = nil } // Close MiddleDevice first - this closes the TUN and signals the closed channel // This unblocks the pump goroutine and allows WireGuard's TUN reader to exit - logger.Debug("Closing MiddleDevice") - if middleDev != nil { - middleDev.Close() - middleDev = nil + // Note: o.tdev is closed by o.middleDev.Close() since middleDev wraps it + if o.middleDev != nil { + logger.Debug("Closing MiddleDevice") + _ = o.middleDev.Close() + o.middleDev = nil } - // Note: tdev is closed by middleDev.Close() since middleDev wraps it - tdev = nil // Now close WireGuard device - its TUN reader should have exited by now - logger.Debug("Closing WireGuard device") - if dev != nil { - dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference - dev = nil + // This will call sharedBind.Close() which releases WireGuard's reference + if o.dev != nil { + logger.Debug("Closing WireGuard device") + o.dev.Close() + o.dev = nil } - // Release the hole punch reference to the shared bind - if sharedBind != nil { - // Release hole punch reference (WireGuard already released its reference via dev.Close()) - logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount()) - sharedBind.Release() - sharedBind = nil + // Release the hole punch reference to the shared bind (WireGuard already + // released its reference via dev.Close()) + if o.sharedBind != nil { + logger.Debug("Releasing shared bind (refcount before release: %d)", o.sharedBind.GetRefCount()) + _ = o.sharedBind.Release() logger.Info("Released shared UDP bind") + o.sharedBind = nil } logger.Info("Olm service stopped") @@ -1174,78 +509,239 @@ func Close() { // StopTunnel stops just the tunnel process and websocket connection // without shutting down the entire application -func StopTunnel() error { +func (o *Olm) StopTunnel() error { logger.Info("Stopping tunnel process") + if !o.tunnelRunning { + logger.Debug("Tunnel not running, nothing to stop") + return nil + } + // Cancel the tunnel context if it exists - if tunnelCancel != nil { - tunnelCancel() + if o.tunnelCancel != nil { + o.tunnelCancel() // Give it a moment to clean up time.Sleep(200 * time.Millisecond) } // Close the websocket connection - if olmClient != nil { - olmClient.Close() - olmClient = nil + if o.websocket != nil { + _ = o.websocket.Close() + o.websocket = nil } - Close() + o.Close() // Reset the connected state - connected = false - tunnelRunning = false + o.connected = false + o.tunnelRunning = false // Update API server status - apiServer.SetConnectionStatus(false) - apiServer.SetRegistered(false) + o.apiServer.SetConnectionStatus(false) + o.apiServer.SetRegistered(false) network.ClearNetworkSettings() - apiServer.ClearPeerStatuses() + o.apiServer.ClearPeerStatuses() logger.Info("Tunnel process stopped") return nil } -func StopApi() error { - if apiServer != nil { - err := apiServer.Stop() +func (o *Olm) StopApi() error { + if o.apiServer != nil { + err := o.apiServer.Stop() if err != nil { return fmt.Errorf("failed to stop API server: %w", err) } } + return nil } -func StartApi() error { - if apiServer != nil { - err := apiServer.Start() +func (o *Olm) StartApi() error { + if o.apiServer != nil { + err := o.apiServer.Start() if err != nil { return fmt.Errorf("failed to start API server: %w", err) } } + return nil } -func GetStatus() api.StatusResponse { - return apiServer.GetStatus() +func (o *Olm) GetStatus() api.StatusResponse { + return o.apiServer.GetStatus() } -func SwitchOrg(orgID string) error { +func (o *Olm) SwitchOrg(orgID string) error { logger.Info("Processing org switch request to orgId: %s", orgID) // stop the tunnel - if err := StopTunnel(); err != nil { + if err := o.StopTunnel(); err != nil { return fmt.Errorf("failed to stop existing tunnel: %w", err) } // Update the org ID in the API server and global config - apiServer.SetOrgID(orgID) + o.apiServer.SetOrgID(orgID) - tunnelConfig.OrgID = orgID + o.tunnelConfig.OrgID = orgID // Restart the tunnel with the same config but new org ID - go StartTunnel(tunnelConfig) + go o.StartTunnel(o.tunnelConfig) return nil } + +// SetPowerMode switches between normal and low power modes +// In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes +// In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored +// Wake-up has a 3-second debounce to prevent rapid flip-flopping; sleep is immediate +func (o *Olm) SetPowerMode(mode string) error { + // Validate mode + if mode != "normal" && mode != "low" { + return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode) + } + + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + + // If already in the requested mode, return early + if o.currentPowerMode == mode { + // Cancel any pending wake-up timer if we're already in normal mode + if mode == "normal" && o.wakeUpTimer != nil { + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } + logger.Debug("Already in %s power mode", mode) + return nil + } + + if mode == "low" { + // Low Power Mode: Cancel any pending wake-up and immediately go to sleep + + // Cancel pending wake-up timer if any + if o.wakeUpTimer != nil { + logger.Debug("Cancelling pending wake-up timer") + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } + + logger.Info("Switching to low power mode") + + if o.websocket != nil { + logger.Info("Disconnecting websocket for low power mode") + if err := o.websocket.Disconnect(); err != nil { + logger.Error("Error disconnecting websocket: %v", err) + } + } + + lowPowerInterval := 10 * time.Minute + + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval) + peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval) + logger.Info("Set monitoring intervals to 10 minutes for low power mode") + } + o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable + } + + if o.holePunchManager != nil { + o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval) + } + + o.currentPowerMode = "low" + logger.Info("Switched to low power mode") + + } else { + // Normal Power Mode: Start debounce timer before actually waking up + + // If there's already a pending wake-up timer, don't start another + if o.wakeUpTimer != nil { + logger.Debug("Wake-up already pending, ignoring duplicate request") + return nil + } + + logger.Info("Wake-up requested, starting %v debounce timer", o.wakeUpDebounce) + + o.wakeUpTimer = time.AfterFunc(o.wakeUpDebounce, func() { + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + + // Clear the timer reference + o.wakeUpTimer = nil + + // Double-check we're still in low power mode (could have changed) + if o.currentPowerMode == "normal" { + logger.Debug("Already in normal mode after debounce, skipping wake-up") + return + } + + logger.Info("Debounce complete, switching to normal power mode") + + logger.Info("Reconnecting websocket for normal power mode") + if o.websocket != nil { + if err := o.websocket.Connect(); err != nil { + logger.Error("Failed to reconnect websocket: %v", err) + return + } + } + + // Restore intervals and reconnect websocket + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.ResetPeerHolepunchInterval() + peerMonitor.ResetPeerInterval() + } + + o.peerManager.UpdateAllPeersPersistentKeepalive(5) + } + + if o.holePunchManager != nil { + o.holePunchManager.ResetServerHolepunchInterval() + } + + o.currentPowerMode = "normal" + logger.Info("Switched to normal power mode") + }) + } + + return nil +} + +func (o *Olm) AddDevice(fd uint32) error { + if o.middleDev == nil { + return fmt.Errorf("middle device is not initialized") + } + + if o.tunnelConfig.MTU == 0 { + return fmt.Errorf("tunnel MTU is not set") + } + + tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU) + if err != nil { + return fmt.Errorf("failed to create TUN device from fd: %v", err) + } + + // Update interface name if available + if realInterfaceName, err2 := tdev.Name(); err2 == nil { + o.tunnelConfig.InterfaceName = realInterfaceName + } + + // Replace the existing TUN device in the middle device with the new one + o.middleDev.AddDevice(tdev) + + logger.Info("Added device from file descriptor %d", fd) + + return nil +} + +func GetNetworkSettingsJSON() (string, error) { + return network.GetJSON() +} + +func GetNetworkSettingsIncrementor() int { + return network.GetIncrementor() +} diff --git a/olm/peer.go b/olm/peer.go new file mode 100644 index 0000000..9bc842e --- /dev/null +++ b/olm/peer.go @@ -0,0 +1,195 @@ +package olm + +import ( + "encoding/json" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "github.com/fosrl/olm/peers" + "github.com/fosrl/olm/websocket" +) + +func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { + logger.Debug("Received add-peer message: %v", msg.Data) + + if o.stopPeerSend != nil { + o.stopPeerSend() + o.stopPeerSend = nil + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var siteConfig peers.SiteConfig + if err := json.Unmarshal(jsonData, &siteConfig); err != nil { + logger.Error("Error unmarshaling add data: %v", err) + return + } + + _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it + + if err := o.peerManager.AddPeer(siteConfig); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + logger.Info("Successfully added peer for site %d", siteConfig.SiteId) +} + +func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { + logger.Debug("Received remove-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeData peers.PeerRemove + if err := json.Unmarshal(jsonData, &removeData); err != nil { + logger.Error("Error unmarshaling remove data: %v", err) + return + } + + if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil { + logger.Error("Failed to remove peer: %v", err) + return + } + + // Remove any exit nodes associated with this peer from hole punching + if o.holePunchManager != nil { + removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId) + if removed > 0 { + logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId) + } + } + + logger.Info("Successfully removed peer for site %d", removeData.SiteId) +} + +func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) { + logger.Debug("Received update-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var updateData peers.SiteConfig + if err := json.Unmarshal(jsonData, &updateData); err != nil { + logger.Error("Error unmarshaling update data: %v", err) + return + } + + // Get existing peer from PeerManager + existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId) + if !exists { + logger.Warn("Peer with site ID %d not found", updateData.SiteId) + return + } + + // Create updated site config by merging with existing data + siteConfig := existingPeer + + if updateData.Endpoint != "" { + siteConfig.Endpoint = updateData.Endpoint + } + if updateData.RelayEndpoint != "" { + siteConfig.RelayEndpoint = updateData.RelayEndpoint + } + if updateData.PublicKey != "" { + siteConfig.PublicKey = updateData.PublicKey + } + if updateData.ServerIP != "" { + siteConfig.ServerIP = updateData.ServerIP + } + if updateData.ServerPort != 0 { + siteConfig.ServerPort = updateData.ServerPort + } + if updateData.RemoteSubnets != nil { + siteConfig.RemoteSubnets = updateData.RemoteSubnets + } + + if err := o.peerManager.UpdatePeer(siteConfig); err != nil { + logger.Error("Failed to update peer: %v", err) + return + } + + // If the endpoint changed, trigger holepunch to refresh NAT mappings + if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { + logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) + _ = o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetServerHolepunchInterval() + } + + logger.Info("Successfully updated peer for site %d", updateData.SiteId) +} + +func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) { + logger.Debug("Received relay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if o.peerManager == nil { + logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.RelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) + if err != nil { + logger.Error("Failed to resolve primary relay endpoint: %v", err) + return + } + + // Update HTTP server to mark this peer as using relay + o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) + + o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort) +} + +func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { + logger.Debug("Received unrelay-peer message: %v", msg.Data) + + // Check if peerManager is still valid (may be nil during shutdown) + if o.peerManager == nil { + logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)") + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.UnRelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + + primaryRelay, err := util.ResolveDomain(relayData.Endpoint) + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + // Update HTTP server to mark this peer as using relay + o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) + + o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) +} diff --git a/olm/types.go b/olm/types.go index b7153af..397eab9 100644 --- a/olm/types.go +++ b/olm/types.go @@ -12,9 +12,10 @@ type WgData struct { UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses } -type GlobalConfig struct { +type OlmConfig struct { // Logging - LogLevel string + LogLevel string + LogFilePath string // HTTP server EnableAPI bool @@ -22,6 +23,11 @@ type GlobalConfig struct { SocketPath string Version string Agent string + + WakeUpDebounce time.Duration + + // Debugging + PprofAddr string // Address to serve pprof on (e.g., "localhost:6060") // Callbacks OnRegistered func() diff --git a/olm/util.go b/olm/util.go index d138755..73572dc 100644 --- a/olm/util.go +++ b/olm/util.go @@ -1,60 +1,9 @@ package olm import ( - "time" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/network" "github.com/fosrl/olm/peers" - "github.com/fosrl/olm/websocket" ) -func sendPing(olm *websocket.Client) error { - err := olm.SendMessage("olm/ping", map[string]interface{}{ - "timestamp": time.Now().Unix(), - "userToken": olm.GetConfig().UserToken, - }) - if err != nil { - logger.Error("Failed to send ping message: %v", err) - return err - } - logger.Debug("Sent ping message") - return nil -} - -func keepSendingPing(olm *websocket.Client) { - // Send ping immediately on startup - if err := sendPing(olm); err != nil { - logger.Error("Failed to send initial ping: %v", err) - } else { - logger.Info("Sent initial ping message") - } - - // Set up ticker for one minute intervals - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-stopPing: - logger.Info("Stopping ping messages") - return - case <-ticker.C: - if err := sendPing(olm); err != nil { - logger.Error("Failed to send periodic ping: %v", err) - } - } - } -} - -func GetNetworkSettingsJSON() (string, error) { - return network.GetJSON() -} - -func GetNetworkSettingsIncrementor() int { - return network.GetIncrementor() -} - // slicesEqual compares two string slices for equality (order-independent) func slicesEqual(a, b []string) bool { if len(a) != len(b) { diff --git a/peers/manager.go b/peers/manager.go index af781e5..0566775 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -50,6 +50,8 @@ type PeerManager struct { // key is the CIDR string, value is a set of siteIds that want this IP allowedIPClaims map[string]map[int]bool APIServer *api.API + + PersistentKeepalive int } // NewPeerManager creates a new PeerManager with an internal PeerMonitor @@ -84,6 +86,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) { return peer, ok } +// GetPeerMonitor returns the internal peer monitor instance +func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor { + pm.mu.RLock() + defer pm.mu.RUnlock() + return pm.peerMonitor +} + func (pm *PeerManager) GetAllPeers() []SiteConfig { pm.mu.RLock() defer pm.mu.RUnlock() @@ -120,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -159,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { return nil } +// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once +// without recreating them. Returns a map of siteId to error for any peers that failed to update. +func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error { + pm.mu.RLock() + defer pm.mu.RUnlock() + + pm.PersistentKeepalive = interval + + errors := make(map[int]error) + + for siteId, peer := range pm.peers { + err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval) + if err != nil { + errors[siteId] = err + } + } + + if len(errors) == 0 { + return nil + } + return errors +} + func (pm *PeerManager) RemovePeer(siteId int) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -238,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error { ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId) wgConfig := promotedPeer wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } @@ -314,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -324,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId) promotedWgConfig := promotedPeer promotedWgConfig.AllowedIps = promotedOwnedIPs - if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 27bc408..387b82f 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -31,8 +31,7 @@ type PeerMonitor struct { monitors map[int]*Client mutex sync.Mutex running bool - interval time.Duration - timeout time.Duration + timeout time.Duration maxAttempts int wsClient *websocket.Client @@ -42,7 +41,7 @@ type PeerMonitor struct { stack *stack.Stack ep *channel.Endpoint activePorts map[uint16]bool - portsLock sync.Mutex + portsLock sync.RWMutex nsCtx context.Context nsCancel context.CancelFunc nsWg sync.WaitGroup @@ -50,17 +49,26 @@ type PeerMonitor struct { // Holepunch testing fields sharedBind *bind.SharedBind holepunchTester *holepunch.HolepunchTester - holepunchInterval time.Duration holepunchTimeout time.Duration holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchStatus map[int]bool // siteID -> connected status - holepunchStopChan chan struct{} + holepunchStopChan chan struct{} + holepunchUpdateChan chan struct{} // Relay tracking fields relayedPeers map[int]bool // siteID -> whether the peer is currently relayed holepunchMaxAttempts int // max consecutive failures before triggering relay holepunchFailures map[int]int // siteID -> consecutive failure count + // Exponential backoff fields for holepunch monitor + defaultHolepunchMinInterval time.Duration // Minimum interval (initial) + defaultHolepunchMaxInterval time.Duration + holepunchMinInterval time.Duration // Minimum interval (initial) + holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) + holepunchBackoffMultiplier float64 // Multiplier for each stable check + holepunchStableCount map[int]int // siteID -> consecutive stable status count + holepunchCurrentInterval time.Duration // Current interval with backoff applied + // Rapid initial test fields rapidTestInterval time.Duration // interval between rapid test attempts rapidTestTimeout time.Duration // timeout for each rapid test attempt @@ -78,7 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), - interval: 2 * time.Second, // Default check interval (faster) timeout: 3 * time.Second, maxAttempts: 3, wsClient: wsClient, @@ -88,7 +95,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe nsCtx: ctx, nsCancel: cancel, sharedBind: sharedBind, - holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds holepunchTimeout: 2 * time.Second, // Faster timeout holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), @@ -101,6 +107,15 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total apiServer: apiServer, wgConnectionStatus: make(map[int]bool), + // Exponential backoff settings for holepunch monitor + defaultHolepunchMinInterval: 2 * time.Second, + defaultHolepunchMaxInterval: 30 * time.Second, + holepunchMinInterval: 2 * time.Second, + holepunchMaxInterval: 30 * time.Second, + holepunchBackoffMultiplier: 1.5, + holepunchStableCount: make(map[int]int), + holepunchCurrentInterval: 2 * time.Second, + holepunchUpdateChan: make(chan struct{}, 1), } if err := pm.initNetstack(); err != nil { @@ -116,41 +131,75 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe } // SetInterval changes how frequently peers are checked -func (pm *PeerMonitor) SetInterval(interval time.Duration) { +func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.interval = interval - // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetPacketInterval(interval) + client.SetPacketInterval(minInterval, maxInterval) } + + logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval) } -// SetTimeout changes the timeout for waiting for responses -func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { +func (pm *PeerMonitor) ResetPeerInterval() { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.timeout = timeout - - // Update timeout for all existing monitors + // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetTimeout(timeout) + client.ResetPacketInterval() } } -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (pm *PeerMonitor) SetMaxAttempts(attempts int) { +// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) { + pm.mutex.Lock() + pm.holepunchMinInterval = minInterval + pm.holepunchMaxInterval = maxInterval + // Reset current interval to the new minimum + pm.holepunchCurrentInterval = minInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } +} + +// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.maxAttempts = attempts + return pm.holepunchMinInterval, pm.holepunchMaxInterval +} - // Update max attempts for all existing monitors - for _, client := range pm.monitors { - client.SetMaxAttempts(attempts) +func (pm *PeerMonitor) ResetPeerHolepunchInterval() { + pm.mutex.Lock() + pm.holepunchMinInterval = pm.defaultHolepunchMinInterval + pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval + pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } } } @@ -169,10 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st return err } - client.SetPacketInterval(pm.interval) - client.SetTimeout(pm.timeout) - client.SetMaxAttempts(pm.maxAttempts) - pm.monitors[siteID] = client pm.holepunchEndpoints[siteID] = holepunchEndpoint @@ -470,31 +515,59 @@ func (pm *PeerMonitor) stopHolepunchMonitor() { logger.Info("Stopped holepunch connection monitor") } -// runHolepunchMonitor runs the holepunch monitoring loop +// runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff func (pm *PeerMonitor) runHolepunchMonitor() { - ticker := time.NewTicker(pm.holepunchInterval) - defer ticker.Stop() + pm.mutex.Lock() + pm.holepunchCurrentInterval = pm.holepunchMinInterval + pm.mutex.Unlock() - // Do initial check immediately - pm.checkHolepunchEndpoints() + timer := time.NewTimer(0) // Fire immediately for initial check + defer timer.Stop() for { select { case <-pm.holepunchStopChan: return - case <-ticker.C: - pm.checkHolepunchEndpoints() + case <-pm.holepunchUpdateChan: + // Interval settings changed, reset to minimum + pm.mutex.Lock() + pm.holepunchCurrentInterval = pm.holepunchMinInterval + currentInterval := pm.holepunchCurrentInterval + pm.mutex.Unlock() + + timer.Reset(currentInterval) + logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval) + case <-timer.C: + anyStatusChanged := pm.checkHolepunchEndpoints() + + pm.mutex.Lock() + if anyStatusChanged { + // Reset to minimum interval on any status change + pm.holepunchCurrentInterval = pm.holepunchMinInterval + } else { + // Apply exponential backoff when stable + newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier) + if newInterval > pm.holepunchMaxInterval { + newInterval = pm.holepunchMaxInterval + } + pm.holepunchCurrentInterval = newInterval + } + currentInterval := pm.holepunchCurrentInterval + pm.mutex.Unlock() + + timer.Reset(currentInterval) } } } // checkHolepunchEndpoints tests all holepunch endpoints -func (pm *PeerMonitor) checkHolepunchEndpoints() { +// Returns true if any endpoint's status changed +func (pm *PeerMonitor) checkHolepunchEndpoints() bool { pm.mutex.Lock() // Check if we're still running before doing any work if !pm.running { pm.mutex.Unlock() - return + return false } endpoints := make(map[int]string, len(pm.holepunchEndpoints)) for siteID, endpoint := range pm.holepunchEndpoints { @@ -504,8 +577,10 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { maxAttempts := pm.holepunchMaxAttempts pm.mutex.Unlock() + anyStatusChanged := false + for siteID, endpoint := range endpoints { - // logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) + logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint) result := pm.holepunchTester.TestEndpoint(endpoint, timeout) pm.mutex.Lock() @@ -529,7 +604,9 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { pm.mutex.Unlock() // Log status changes - if !exists || previousStatus != result.Success { + statusChanged := !exists || previousStatus != result.Success + if statusChanged { + anyStatusChanged = true if result.Success { logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT) } else { @@ -562,7 +639,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { pm.mutex.Unlock() if !stillRunning { - return // Stop processing if shutdown is in progress + return anyStatusChanged // Stop processing if shutdown is in progress } if !result.Success && !isRelayed && failureCount >= maxAttempts { @@ -579,6 +656,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { } } } + + return anyStatusChanged } // GetHolepunchStatus returns the current holepunch status for all endpoints @@ -650,55 +729,55 @@ func (pm *PeerMonitor) Close() { logger.Debug("PeerMonitor: Cleanup complete") } -// TestPeer tests connectivity to a specific peer -func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { - pm.mutex.Lock() - client, exists := pm.monitors[siteID] - pm.mutex.Unlock() +// // TestPeer tests connectivity to a specific peer +// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { +// pm.mutex.Lock() +// client, exists := pm.monitors[siteID] +// pm.mutex.Unlock() - if !exists { - return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) - } +// if !exists { +// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) +// } - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - defer cancel() +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// defer cancel() - connected, rtt := client.TestConnection(ctx) - return connected, rtt, nil -} +// connected, rtt := client.TestPeerConnection(ctx) +// return connected, rtt, nil +// } -// TestAllPeers tests connectivity to all peers -func (pm *PeerMonitor) TestAllPeers() map[int]struct { - Connected bool - RTT time.Duration -} { - pm.mutex.Lock() - peers := make(map[int]*Client, len(pm.monitors)) - for siteID, client := range pm.monitors { - peers[siteID] = client - } - pm.mutex.Unlock() +// // TestAllPeers tests connectivity to all peers +// func (pm *PeerMonitor) TestAllPeers() map[int]struct { +// Connected bool +// RTT time.Duration +// } { +// pm.mutex.Lock() +// peers := make(map[int]*Client, len(pm.monitors)) +// for siteID, client := range pm.monitors { +// peers[siteID] = client +// } +// pm.mutex.Unlock() - results := make(map[int]struct { - Connected bool - RTT time.Duration - }) - for siteID, client := range peers { - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - connected, rtt := client.TestConnection(ctx) - cancel() +// results := make(map[int]struct { +// Connected bool +// RTT time.Duration +// }) +// for siteID, client := range peers { +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// connected, rtt := client.TestPeerConnection(ctx) +// cancel() - results[siteID] = struct { - Connected bool - RTT time.Duration - }{ - Connected: connected, - RTT: rtt, - } - } +// results[siteID] = struct { +// Connected bool +// RTT time.Duration +// }{ +// Connected: connected, +// RTT: rtt, +// } +// } - return results -} +// return results +// } // initNetstack initializes the gvisor netstack func (pm *PeerMonitor) initNetstack() error { @@ -770,9 +849,9 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool { } // Check if we are listening on this port - pm.portsLock.Lock() + pm.portsLock.RLock() active := pm.activePorts[uint16(port)] - pm.portsLock.Unlock() + pm.portsLock.RUnlock() if !active { return false @@ -803,13 +882,12 @@ func (pm *PeerMonitor) runPacketSender() { defer pm.nsWg.Done() logger.Debug("PeerMonitor: Packet sender goroutine started") - // Use a ticker to periodically check for packets without blocking indefinitely - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-pm.nsCtx.Done(): + // Use blocking ReadContext instead of polling - much more CPU efficient + // This will block until a packet is available or context is cancelled + pkt := pm.ep.ReadContext(pm.nsCtx) + if pkt == nil { + // Context was cancelled or endpoint closed logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets") // Drain any remaining packets before exiting for { @@ -821,36 +899,28 @@ func (pm *PeerMonitor) runPacketSender() { } logger.Debug("PeerMonitor: Packet sender goroutine exiting") return - case <-ticker.C: - // Try to read packets in batches - for i := 0; i < 10; i++ { - pkt := pm.ep.Read() - if pkt == nil { - break - } - - // Extract packet data - slices := pkt.AsSlices() - if len(slices) > 0 { - var totalSize int - for _, slice := range slices { - totalSize += len(slice) - } - - buf := make([]byte, totalSize) - pos := 0 - for _, slice := range slices { - copy(buf[pos:], slice) - pos += len(slice) - } - - // Inject into MiddleDevice (outbound to WG) - pm.middleDev.InjectOutbound(buf) - } - - pkt.DecRef() - } } + + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + pm.middleDev.InjectOutbound(buf) + } + + pkt.DecRef() } } diff --git a/peers/monitor/wgtester.go b/peers/monitor/wgtester.go index dac2008..f06759a 100644 --- a/peers/monitor/wgtester.go +++ b/peers/monitor/wgtester.go @@ -32,10 +32,19 @@ type Client struct { monitorLock sync.Mutex connLock sync.Mutex // Protects connection operations shutdownCh chan struct{} + updateCh chan struct{} packetInterval time.Duration timeout time.Duration maxAttempts int dialer Dialer + + // Exponential backoff fields + defaultMinInterval time.Duration // Default minimum interval (initial) + defaultMaxInterval time.Duration // Default maximum interval (cap for backoff) + minInterval time.Duration // Minimum interval (initial) + maxInterval time.Duration // Maximum interval (cap for backoff) + backoffMultiplier float64 // Multiplier for each stable check + stableCountToBackoff int // Number of stable checks before backing off } // Dialer is a function that creates a connection @@ -50,28 +59,59 @@ type ConnectionStatus struct { // NewClient creates a new connection test client func NewClient(serverAddr string, dialer Dialer) (*Client, error) { return &Client{ - serverAddr: serverAddr, - shutdownCh: make(chan struct{}), - packetInterval: 2 * time.Second, - timeout: 500 * time.Millisecond, // Timeout for individual packets - maxAttempts: 3, // Default max attempts - dialer: dialer, + serverAddr: serverAddr, + shutdownCh: make(chan struct{}), + updateCh: make(chan struct{}, 1), + packetInterval: 2 * time.Second, + defaultMinInterval: 2 * time.Second, + defaultMaxInterval: 30 * time.Second, + minInterval: 2 * time.Second, + maxInterval: 30 * time.Second, + backoffMultiplier: 1.5, + stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off + timeout: 500 * time.Millisecond, // Timeout for individual packets + maxAttempts: 3, // Default max attempts + dialer: dialer, }, nil } // SetPacketInterval changes how frequently packets are sent in monitor mode -func (c *Client) SetPacketInterval(interval time.Duration) { - c.packetInterval = interval +func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) { + c.monitorLock.Lock() + c.packetInterval = minInterval + c.minInterval = minInterval + c.maxInterval = maxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() + + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } -// SetTimeout changes the timeout for waiting for responses -func (c *Client) SetTimeout(timeout time.Duration) { - c.timeout = timeout -} +func (c *Client) ResetPacketInterval() { + c.monitorLock.Lock() + c.packetInterval = c.defaultMinInterval + c.minInterval = c.defaultMinInterval + c.maxInterval = c.defaultMaxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (c *Client) SetMaxAttempts(attempts int) { - c.maxAttempts = attempts + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } // UpdateServerAddr updates the server address and resets the connection @@ -125,9 +165,10 @@ func (c *Client) ensureConnection() error { return nil } -// TestConnection checks if the connection to the server is working +// TestPeerConnection checks if the connection to the server is working // Returns true if connected, false otherwise -func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { +func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) { + logger.Debug("wgtester: testing connection to peer %s", c.serverAddr) if err := c.ensureConnection(); err != nil { logger.Warn("Failed to ensure connection: %v", err) return false, 0 @@ -138,6 +179,9 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { binary.BigEndian.PutUint32(packet[0:4], magicHeader) packet[4] = packetTypeRequest + // Reusable response buffer + responseBuffer := make([]byte, packetSize) + // Send multiple attempts as specified for attempt := 0; attempt < c.maxAttempts; attempt++ { select { @@ -157,20 +201,17 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { return false, 0 } - // logger.Debug("Attempting to send monitor packet to %s", c.serverAddr) _, err := c.conn.Write(packet) if err != nil { c.connLock.Unlock() logger.Info("Error sending packet: %v", err) continue } - // logger.Debug("Successfully sent monitor packet") // Set read deadline c.conn.SetReadDeadline(time.Now().Add(c.timeout)) // Wait for response - responseBuffer := make([]byte, packetSize) n, err := c.conn.Read(responseBuffer) c.connLock.Unlock() @@ -211,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - return c.TestConnection(ctx) + return c.TestPeerConnection(ctx) } // MonitorCallback is the function type for connection status change callbacks @@ -238,28 +279,61 @@ func (c *Client) StartMonitor(callback MonitorCallback) error { go func() { var lastConnected bool firstRun := true + stableCount := 0 + currentInterval := c.minInterval - ticker := time.NewTicker(c.packetInterval) - defer ticker.Stop() + timer := time.NewTimer(currentInterval) + defer timer.Stop() for { select { case <-c.shutdownCh: return - case <-ticker.C: + case <-c.updateCh: + // Interval settings changed, reset to minimum + c.monitorLock.Lock() + currentInterval = c.minInterval + c.monitorLock.Unlock() + + // Reset backoff state + stableCount = 0 + + timer.Reset(currentInterval) + logger.Debug("Packet interval updated, reset to %v", currentInterval) + case <-timer.C: ctx, cancel := context.WithTimeout(context.Background(), c.timeout) - connected, rtt := c.TestConnection(ctx) + connected, rtt := c.TestPeerConnection(ctx) cancel() + statusChanged := connected != lastConnected + // Callback if status changed or it's the first check - if connected != lastConnected || firstRun { + if statusChanged || firstRun { callback(ConnectionStatus{ Connected: connected, RTT: rtt, }) lastConnected = connected firstRun = false + // Reset backoff on status change + stableCount = 0 + currentInterval = c.minInterval + } else { + // Status is stable, increment counter + stableCount++ + + // Apply exponential backoff after stable threshold + if stableCount >= c.stableCountToBackoff { + newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier) + if newInterval > c.maxInterval { + newInterval = c.maxInterval + } + currentInterval = newInterval + } } + + // Reset timer with current interval + timer.Reset(currentInterval) } } }() diff --git a/peers/peer.go b/peers/peer.go index 9370b9d..8211fa4 100644 --- a/peers/peer.go +++ b/peers/peer.go @@ -11,7 +11,7 @@ import ( ) // ConfigurePeer sets up or updates a peer within the WireGuard device -func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error { +func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error { var endpoint string if relay && siteConfig.RelayEndpoint != "" { endpoint = formatEndpoint(siteConfig.RelayEndpoint) @@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes } configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) - configBuilder.WriteString("persistent_keepalive_interval=5\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive)) config := configBuilder.String() logger.Debug("Configuring peer with config: %s", config) @@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [ return nil } +// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it +func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error { + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) + configBuilder.WriteString("update_only=true\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval)) + + config := configBuilder.String() + logger.Debug("Updating persistent keepalive for peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err) + } + + return nil +} + func formatEndpoint(endpoint string) string { if strings.Contains(endpoint, ":") { return endpoint diff --git a/websocket/client.go b/websocket/client.go index 74b0401..ba70494 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -79,6 +79,7 @@ type Client struct { handlersMux sync.RWMutex reconnectInterval time.Duration isConnected bool + isDisconnected bool // Flag to track if client is intentionally disconnected reconnectMux sync.RWMutex pingInterval time.Duration pingTimeout time.Duration @@ -91,6 +92,10 @@ type Client struct { configNeedsSave bool // Flag to track if config needs to be saved configVersion int // Latest config version received from server configVersionMux sync.RWMutex + token string // Cached authentication token + exitNodes []ExitNode // Cached exit nodes from token response + tokenMux sync.RWMutex // Protects token and exitNodes + forceNewToken bool // Flag to force fetching a new token on next connection } type ClientOption func(*Client) @@ -177,6 +182,9 @@ func (c *Client) GetConfig() *Config { // Connect establishes the WebSocket connection func (c *Client) Connect() error { + if c.isDisconnected { + c.isDisconnected = false + } go c.connectWithRetry() return nil } @@ -209,9 +217,25 @@ func (c *Client) Close() error { return nil } +// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later. +func (c *Client) Disconnect() error { + c.isDisconnected = true + c.setConnected(false) + + if c.conn != nil { + c.writeMux.Lock() + c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + c.writeMux.Unlock() + err := c.conn.Close() + c.conn = nil + return err + } + return nil +} + // SendMessage sends a message through the WebSocket connection func (c *Client) SendMessage(messageType string, data interface{}) error { - if c.conn == nil { + if c.isDisconnected || c.conn == nil { return fmt.Errorf("not connected") } @@ -220,14 +244,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { Data: data, } - logger.Debug("Sending message: %s, data: %+v", messageType, data) + logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data) c.writeMux.Lock() defer c.writeMux.Unlock() return c.conn.WriteJSON(msg) } -func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) { +func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) { stopChan := make(chan struct{}) updateChan := make(chan interface{}) var dataMux sync.Mutex @@ -235,30 +259,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter go func() { count := 0 - maxAttempts := 10 - err := c.SendMessage(messageType, currentData) // Send immediately - if err != nil { - logger.Error("Failed to send initial message: %v", err) + send := func() { + if c.isDisconnected || c.conn == nil { + return + } + err := c.SendMessage(messageType, currentData) + if err != nil { + logger.Error("websocket: Failed to send message: %v", err) + } + count++ } - count++ + + send() // Send immediately ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: - if count >= maxAttempts { - logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) + if maxAttempts != -1 && count >= maxAttempts { + logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) return } dataMux.Lock() - err = c.SendMessage(messageType, currentData) + send() dataMux.Unlock() - if err != nil { - logger.Error("Failed to send message: %v", err) - } - count++ case newData := <-updateChan: dataMux.Lock() // Merge newData into currentData if both are maps @@ -281,6 +307,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter case <-stopChan: return } + // Suspend sending if disconnected + for c.isDisconnected { + select { + case <-stopChan: + return + case <-time.After(500 * time.Millisecond): + } + } } }() return func() { @@ -327,7 +361,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { tlsConfig = &tls.Config{} } tlsConfig.InsecureSkipVerify = true - logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } tokenData := map[string]interface{}{ @@ -356,7 +390,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { req.Header.Set("X-CSRF-Token", "x-csrf-protection") // print out the request for debugging - logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) + logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) // Make the request client := &http.Client{} @@ -373,7 +407,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) // Return AuthError for 401/403 status codes if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { @@ -389,7 +423,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - logger.Error("Failed to decode token response.") + logger.Error("websocket: Failed to decode token response.") return "", nil, fmt.Errorf("failed to decode token response: %w", err) } @@ -401,7 +435,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { return "", nil, fmt.Errorf("received empty token from server") } - logger.Debug("Received token: %s", tokenResp.Data.Token) + logger.Debug("websocket: Received token: %s", tokenResp.Data.Token) return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil } @@ -427,7 +461,7 @@ func (c *Client) connectWithRetry() { continue } // For other errors (5xx, network issues), continue retrying - logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) + logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) time.Sleep(c.reconnectInterval) continue } @@ -437,15 +471,25 @@ func (c *Client) connectWithRetry() { } func (c *Client) establishConnection() error { - // Get token for authentication - token, exitNodes, err := c.getToken() - if err != nil { - return fmt.Errorf("failed to get token: %w", err) - } + // Get token for authentication - reuse cached token unless forced to get new one + c.tokenMux.Lock() + needNewToken := c.token == "" || c.forceNewToken + if needNewToken { + token, exitNodes, err := c.getToken() + if err != nil { + c.tokenMux.Unlock() + return fmt.Errorf("failed to get token: %w", err) + } + c.token = token + c.exitNodes = exitNodes + c.forceNewToken = false - if c.onTokenUpdate != nil { - c.onTokenUpdate(token, exitNodes) + if c.onTokenUpdate != nil { + c.onTokenUpdate(token, exitNodes) + } } + token := c.token + c.tokenMux.Unlock() // Parse the base URL to determine protocol and hostname baseURL, err := url.Parse(c.baseURL) @@ -480,7 +524,7 @@ func (c *Client) establishConnection() error { // Use new TLS configuration method if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { - logger.Info("Setting up TLS configuration for WebSocket connection") + logger.Info("websocket: Setting up TLS configuration for WebSocket connection") tlsConfig, err := c.setupTLS() if err != nil { return fmt.Errorf("failed to setup TLS configuration: %w", err) @@ -494,11 +538,23 @@ func (c *Client) establishConnection() error { dialer.TLSClientConfig = &tls.Config{} } dialer.TLSClientConfig.InsecureSkipVerify = true - logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } - conn, _, err := dialer.Dial(u.String(), nil) + conn, resp, err := dialer.Dial(u.String(), nil) if err != nil { + // Check if this is an unauthorized error (401) + if resp != nil && resp.StatusCode == http.StatusUnauthorized { + logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized") + // Force getting a new token on next reconnect attempt + c.tokenMux.Lock() + c.forceNewToken = true + c.tokenMux.Unlock() + return &AuthError{ + StatusCode: http.StatusUnauthorized, + Message: "WebSocket connection unauthorized", + } + } return fmt.Errorf("failed to connect to WebSocket: %w", err) } @@ -512,7 +568,7 @@ func (c *Client) establishConnection() error { if c.onConnect != nil { if err := c.onConnect(); err != nil { - logger.Error("OnConnect callback failed: %v", err) + logger.Error("websocket: OnConnect callback failed: %v", err) } } @@ -525,9 +581,9 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Handle new separate certificate configuration if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" { - logger.Info("Loading separate certificate files for mTLS") - logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile) - logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile) + logger.Info("websocket: Loading separate certificate files for mTLS") + logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile) + logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile) // Load client certificate and key cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile) @@ -538,7 +594,7 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Load CA certificates for remote validation if specified if len(c.tlsConfig.CAFiles) > 0 { - logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles) + logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles) caCertPool := x509.NewCertPool() for _, caFile := range c.tlsConfig.CAFiles { caCert, err := os.ReadFile(caFile) @@ -564,13 +620,13 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Fallback to existing PKCS12 implementation for backward compatibility if c.tlsConfig.PKCS12File != "" { - logger.Info("Loading PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)") return c.setupPKCS12TLS() } // Legacy fallback using config.TlsClientCert if c.config.TlsClientCert != "" { - logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)") return loadClientCertificate(c.config.TlsClientCert) } @@ -592,7 +648,7 @@ func (c *Client) pingMonitor() { case <-c.done: return case <-ticker.C: - if c.conn == nil { + if c.isDisconnected || c.conn == nil { return } // Send application-level ping with config version @@ -616,7 +672,7 @@ func (c *Client) pingMonitor() { // Expected during shutdown return default: - logger.Error("Ping failed: %v", err) + logger.Error("websocket: Ping failed: %v", err) c.reconnect() return } @@ -665,18 +721,24 @@ func (c *Client) readPumpWithDisconnectDetection() { var msg WSMessage err := c.conn.ReadJSON(&msg) if err != nil { - // Check if we're shutting down before logging error + // Check if we're shutting down or explicitly disconnected before logging error select { case <-c.done: // Expected during shutdown, don't log as error - logger.Debug("WebSocket connection closed during shutdown") + logger.Debug("websocket: connection closed during shutdown") return default: + // Check if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: connection closed: client was explicitly disconnected") + return + } + // Unexpected error during normal operation if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { - logger.Error("WebSocket read error: %v", err) + logger.Error("websocket: read error: %v", err) } else { - logger.Debug("WebSocket connection closed: %v", err) + logger.Debug("websocket: connection closed: %v", err) } return // triggers reconnect via defer } @@ -703,6 +765,12 @@ func (c *Client) reconnect() { c.conn = nil } + // Don't reconnect if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected") + return + } + // Only reconnect if we're not shutting down select { case <-c.done: @@ -720,7 +788,7 @@ func (c *Client) setConnected(status bool) { // LoadClientCertificate Helper method to load client certificates (PKCS12 format) func loadClientCertificate(p12Path string) (*tls.Config, error) { - logger.Info("Loading tls-client-cert %s", p12Path) + logger.Info("websocket: Loading tls-client-cert %s", p12Path) // Read the PKCS12 file p12Data, err := os.ReadFile(p12Path) if err != nil {