mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
Merge branch 'power-state' into msg-delivery
Former-commit-id: bda6606098
This commit is contained in:
6
.github/workflows/cicd.yml
vendored
6
.github/workflows/cicd.yml
vendored
@@ -11,7 +11,9 @@ permissions:
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- "*"
|
- "[0-9]+.[0-9]+.[0-9]+"
|
||||||
|
- "[0-9]+.[0-9]+.[0-9]+-rc.[0-9]+"
|
||||||
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
version:
|
version:
|
||||||
@@ -273,7 +275,7 @@ jobs:
|
|||||||
tags: |
|
tags: |
|
||||||
type=semver,pattern={{version}},value=${{ env.TAG }}
|
type=semver,pattern={{version}},value=${{ env.TAG }}
|
||||||
type=semver,pattern={{major}}.{{minor}},value=${{ env.TAG }},enable=${{ env.PUBLISH_MINOR == 'true' && env.IS_RC != 'true' }}
|
type=semver,pattern={{major}}.{{minor}},value=${{ env.TAG }},enable=${{ env.PUBLISH_MINOR == 'true' && env.IS_RC != 'true' }}
|
||||||
type=raw,value=latest,enable=${{ env.PUBLISH_LATEST == 'true' && env.IS_RC != 'true' }}
|
type=raw,value=latest,enable=${{ env.IS_RC != 'true' }}
|
||||||
flavor: |
|
flavor: |
|
||||||
latest=false
|
latest=false
|
||||||
labels: |
|
labels: |
|
||||||
|
|||||||
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -22,4 +22,4 @@ jobs:
|
|||||||
run: make go-build-release
|
run: make go-build-release
|
||||||
|
|
||||||
- name: Build Docker image
|
- name: Build Docker image
|
||||||
run: make docker-build-release
|
run: make docker-build
|
||||||
|
|||||||
3
Makefile
3
Makefile
@@ -5,6 +5,9 @@ all: local
|
|||||||
local:
|
local:
|
||||||
CGO_ENABLED=0 go build -o ./bin/olm
|
CGO_ENABLED=0 go build -o ./bin/olm
|
||||||
|
|
||||||
|
docker-build:
|
||||||
|
docker build -t fosrl/olm:latest .
|
||||||
|
|
||||||
docker-build-release:
|
docker-build-release:
|
||||||
@if [ -z "$(tag)" ]; then \
|
@if [ -z "$(tag)" ]; then \
|
||||||
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
|
echo "Error: tag is required. Usage: make docker-build-release tag=<tag>"; \
|
||||||
|
|||||||
@@ -20,13 +20,7 @@ When Olm receives WireGuard control messages, it will use the information encode
|
|||||||
|
|
||||||
## Hole Punching
|
## Hole Punching
|
||||||
|
|
||||||
In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to newt. If you want to disable hole punching, use the `--disable-holepunch` flag. Hole punching attempts to orchestrate a NAT hole punch between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil.
|
In the default mode, olm uses both relaying through Gerbil and NAT hole punching to connect to Newt. Hole punching attempts to orchestrate a NAT traversal between the two sites so that traffic flows directly, which can save data costs and improve speed. If hole punching fails, traffic will fall back to relaying through Gerbil.
|
||||||
|
|
||||||
Right now, basic NAT hole punching is supported. We plan to add:
|
|
||||||
|
|
||||||
- [ ] Birthday paradox
|
|
||||||
- [ ] UPnP
|
|
||||||
- [ ] LAN detection
|
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
|
|||||||
55
api/api.go
55
api/api.go
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -62,23 +63,26 @@ type StatusResponse struct {
|
|||||||
|
|
||||||
// API represents the HTTP server and its state
|
// API represents the HTTP server and its state
|
||||||
type API struct {
|
type API struct {
|
||||||
addr string
|
addr string
|
||||||
socketPath string
|
socketPath string
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
server *http.Server
|
server *http.Server
|
||||||
|
|
||||||
onConnect func(ConnectionRequest) error
|
onConnect func(ConnectionRequest) error
|
||||||
onSwitchOrg func(SwitchOrgRequest) error
|
onSwitchOrg func(SwitchOrgRequest) error
|
||||||
onDisconnect func() error
|
onDisconnect func() error
|
||||||
onExit func() error
|
onExit func() error
|
||||||
|
|
||||||
statusMu sync.RWMutex
|
statusMu sync.RWMutex
|
||||||
peerStatuses map[int]*PeerStatus
|
peerStatuses map[int]*PeerStatus
|
||||||
connectedAt time.Time
|
connectedAt time.Time
|
||||||
isConnected bool
|
isConnected bool
|
||||||
isRegistered bool
|
isRegistered bool
|
||||||
isTerminated bool
|
isTerminated bool
|
||||||
version string
|
|
||||||
agent string
|
version string
|
||||||
orgID string
|
agent string
|
||||||
|
orgID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPI creates a new HTTP server that listens on a TCP address
|
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||||
@@ -101,6 +105,14 @@ func NewAPISocket(socketPath string) *API {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewAPIStub() *API {
|
||||||
|
s := &API{
|
||||||
|
peerStatuses: make(map[int]*PeerStatus),
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// SetHandlers sets the callback functions for handling API requests
|
// SetHandlers sets the callback functions for handling API requests
|
||||||
func (s *API) SetHandlers(
|
func (s *API) SetHandlers(
|
||||||
onConnect func(ConnectionRequest) error,
|
onConnect func(ConnectionRequest) error,
|
||||||
@@ -116,6 +128,10 @@ func (s *API) SetHandlers(
|
|||||||
|
|
||||||
// Start starts the HTTP server
|
// Start starts the HTTP server
|
||||||
func (s *API) Start() error {
|
func (s *API) Start() error {
|
||||||
|
if s.socketPath == "" && s.addr == "" {
|
||||||
|
return fmt.Errorf("either socketPath or addr must be provided to start the API server")
|
||||||
|
}
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/connect", s.handleConnect)
|
mux.HandleFunc("/connect", s.handleConnect)
|
||||||
mux.HandleFunc("/status", s.handleStatus)
|
mux.HandleFunc("/status", s.handleStatus)
|
||||||
@@ -160,7 +176,7 @@ func (s *API) Stop() error {
|
|||||||
|
|
||||||
// Close the server first, which will also close the listener gracefully
|
// Close the server first, which will also close the listener gracefully
|
||||||
if s.server != nil {
|
if s.server != nil {
|
||||||
s.server.Close()
|
_ = s.server.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up socket file if using Unix socket
|
// Clean up socket file if using Unix socket
|
||||||
@@ -345,7 +361,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Return a success response
|
// Return a success response
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusAccepted)
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
"status": "connection request accepted",
|
"status": "connection request accepted",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -358,7 +374,6 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.statusMu.RLock()
|
s.statusMu.RLock()
|
||||||
defer s.statusMu.RUnlock()
|
|
||||||
|
|
||||||
resp := StatusResponse{
|
resp := StatusResponse{
|
||||||
Connected: s.isConnected,
|
Connected: s.isConnected,
|
||||||
@@ -371,8 +386,18 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
NetworkSettings: network.GetSettings(),
|
NetworkSettings: network.GetSettings(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.statusMu.RUnlock()
|
||||||
|
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(resp)
|
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleHealth handles the /health endpoint
|
// handleHealth handles the /health endpoint
|
||||||
@@ -384,7 +409,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -401,7 +426,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Return a success response first
|
// Return a success response first
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
"status": "shutdown initiated",
|
"status": "shutdown initiated",
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -450,7 +475,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Return a success response
|
// Return a success response
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
"status": "org switch request accepted",
|
"status": "org switch request accepted",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -484,7 +509,7 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Return a success response
|
// Return a success response
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
"status": "disconnect initiated",
|
"status": "disconnect initiated",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
@@ -18,14 +21,68 @@ type FilterRule struct {
|
|||||||
Handler PacketHandler
|
Handler PacketHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
// closeAwareDevice wraps a tun.Device along with a flag
|
||||||
type MiddleDevice struct {
|
// indicating whether its Close method was called.
|
||||||
|
type closeAwareDevice struct {
|
||||||
|
isClosed atomic.Bool
|
||||||
tun.Device
|
tun.Device
|
||||||
rules []FilterRule
|
closeEventCh chan struct{}
|
||||||
mutex sync.RWMutex
|
wg sync.WaitGroup
|
||||||
readCh chan readResult
|
closeOnce sync.Once
|
||||||
injectCh chan []byte
|
}
|
||||||
closed chan struct{}
|
|
||||||
|
func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice {
|
||||||
|
return &closeAwareDevice{
|
||||||
|
Device: tunDevice,
|
||||||
|
isClosed: atomic.Bool{},
|
||||||
|
closeEventCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirectEvents redirects the Events() method of the underlying tun.Device
|
||||||
|
// to the given channel.
|
||||||
|
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
|
||||||
|
c.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer c.wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-c.Device.Events():
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ev == tun.EventDown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case out <- ev:
|
||||||
|
case <-c.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-c.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close calls the underlying Device's Close method
|
||||||
|
// after setting isClosed to true.
|
||||||
|
func (c *closeAwareDevice) Close() (err error) {
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.isClosed.Store(true)
|
||||||
|
close(c.closeEventCh)
|
||||||
|
err = c.Device.Close()
|
||||||
|
c.wg.Wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeAwareDevice) IsClosed() bool {
|
||||||
|
return c.isClosed.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
type readResult struct {
|
type readResult struct {
|
||||||
@@ -36,58 +93,136 @@ type readResult struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||||
|
// and supports swapping the underlying device.
|
||||||
|
type MiddleDevice struct {
|
||||||
|
devices []*closeAwareDevice
|
||||||
|
mu sync.Mutex
|
||||||
|
cond *sync.Cond
|
||||||
|
rules []FilterRule
|
||||||
|
rulesMutex sync.RWMutex
|
||||||
|
readCh chan readResult
|
||||||
|
injectCh chan []byte
|
||||||
|
closed atomic.Bool
|
||||||
|
events chan tun.Event
|
||||||
|
}
|
||||||
|
|
||||||
// NewMiddleDevice creates a new filtered TUN device wrapper
|
// NewMiddleDevice creates a new filtered TUN device wrapper
|
||||||
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||||
d := &MiddleDevice{
|
d := &MiddleDevice{
|
||||||
Device: device,
|
devices: make([]*closeAwareDevice, 0),
|
||||||
rules: make([]FilterRule, 0),
|
rules: make([]FilterRule, 0),
|
||||||
readCh: make(chan readResult),
|
readCh: make(chan readResult, 16),
|
||||||
injectCh: make(chan []byte, 100),
|
injectCh: make(chan []byte, 100),
|
||||||
closed: make(chan struct{}),
|
events: make(chan tun.Event, 16),
|
||||||
}
|
}
|
||||||
go d.pump()
|
d.cond = sync.NewCond(&d.mu)
|
||||||
|
|
||||||
|
if device != nil {
|
||||||
|
d.AddDevice(device)
|
||||||
|
}
|
||||||
|
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *MiddleDevice) pump() {
|
// AddDevice adds a new underlying TUN device, closing any previous one
|
||||||
|
func (d *MiddleDevice) AddDevice(device tun.Device) {
|
||||||
|
d.mu.Lock()
|
||||||
|
if d.closed.Load() {
|
||||||
|
d.mu.Unlock()
|
||||||
|
_ = device.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var toClose *closeAwareDevice
|
||||||
|
if len(d.devices) > 0 {
|
||||||
|
toClose = d.devices[len(d.devices)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
cad := newCloseAwareDevice(device)
|
||||||
|
cad.redirectEvents(d.events)
|
||||||
|
|
||||||
|
d.devices = []*closeAwareDevice{cad}
|
||||||
|
|
||||||
|
// Start pump for the new device
|
||||||
|
go d.pump(cad)
|
||||||
|
|
||||||
|
d.cond.Broadcast()
|
||||||
|
d.mu.Unlock()
|
||||||
|
|
||||||
|
if toClose != nil {
|
||||||
|
logger.Debug("MiddleDevice: Closing previous device")
|
||||||
|
if err := toClose.Close(); err != nil {
|
||||||
|
logger.Debug("MiddleDevice: Error closing previous device: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MiddleDevice) pump(dev *closeAwareDevice) {
|
||||||
const defaultOffset = 16
|
const defaultOffset = 16
|
||||||
batchSize := d.Device.BatchSize()
|
batchSize := dev.BatchSize()
|
||||||
logger.Debug("MiddleDevice: pump started")
|
logger.Debug("MiddleDevice: pump started for device")
|
||||||
|
|
||||||
|
// Recover from panic if readCh is closed while we're trying to send
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.Debug("MiddleDevice: pump recovered from panic (channel closed)")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Check closed first with priority
|
// Check if this device is closed
|
||||||
select {
|
if dev.IsClosed() {
|
||||||
case <-d.closed:
|
logger.Debug("MiddleDevice: pump exiting, device is closed")
|
||||||
logger.Debug("MiddleDevice: pump exiting due to closed channel")
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if MiddleDevice itself is closed
|
||||||
|
if d.closed.Load() {
|
||||||
|
logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed")
|
||||||
return
|
return
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate buffers for reading
|
// Allocate buffers for reading
|
||||||
// We allocate new buffers for each read to avoid race conditions
|
|
||||||
// since we pass them to the channel
|
|
||||||
bufs := make([][]byte, batchSize)
|
bufs := make([][]byte, batchSize)
|
||||||
sizes := make([]int, batchSize)
|
sizes := make([]int, batchSize)
|
||||||
for i := range bufs {
|
for i := range bufs {
|
||||||
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
|
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := d.Device.Read(bufs, sizes, defaultOffset)
|
n, err := dev.Read(bufs, sizes, defaultOffset)
|
||||||
|
|
||||||
// Check closed again after read returns
|
// Check if device was closed during read
|
||||||
select {
|
if dev.IsClosed() {
|
||||||
case <-d.closed:
|
logger.Debug("MiddleDevice: pump exiting, device closed during read")
|
||||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if MiddleDevice was closed during read
|
||||||
|
if d.closed.Load() {
|
||||||
|
logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to send the result - check closed state first to avoid sending on closed channel
|
||||||
|
if d.closed.Load() {
|
||||||
|
logger.Debug("MiddleDevice: pump exiting, device closed before send")
|
||||||
return
|
return
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now try to send the result
|
|
||||||
select {
|
select {
|
||||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||||
case <-d.closed:
|
default:
|
||||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
|
// Channel full, check if we should exit
|
||||||
return
|
if dev.IsClosed() || d.closed.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Try again with blocking
|
||||||
|
select {
|
||||||
|
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||||
|
case <-dev.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -99,16 +234,28 @@ func (d *MiddleDevice) pump() {
|
|||||||
|
|
||||||
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
|
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
|
||||||
func (d *MiddleDevice) InjectOutbound(packet []byte) {
|
func (d *MiddleDevice) InjectOutbound(packet []byte) {
|
||||||
|
if d.closed.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Use defer/recover to handle panic from sending on closed channel
|
||||||
|
// This can happen during shutdown race conditions
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.Debug("MiddleDevice: InjectOutbound recovered from panic (channel closed)")
|
||||||
|
}
|
||||||
|
}()
|
||||||
select {
|
select {
|
||||||
case d.injectCh <- packet:
|
case d.injectCh <- packet:
|
||||||
case <-d.closed:
|
default:
|
||||||
|
// Channel full, drop packet
|
||||||
|
logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRule adds a packet filtering rule
|
// AddRule adds a packet filtering rule
|
||||||
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||||
d.mutex.Lock()
|
d.rulesMutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.rulesMutex.Unlock()
|
||||||
d.rules = append(d.rules, FilterRule{
|
d.rules = append(d.rules, FilterRule{
|
||||||
DestIP: destIP,
|
DestIP: destIP,
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
@@ -117,8 +264,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
|||||||
|
|
||||||
// RemoveRule removes all rules for a given destination IP
|
// RemoveRule removes all rules for a given destination IP
|
||||||
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||||
d.mutex.Lock()
|
d.rulesMutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.rulesMutex.Unlock()
|
||||||
newRules := make([]FilterRule, 0, len(d.rules))
|
newRules := make([]FilterRule, 0, len(d.rules))
|
||||||
for _, rule := range d.rules {
|
for _, rule := range d.rules {
|
||||||
if rule.DestIP != destIP {
|
if rule.DestIP != destIP {
|
||||||
@@ -130,18 +277,120 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
|||||||
|
|
||||||
// Close stops the device
|
// Close stops the device
|
||||||
func (d *MiddleDevice) Close() error {
|
func (d *MiddleDevice) Close() error {
|
||||||
select {
|
if !d.closed.CompareAndSwap(false, true) {
|
||||||
case <-d.closed:
|
return nil // already closed
|
||||||
// Already closed
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
logger.Debug("MiddleDevice: Closing, signaling closed channel")
|
|
||||||
close(d.closed)
|
|
||||||
}
|
}
|
||||||
logger.Debug("MiddleDevice: Closing underlying TUN device")
|
|
||||||
err := d.Device.Close()
|
d.mu.Lock()
|
||||||
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
|
devices := d.devices
|
||||||
return err
|
d.devices = nil
|
||||||
|
d.cond.Broadcast()
|
||||||
|
d.mu.Unlock()
|
||||||
|
|
||||||
|
// Close underlying devices first - this causes the pump goroutines to exit
|
||||||
|
// when their read operations return errors
|
||||||
|
var lastErr error
|
||||||
|
logger.Debug("MiddleDevice: Closing %d devices", len(devices))
|
||||||
|
for _, device := range devices {
|
||||||
|
if err := device.Close(); err != nil {
|
||||||
|
logger.Debug("MiddleDevice: Error closing device: %v", err)
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now close channels to unblock any remaining readers
|
||||||
|
// The pump should have exited by now, but close channels to be safe
|
||||||
|
close(d.readCh)
|
||||||
|
close(d.injectCh)
|
||||||
|
close(d.events)
|
||||||
|
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Events returns the events channel
|
||||||
|
func (d *MiddleDevice) Events() <-chan tun.Event {
|
||||||
|
return d.events
|
||||||
|
}
|
||||||
|
|
||||||
|
// File returns the underlying file descriptor
|
||||||
|
func (d *MiddleDevice) File() *os.File {
|
||||||
|
for {
|
||||||
|
dev := d.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !d.waitForDevice() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
file := dev.File()
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MTU returns the MTU of the underlying device
|
||||||
|
func (d *MiddleDevice) MTU() (int, error) {
|
||||||
|
for {
|
||||||
|
dev := d.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !d.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
mtu, err := dev.MTU()
|
||||||
|
if err == nil {
|
||||||
|
return mtu, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the name of the underlying device
|
||||||
|
func (d *MiddleDevice) Name() (string, error) {
|
||||||
|
for {
|
||||||
|
dev := d.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !d.waitForDevice() {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
name, err := dev.Name()
|
||||||
|
if err == nil {
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns the batch size
|
||||||
|
func (d *MiddleDevice) BatchSize() int {
|
||||||
|
dev := d.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return dev.BatchSize()
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractDestIP extracts destination IP from packet (fast path)
|
// extractDestIP extracts destination IP from packet (fast path)
|
||||||
@@ -176,156 +425,239 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
|
|||||||
|
|
||||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
// Check if already closed first (non-blocking)
|
for {
|
||||||
select {
|
if d.closed.Load() {
|
||||||
case <-d.closed:
|
logger.Debug("MiddleDevice: Read returning io.EOF, device closed")
|
||||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
|
return 0, io.EOF
|
||||||
return 0, os.ErrClosed
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now block waiting for data
|
|
||||||
select {
|
|
||||||
case res := <-d.readCh:
|
|
||||||
if res.err != nil {
|
|
||||||
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
|
||||||
return 0, res.err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy packets from result to provided buffers
|
// Wait for a device to be available
|
||||||
count := 0
|
dev := d.peekLast()
|
||||||
for i := 0; i < res.n && i < len(bufs); i++ {
|
if dev == nil {
|
||||||
// Handle offset mismatch if necessary
|
if !d.waitForDevice() {
|
||||||
// We assume the pump used defaultOffset (16)
|
return 0, io.EOF
|
||||||
// If caller asks for different offset, we need to shift
|
|
||||||
src := res.bufs[i]
|
|
||||||
srcOffset := res.offset
|
|
||||||
srcSize := res.sizes[i]
|
|
||||||
|
|
||||||
// Calculate where the packet data starts and ends in src
|
|
||||||
pktData := src[srcOffset : srcOffset+srcSize]
|
|
||||||
|
|
||||||
// Ensure dest buffer is large enough
|
|
||||||
if len(bufs[i]) < offset+len(pktData) {
|
|
||||||
continue // Skip if buffer too small
|
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(bufs[i][offset:], pktData)
|
|
||||||
sizes[i] = len(pktData)
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
n = count
|
|
||||||
|
|
||||||
case pkt := <-d.injectCh:
|
|
||||||
if len(bufs) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
if len(bufs[0]) < offset+len(pkt) {
|
|
||||||
return 0, nil // Buffer too small
|
|
||||||
}
|
|
||||||
copy(bufs[0][offset:], pkt)
|
|
||||||
sizes[0] = len(pkt)
|
|
||||||
n = 1
|
|
||||||
|
|
||||||
case <-d.closed:
|
|
||||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
|
|
||||||
return 0, os.ErrClosed // Signal that device is closed
|
|
||||||
}
|
|
||||||
|
|
||||||
d.mutex.RLock()
|
|
||||||
rules := d.rules
|
|
||||||
d.mutex.RUnlock()
|
|
||||||
|
|
||||||
if len(rules) == 0 {
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process packets and filter out handled ones
|
|
||||||
writeIdx := 0
|
|
||||||
for readIdx := 0; readIdx < n; readIdx++ {
|
|
||||||
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
|
||||||
|
|
||||||
destIP, ok := extractDestIP(packet)
|
|
||||||
if !ok {
|
|
||||||
// Can't parse, keep packet
|
|
||||||
if writeIdx != readIdx {
|
|
||||||
bufs[writeIdx] = bufs[readIdx]
|
|
||||||
sizes[writeIdx] = sizes[readIdx]
|
|
||||||
}
|
|
||||||
writeIdx++
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if packet matches any rule
|
// Now block waiting for data from readCh or injectCh
|
||||||
handled := false
|
select {
|
||||||
for _, rule := range rules {
|
case res, ok := <-d.readCh:
|
||||||
if rule.DestIP == destIP {
|
if !ok {
|
||||||
if rule.Handler(packet) {
|
// Channel closed, device is shutting down
|
||||||
// Packet was handled and should be dropped
|
return 0, io.EOF
|
||||||
handled = true
|
}
|
||||||
break
|
if res.err != nil {
|
||||||
|
// Check if device was swapped
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
||||||
|
return 0, res.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy packets from result to provided buffers
|
||||||
|
count := 0
|
||||||
|
for i := 0; i < res.n && i < len(bufs); i++ {
|
||||||
|
src := res.bufs[i]
|
||||||
|
srcOffset := res.offset
|
||||||
|
srcSize := res.sizes[i]
|
||||||
|
|
||||||
|
pktData := src[srcOffset : srcOffset+srcSize]
|
||||||
|
|
||||||
|
if len(bufs[i]) < offset+len(pktData) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(bufs[i][offset:], pktData)
|
||||||
|
sizes[i] = len(pktData)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
n = count
|
||||||
|
|
||||||
|
case pkt, ok := <-d.injectCh:
|
||||||
|
if !ok {
|
||||||
|
// Channel closed, device is shutting down
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
if len(bufs) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if len(bufs[0]) < offset+len(pkt) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
copy(bufs[0][offset:], pkt)
|
||||||
|
sizes[0] = len(pkt)
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply filtering rules
|
||||||
|
d.rulesMutex.RLock()
|
||||||
|
rules := d.rules
|
||||||
|
d.rulesMutex.RUnlock()
|
||||||
|
|
||||||
|
if len(rules) == 0 {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process packets and filter out handled ones
|
||||||
|
writeIdx := 0
|
||||||
|
for readIdx := 0; readIdx < n; readIdx++ {
|
||||||
|
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
||||||
|
|
||||||
|
destIP, ok := extractDestIP(packet)
|
||||||
|
if !ok {
|
||||||
|
if writeIdx != readIdx {
|
||||||
|
bufs[writeIdx] = bufs[readIdx]
|
||||||
|
sizes[writeIdx] = sizes[readIdx]
|
||||||
|
}
|
||||||
|
writeIdx++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
handled := false
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.DestIP == destIP {
|
||||||
|
if rule.Handler(packet) {
|
||||||
|
handled = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !handled {
|
if !handled {
|
||||||
// Keep packet
|
if writeIdx != readIdx {
|
||||||
if writeIdx != readIdx {
|
bufs[writeIdx] = bufs[readIdx]
|
||||||
bufs[writeIdx] = bufs[readIdx]
|
sizes[writeIdx] = sizes[readIdx]
|
||||||
sizes[writeIdx] = sizes[readIdx]
|
}
|
||||||
|
writeIdx++
|
||||||
}
|
}
|
||||||
writeIdx++
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return writeIdx, err
|
return writeIdx, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
||||||
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
|
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||||
d.mutex.RLock()
|
for {
|
||||||
rules := d.rules
|
if d.closed.Load() {
|
||||||
d.mutex.RUnlock()
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
if len(rules) == 0 {
|
dev := d.peekLast()
|
||||||
return d.Device.Write(bufs, offset)
|
if dev == nil {
|
||||||
}
|
if !d.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
// Filter packets going down
|
}
|
||||||
filteredBufs := make([][]byte, 0, len(bufs))
|
|
||||||
for _, buf := range bufs {
|
|
||||||
if len(buf) <= offset {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
packet := buf[offset:]
|
d.rulesMutex.RLock()
|
||||||
destIP, ok := extractDestIP(packet)
|
rules := d.rules
|
||||||
if !ok {
|
d.rulesMutex.RUnlock()
|
||||||
// Can't parse, keep packet
|
|
||||||
filteredBufs = append(filteredBufs, buf)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if packet matches any rule
|
var filteredBufs [][]byte
|
||||||
handled := false
|
if len(rules) == 0 {
|
||||||
for _, rule := range rules {
|
filteredBufs = bufs
|
||||||
if rule.DestIP == destIP {
|
} else {
|
||||||
if rule.Handler(packet) {
|
filteredBufs = make([][]byte, 0, len(bufs))
|
||||||
// Packet was handled and should be dropped
|
for _, buf := range bufs {
|
||||||
handled = true
|
if len(buf) <= offset {
|
||||||
break
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
packet := buf[offset:]
|
||||||
|
destIP, ok := extractDestIP(packet)
|
||||||
|
if !ok {
|
||||||
|
filteredBufs = append(filteredBufs, buf)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
handled := false
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.DestIP == destIP {
|
||||||
|
if rule.Handler(packet) {
|
||||||
|
handled = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !handled {
|
||||||
|
filteredBufs = append(filteredBufs, buf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !handled {
|
if len(filteredBufs) == 0 {
|
||||||
filteredBufs = append(filteredBufs, buf)
|
return len(bufs), nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(filteredBufs) == 0 {
|
n, err := dev.Write(filteredBufs, offset)
|
||||||
return len(bufs), nil // All packets were handled
|
if err == nil {
|
||||||
}
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
return d.Device.Write(filteredBufs, offset)
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *MiddleDevice) waitForDevice() bool {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
for len(d.devices) == 0 && !d.closed.Load() {
|
||||||
|
d.cond.Wait()
|
||||||
|
}
|
||||||
|
return !d.closed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MiddleDevice) peekLast() *closeAwareDevice {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
if len(d.devices) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.devices[len(d.devices)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteToTun writes packets directly to the underlying TUN device,
|
||||||
|
// bypassing WireGuard. This is useful for sending packets that should
|
||||||
|
// appear to come from the TUN interface (e.g., DNS responses from a proxy).
|
||||||
|
// Unlike Write(), this does not go through packet filtering rules.
|
||||||
|
func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for {
|
||||||
|
if d.closed.Load() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
dev := d.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !d.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := dev.Write(bufs, offset)
|
||||||
|
if err == nil {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !windows
|
//go:build darwin
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
|
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
|
||||||
device, err := tun.CreateTUNFromFile(file, mtuInt)
|
device, err := tun.CreateTUNFromFile(file, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
file.Close()
|
file.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
50
device/tun_linux.go
Normal file
50
device/tun_linux.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||||
|
if runtime.GOOS == "android" { // otherwise we get a permission denied
|
||||||
|
theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
|
||||||
|
return theTun, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dupTunFd, err := unix.Dup(int(tunFd))
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Unable to dup tun fd: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = unix.SetNonblock(dupTunFd, true)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(dupTunFd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
|
||||||
|
device, err := tun.CreateTUNFromFile(file, mtuInt)
|
||||||
|
if err != nil {
|
||||||
|
file.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return device, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||||
|
return ipc.UAPIOpen(interfaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||||
|
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||||
|
}
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/fosrl/newt/util"
|
"github.com/fosrl/newt/util"
|
||||||
"github.com/fosrl/olm/device"
|
"github.com/fosrl/olm/device"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
@@ -34,18 +33,17 @@ type DNSProxy struct {
|
|||||||
ep *channel.Endpoint
|
ep *channel.Endpoint
|
||||||
proxyIP netip.Addr
|
proxyIP netip.Addr
|
||||||
upstreamDNS []string
|
upstreamDNS []string
|
||||||
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
|
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
|
||||||
mtu int
|
mtu int
|
||||||
tunDevice tun.Device // Direct reference to underlying TUN device for responses
|
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes
|
||||||
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
|
|
||||||
recordStore *DNSRecordStore // Local DNS records
|
recordStore *DNSRecordStore // Local DNS records
|
||||||
|
|
||||||
// Tunnel DNS fields - for sending queries over WireGuard
|
// Tunnel DNS fields - for sending queries over WireGuard
|
||||||
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
|
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
|
||||||
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
|
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
|
||||||
tunnelEp *channel.Endpoint
|
tunnelEp *channel.Endpoint
|
||||||
tunnelActivePorts map[uint16]bool
|
tunnelActivePorts map[uint16]bool
|
||||||
tunnelPortsLock sync.Mutex
|
tunnelPortsLock sync.Mutex
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
@@ -53,7 +51,7 @@ type DNSProxy struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSProxy creates a new DNS proxy
|
// NewDNSProxy creates a new DNS proxy
|
||||||
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
|
func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
|
||||||
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
||||||
@@ -68,7 +66,6 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
|
|||||||
proxy := &DNSProxy{
|
proxy := &DNSProxy{
|
||||||
proxyIP: proxyIP,
|
proxyIP: proxyIP,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
tunDevice: tunDevice,
|
|
||||||
middleDevice: middleDevice,
|
middleDevice: middleDevice,
|
||||||
upstreamDNS: upstreamDns,
|
upstreamDNS: upstreamDns,
|
||||||
tunnelDNS: tunnelDns,
|
tunnelDNS: tunnelDns,
|
||||||
@@ -602,12 +599,12 @@ func (p *DNSProxy) runTunnelPacketSender() {
|
|||||||
defer p.wg.Done()
|
defer p.wg.Done()
|
||||||
logger.Debug("DNS tunnel packet sender goroutine started")
|
logger.Debug("DNS tunnel packet sender goroutine started")
|
||||||
|
|
||||||
ticker := time.NewTicker(1 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
// Use blocking ReadContext instead of polling - much more CPU efficient
|
||||||
case <-p.ctx.Done():
|
// This will block until a packet is available or context is cancelled
|
||||||
|
pkt := p.tunnelEp.ReadContext(p.ctx)
|
||||||
|
if pkt == nil {
|
||||||
|
// Context was cancelled or endpoint closed
|
||||||
logger.Debug("DNS tunnel packet sender exiting")
|
logger.Debug("DNS tunnel packet sender exiting")
|
||||||
// Drain any remaining packets
|
// Drain any remaining packets
|
||||||
for {
|
for {
|
||||||
@@ -618,36 +615,28 @@ func (p *DNSProxy) runTunnelPacketSender() {
|
|||||||
pkt.DecRef()
|
pkt.DecRef()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
|
||||||
// Try to read packets
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
pkt := p.tunnelEp.Read()
|
|
||||||
if pkt == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract packet data
|
|
||||||
slices := pkt.AsSlices()
|
|
||||||
if len(slices) > 0 {
|
|
||||||
var totalSize int
|
|
||||||
for _, slice := range slices {
|
|
||||||
totalSize += len(slice)
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, totalSize)
|
|
||||||
pos := 0
|
|
||||||
for _, slice := range slices {
|
|
||||||
copy(buf[pos:], slice)
|
|
||||||
pos += len(slice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Inject into MiddleDevice (outbound to WG)
|
|
||||||
p.middleDevice.InjectOutbound(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
pkt.DecRef()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract packet data
|
||||||
|
slices := pkt.AsSlices()
|
||||||
|
if len(slices) > 0 {
|
||||||
|
var totalSize int
|
||||||
|
for _, slice := range slices {
|
||||||
|
totalSize += len(slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, totalSize)
|
||||||
|
pos := 0
|
||||||
|
for _, slice := range slices {
|
||||||
|
copy(buf[pos:], slice)
|
||||||
|
pos += len(slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject into MiddleDevice (outbound to WG)
|
||||||
|
p.middleDevice.InjectOutbound(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt.DecRef()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -660,18 +649,12 @@ func (p *DNSProxy) runPacketSender() {
|
|||||||
const offset = 16
|
const offset = 16
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
// Use blocking ReadContext instead of polling - much more CPU efficient
|
||||||
case <-p.ctx.Done():
|
// This will block until a packet is available or context is cancelled
|
||||||
return
|
pkt := p.ep.ReadContext(p.ctx)
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read packets from netstack endpoint
|
|
||||||
pkt := p.ep.Read()
|
|
||||||
if pkt == nil {
|
if pkt == nil {
|
||||||
// No packet available, small sleep to avoid busy loop
|
// Context was cancelled or endpoint closed
|
||||||
time.Sleep(1 * time.Millisecond)
|
return
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract packet data as slices
|
// Extract packet data as slices
|
||||||
@@ -694,9 +677,9 @@ func (p *DNSProxy) runPacketSender() {
|
|||||||
pos += len(slice)
|
pos += len(slice)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write packet to TUN device
|
// Write packet to TUN device via MiddleDevice
|
||||||
// offset=16 indicates packet data starts at position 16 in the buffer
|
// offset=16 indicates packet data starts at position 16 in the buffer
|
||||||
_, err := p.tunDevice.Write([][]byte{buf}, offset)
|
_, err := p.middleDevice.WriteToTun([][]byte{buf}, offset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to write DNS response to TUN: %v", err)
|
logger.Error("Failed to write DNS response to TUN: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -322,4 +322,4 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return matchWildcardInternal(pattern, domain, pi+1, di+1)
|
return matchWildcardInternal(pattern, domain, pi+1, di+1)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ func TestWildcardMatching(t *testing.T) {
|
|||||||
domain: "autoco.internal.",
|
domain: "autoco.internal.",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Question mark wildcard tests
|
// Question mark wildcard tests
|
||||||
{
|
{
|
||||||
name: "host-0?.autoco.internal matches host-01.autoco.internal",
|
name: "host-0?.autoco.internal matches host-01.autoco.internal",
|
||||||
@@ -63,7 +63,7 @@ func TestWildcardMatching(t *testing.T) {
|
|||||||
domain: "host-012.autoco.internal.",
|
domain: "host-012.autoco.internal.",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Combined wildcard tests
|
// Combined wildcard tests
|
||||||
{
|
{
|
||||||
name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal",
|
name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal",
|
||||||
@@ -83,7 +83,7 @@ func TestWildcardMatching(t *testing.T) {
|
|||||||
domain: "host-01.autoco.internal.",
|
domain: "host-01.autoco.internal.",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Multiple asterisks
|
// Multiple asterisks
|
||||||
{
|
{
|
||||||
name: "*.*. autoco.internal matches any.thing.autoco.internal",
|
name: "*.*. autoco.internal matches any.thing.autoco.internal",
|
||||||
@@ -97,7 +97,7 @@ func TestWildcardMatching(t *testing.T) {
|
|||||||
domain: "single.autoco.internal.",
|
domain: "single.autoco.internal.",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Asterisk in middle
|
// Asterisk in middle
|
||||||
{
|
{
|
||||||
name: "host-*.autoco.internal matches host-anything.autoco.internal",
|
name: "host-*.autoco.internal matches host-anything.autoco.internal",
|
||||||
@@ -111,7 +111,7 @@ func TestWildcardMatching(t *testing.T) {
|
|||||||
domain: "host-.autoco.internal.",
|
domain: "host-.autoco.internal.",
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Multiple question marks
|
// Multiple question marks
|
||||||
{
|
{
|
||||||
name: "host-??.autoco.internal matches host-01.autoco.internal",
|
name: "host-??.autoco.internal matches host-01.autoco.internal",
|
||||||
@@ -125,7 +125,7 @@ func TestWildcardMatching(t *testing.T) {
|
|||||||
domain: "host-1.autoco.internal.",
|
domain: "host-1.autoco.internal.",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Exact match (no wildcards)
|
// Exact match (no wildcards)
|
||||||
{
|
{
|
||||||
name: "exact.autoco.internal matches exact.autoco.internal",
|
name: "exact.autoco.internal matches exact.autoco.internal",
|
||||||
@@ -139,7 +139,7 @@ func TestWildcardMatching(t *testing.T) {
|
|||||||
domain: "other.autoco.internal.",
|
domain: "other.autoco.internal.",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Edge cases
|
// Edge cases
|
||||||
{
|
{
|
||||||
name: "* matches anything",
|
name: "* matches anything",
|
||||||
@@ -154,7 +154,7 @@ func TestWildcardMatching(t *testing.T) {
|
|||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := matchWildcard(tt.pattern, tt.domain)
|
result := matchWildcard(tt.pattern, tt.domain)
|
||||||
@@ -167,21 +167,21 @@ func TestWildcardMatching(t *testing.T) {
|
|||||||
|
|
||||||
func TestDNSRecordStoreWildcard(t *testing.T) {
|
func TestDNSRecordStoreWildcard(t *testing.T) {
|
||||||
store := NewDNSRecordStore()
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
// Add wildcard records
|
// Add wildcard records
|
||||||
wildcardIP := net.ParseIP("10.0.0.1")
|
wildcardIP := net.ParseIP("10.0.0.1")
|
||||||
err := store.AddRecord("*.autoco.internal", wildcardIP)
|
err := store.AddRecord("*.autoco.internal", wildcardIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add exact record
|
// Add exact record
|
||||||
exactIP := net.ParseIP("10.0.0.2")
|
exactIP := net.ParseIP("10.0.0.2")
|
||||||
err = store.AddRecord("exact.autoco.internal", exactIP)
|
err = store.AddRecord("exact.autoco.internal", exactIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add exact record: %v", err)
|
t.Fatalf("Failed to add exact record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test exact match takes precedence
|
// Test exact match takes precedence
|
||||||
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
|
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
@@ -190,7 +190,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
|
|||||||
if !ips[0].Equal(exactIP) {
|
if !ips[0].Equal(exactIP) {
|
||||||
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
|
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test wildcard match
|
// Test wildcard match
|
||||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
@@ -199,7 +199,7 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
|
|||||||
if !ips[0].Equal(wildcardIP) {
|
if !ips[0].Equal(wildcardIP) {
|
||||||
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
|
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test non-match (base domain)
|
// Test non-match (base domain)
|
||||||
ips = store.GetRecords("autoco.internal.", RecordTypeA)
|
ips = store.GetRecords("autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 0 {
|
if len(ips) != 0 {
|
||||||
@@ -209,14 +209,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
|
|||||||
|
|
||||||
func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
||||||
store := NewDNSRecordStore()
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
// Add complex wildcard pattern
|
// Add complex wildcard pattern
|
||||||
ip1 := net.ParseIP("10.0.0.1")
|
ip1 := net.ParseIP("10.0.0.1")
|
||||||
err := store.AddRecord("*.host-0?.autoco.internal", ip1)
|
err := store.AddRecord("*.host-0?.autoco.internal", ip1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test matching domain
|
// Test matching domain
|
||||||
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
|
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
@@ -225,13 +225,13 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
|||||||
if len(ips) > 0 && !ips[0].Equal(ip1) {
|
if len(ips) > 0 && !ips[0].Equal(ip1) {
|
||||||
t.Errorf("Expected IP %v, got %v", ip1, ips[0])
|
t.Errorf("Expected IP %v, got %v", ip1, ips[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test non-matching domain (missing prefix)
|
// Test non-matching domain (missing prefix)
|
||||||
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
|
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 0 {
|
if len(ips) != 0 {
|
||||||
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
|
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test non-matching domain (wrong ? position)
|
// Test non-matching domain (wrong ? position)
|
||||||
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
|
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 0 {
|
if len(ips) != 0 {
|
||||||
@@ -241,23 +241,23 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
|||||||
|
|
||||||
func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
|
func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
|
||||||
store := NewDNSRecordStore()
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
// Add wildcard record
|
// Add wildcard record
|
||||||
ip := net.ParseIP("10.0.0.1")
|
ip := net.ParseIP("10.0.0.1")
|
||||||
err := store.AddRecord("*.autoco.internal", ip)
|
err := store.AddRecord("*.autoco.internal", ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify it exists
|
// Verify it exists
|
||||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
|
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
|
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove wildcard record
|
// Remove wildcard record
|
||||||
store.RemoveRecord("*.autoco.internal", nil)
|
store.RemoveRecord("*.autoco.internal", nil)
|
||||||
|
|
||||||
// Verify it's gone
|
// Verify it's gone
|
||||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 0 {
|
if len(ips) != 0 {
|
||||||
@@ -267,40 +267,40 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
|
|||||||
|
|
||||||
func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
|
func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
|
||||||
store := NewDNSRecordStore()
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
// Add multiple wildcard patterns that don't overlap
|
// Add multiple wildcard patterns that don't overlap
|
||||||
ip1 := net.ParseIP("10.0.0.1")
|
ip1 := net.ParseIP("10.0.0.1")
|
||||||
ip2 := net.ParseIP("10.0.0.2")
|
ip2 := net.ParseIP("10.0.0.2")
|
||||||
ip3 := net.ParseIP("10.0.0.3")
|
ip3 := net.ParseIP("10.0.0.3")
|
||||||
|
|
||||||
err := store.AddRecord("*.prod.autoco.internal", ip1)
|
err := store.AddRecord("*.prod.autoco.internal", ip1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add first wildcard: %v", err)
|
t.Fatalf("Failed to add first wildcard: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = store.AddRecord("*.dev.autoco.internal", ip2)
|
err = store.AddRecord("*.dev.autoco.internal", ip2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add second wildcard: %v", err)
|
t.Fatalf("Failed to add second wildcard: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a broader wildcard that matches both
|
// Add a broader wildcard that matches both
|
||||||
err = store.AddRecord("*.autoco.internal", ip3)
|
err = store.AddRecord("*.autoco.internal", ip3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add third wildcard: %v", err)
|
t.Fatalf("Failed to add third wildcard: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test domain matching only the prod pattern and the broad pattern
|
// Test domain matching only the prod pattern and the broad pattern
|
||||||
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
|
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 2 {
|
if len(ips) != 2 {
|
||||||
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
|
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test domain matching only the dev pattern and the broad pattern
|
// Test domain matching only the dev pattern and the broad pattern
|
||||||
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
|
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 2 {
|
if len(ips) != 2 {
|
||||||
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
|
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test domain matching only the broad pattern
|
// Test domain matching only the broad pattern
|
||||||
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
|
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
@@ -310,14 +310,14 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
|
|||||||
|
|
||||||
func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
|
func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
|
||||||
store := NewDNSRecordStore()
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
// Add IPv6 wildcard record
|
// Add IPv6 wildcard record
|
||||||
ip := net.ParseIP("2001:db8::1")
|
ip := net.ParseIP("2001:db8::1")
|
||||||
err := store.AddRecord("*.autoco.internal", ip)
|
err := store.AddRecord("*.autoco.internal", ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
|
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test wildcard match for IPv6
|
// Test wildcard match for IPv6
|
||||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
|
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
@@ -330,21 +330,21 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
|
|||||||
|
|
||||||
func TestHasRecordWildcard(t *testing.T) {
|
func TestHasRecordWildcard(t *testing.T) {
|
||||||
store := NewDNSRecordStore()
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
// Add wildcard record
|
// Add wildcard record
|
||||||
ip := net.ParseIP("10.0.0.1")
|
ip := net.ParseIP("10.0.0.1")
|
||||||
err := store.AddRecord("*.autoco.internal", ip)
|
err := store.AddRecord("*.autoco.internal", ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test HasRecord with wildcard match
|
// Test HasRecord with wildcard match
|
||||||
if !store.HasRecord("host.autoco.internal.", RecordTypeA) {
|
if !store.HasRecord("host.autoco.internal.", RecordTypeA) {
|
||||||
t.Error("Expected HasRecord to return true for wildcard match")
|
t.Error("Expected HasRecord to return true for wildcard match")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test HasRecord with non-match
|
// Test HasRecord with non-match
|
||||||
if store.HasRecord("autoco.internal.", RecordTypeA) {
|
if store.HasRecord("autoco.internal.", RecordTypeA) {
|
||||||
t.Error("Expected HasRecord to return false for base domain")
|
t.Error("Expected HasRecord to return false for base domain")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
16
dns/override/dns_override_android.go
Normal file
16
dns/override/dns_override_android.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package olm
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
|
// SetupDNSOverride is a no-op on Android
|
||||||
|
// Android handles DNS through the VpnService API at the Java/Kotlin layer
|
||||||
|
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreDNSOverride is a no-op on Android
|
||||||
|
func RestoreDNSOverride() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/olm/dns"
|
|
||||||
platform "github.com/fosrl/olm/dns/platform"
|
platform "github.com/fosrl/olm/dns/platform"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
|
|||||||
|
|
||||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
|
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
|
||||||
// Uses scutil for DNS configuration
|
// Uses scutil for DNS configuration
|
||||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||||
if dnsProxy == nil {
|
|
||||||
return fmt.Errorf("DNS proxy is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
configurator, err = platform.NewDarwinDNSConfigurator()
|
configurator, err = platform.NewDarwinDNSConfigurator()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
|||||||
|
|
||||||
// Set new DNS servers to point to our proxy
|
// Set new DNS servers to point to our proxy
|
||||||
newDNS := []netip.Addr{
|
newDNS := []netip.Addr{
|
||||||
dnsProxy.GetProxyIP(),
|
proxyIp,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||||
|
|||||||
15
dns/override/dns_override_ios.go
Normal file
15
dns/override/dns_override_ios.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package olm
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
|
// SetupDNSOverride is a no-op on iOS as DNS configuration is handled by the system
|
||||||
|
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreDNSOverride is a no-op on iOS as DNS configuration is handled by the system
|
||||||
|
func RestoreDNSOverride() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/olm/dns"
|
|
||||||
platform "github.com/fosrl/olm/dns/platform"
|
platform "github.com/fosrl/olm/dns/platform"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
|
|||||||
|
|
||||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
|
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
|
||||||
// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability
|
// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability
|
||||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||||
if dnsProxy == nil {
|
|
||||||
return fmt.Errorf("DNS proxy is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability
|
// Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability
|
||||||
@@ -32,7 +27,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
|||||||
configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName)
|
configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logger.Info("Using systemd-resolved DNS configurator")
|
logger.Info("Using systemd-resolved DNS configurator")
|
||||||
return setDNS(dnsProxy, configurator)
|
return setDNS(proxyIp, configurator)
|
||||||
}
|
}
|
||||||
logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err)
|
logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err)
|
||||||
|
|
||||||
@@ -40,7 +35,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
|||||||
configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName)
|
configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logger.Info("Using NetworkManager DNS configurator")
|
logger.Info("Using NetworkManager DNS configurator")
|
||||||
return setDNS(dnsProxy, configurator)
|
return setDNS(proxyIp, configurator)
|
||||||
}
|
}
|
||||||
logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err)
|
logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err)
|
||||||
|
|
||||||
@@ -48,7 +43,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
|||||||
configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName)
|
configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logger.Info("Using resolvconf DNS configurator")
|
logger.Info("Using resolvconf DNS configurator")
|
||||||
return setDNS(dnsProxy, configurator)
|
return setDNS(proxyIp, configurator)
|
||||||
}
|
}
|
||||||
logger.Warn("Failed to create resolvconf configurator: %v, falling back", err)
|
logger.Warn("Failed to create resolvconf configurator: %v, falling back", err)
|
||||||
}
|
}
|
||||||
@@ -60,11 +55,11 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Using file-based DNS configurator")
|
logger.Info("Using file-based DNS configurator")
|
||||||
return setDNS(dnsProxy, configurator)
|
return setDNS(proxyIp, configurator)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setDNS is a helper function to set DNS and log the results
|
// setDNS is a helper function to set DNS and log the results
|
||||||
func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
|
func setDNS(proxyIp netip.Addr, conf platform.DNSConfigurator) error {
|
||||||
// Get current DNS servers before changing
|
// Get current DNS servers before changing
|
||||||
currentDNS, err := conf.GetCurrentDNS()
|
currentDNS, err := conf.GetCurrentDNS()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -75,7 +70,7 @@ func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
|
|||||||
|
|
||||||
// Set new DNS servers to point to our proxy
|
// Set new DNS servers to point to our proxy
|
||||||
newDNS := []netip.Addr{
|
newDNS := []netip.Addr{
|
||||||
dnsProxy.GetProxyIP(),
|
proxyIp,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/olm/dns"
|
|
||||||
platform "github.com/fosrl/olm/dns/platform"
|
platform "github.com/fosrl/olm/dns/platform"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,11 +14,7 @@ var configurator platform.DNSConfigurator
|
|||||||
|
|
||||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
|
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
|
||||||
// Uses registry-based configuration (automatically extracts interface GUID)
|
// Uses registry-based configuration (automatically extracts interface GUID)
|
||||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
func SetupDNSOverride(interfaceName string, proxyIp netip.Addr) error {
|
||||||
if dnsProxy == nil {
|
|
||||||
return fmt.Errorf("DNS proxy is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
|
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -38,7 +33,7 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
|||||||
|
|
||||||
// Set new DNS servers to point to our proxy
|
// Set new DNS servers to point to our proxy
|
||||||
newDNS := []netip.Addr{
|
newDNS := []netip.Addr{
|
||||||
dnsProxy.GetProxyIP(),
|
proxyIp,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Setting DNS servers to: %v", newDNS)
|
logger.Info("Setting DNS servers to: %v", newDNS)
|
||||||
|
|||||||
@@ -416,4 +416,4 @@ func (d *DarwinDNSConfigurator) clearState() error {
|
|||||||
|
|
||||||
logger.Debug("Cleared DNS state file")
|
logger.Debug("Cleared DNS state file")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -4,7 +4,7 @@ go 1.25
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Microsoft/go-winio v0.6.2
|
github.com/Microsoft/go-winio v0.6.2
|
||||||
github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01
|
github.com/fosrl/newt v1.8.0
|
||||||
github.com/godbus/dbus/v5 v5.2.0
|
github.com/godbus/dbus/v5 v5.2.0
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/miekg/dns v1.1.68
|
github.com/miekg/dns v1.1.68
|
||||||
@@ -30,3 +30,5 @@ require (
|
|||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
replace github.com/fosrl/newt => ../newt
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -1,7 +1,5 @@
|
|||||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01 h1:VpuI42l4enih//6IFFQDln/B7WukfMePxIRIpXsNe/0=
|
|
||||||
github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
|
|
||||||
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
|
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
|
||||||
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||||
|
|||||||
13
main.go
13
main.go
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/updates"
|
"github.com/fosrl/newt/updates"
|
||||||
"github.com/fosrl/olm/olm"
|
olmpkg "github.com/fosrl/olm/olm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new olm.Config struct and copy values from the main config
|
// Create a new olm.Config struct and copy values from the main config
|
||||||
olmConfig := olm.GlobalConfig{
|
olmConfig := olmpkg.OlmConfig{
|
||||||
LogLevel: config.LogLevel,
|
LogLevel: config.LogLevel,
|
||||||
EnableAPI: config.EnableAPI,
|
EnableAPI: config.EnableAPI,
|
||||||
HTTPAddr: config.HTTPAddr,
|
HTTPAddr: config.HTTPAddr,
|
||||||
@@ -219,15 +219,20 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
|
|||||||
Agent: "Olm CLI",
|
Agent: "Olm CLI",
|
||||||
OnExit: cancel, // Pass cancel function directly to trigger shutdown
|
OnExit: cancel, // Pass cancel function directly to trigger shutdown
|
||||||
OnTerminated: cancel,
|
OnTerminated: cancel,
|
||||||
|
PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
olm, err := olmpkg.Init(ctx, olmConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to initialize olm: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
olm.Init(ctx, olmConfig)
|
|
||||||
if err := olm.StartApi(); err != nil {
|
if err := olm.StartApi(); err != nil {
|
||||||
logger.Fatal("Failed to start API server: %v", err)
|
logger.Fatal("Failed to start API server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
|
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
|
||||||
tunnelConfig := olm.TunnelConfig{
|
tunnelConfig := olmpkg.TunnelConfig{
|
||||||
Endpoint: config.Endpoint,
|
Endpoint: config.Endpoint,
|
||||||
ID: config.ID,
|
ID: config.ID,
|
||||||
Secret: config.Secret,
|
Secret: config.Secret,
|
||||||
|
|||||||
223
olm/connect.go
Normal file
223
olm/connect.go
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/newt/network"
|
||||||
|
olmDevice "github.com/fosrl/olm/device"
|
||||||
|
"github.com/fosrl/olm/dns"
|
||||||
|
dnsOverride "github.com/fosrl/olm/dns/override"
|
||||||
|
"github.com/fosrl/olm/peers"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
var wgData WgData
|
||||||
|
|
||||||
|
if o.connected {
|
||||||
|
logger.Info("Already connected. Ignoring new connection request.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.stopRegister != nil {
|
||||||
|
o.stopRegister()
|
||||||
|
o.stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.updateRegister != nil {
|
||||||
|
o.updateRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there is an existing tunnel then close it
|
||||||
|
if o.dev != nil {
|
||||||
|
logger.Info("Got new message. Closing existing tunnel!")
|
||||||
|
o.dev.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
||||||
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
o.tdev, err = func() (tun.Device, error) {
|
||||||
|
if o.tunnelConfig.FileDescriptorTun != 0 {
|
||||||
|
return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU)
|
||||||
|
}
|
||||||
|
ifName := o.tunnelConfig.InterfaceName
|
||||||
|
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
||||||
|
ifName, err = network.FindUnusedUTUN()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tun.CreateTUN(ifName, o.tunnelConfig.MTU)
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create TUN device: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if config.FileDescriptorTun == 0 {
|
||||||
|
if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
|
||||||
|
o.tunnelConfig.InterfaceName = realInterfaceName
|
||||||
|
}
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Wrap TUN device with packet filter for DNS proxy
|
||||||
|
o.middleDev = olmDevice.NewMiddleDevice(o.tdev)
|
||||||
|
|
||||||
|
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
||||||
|
// Use filtered device instead of raw TUN device
|
||||||
|
o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger))
|
||||||
|
|
||||||
|
if o.tunnelConfig.EnableUAPI {
|
||||||
|
fileUAPI, err := func() (*os.File, error) {
|
||||||
|
if o.tunnelConfig.FileDescriptorUAPI != 0 {
|
||||||
|
fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err)
|
||||||
|
}
|
||||||
|
return os.NewFile(uintptr(fd), ""), nil
|
||||||
|
}
|
||||||
|
return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName)
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("UAPI listen error: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := o.uapiListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go o.dev.IpcHandle(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
logger.Info("UAPI listener started")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = o.dev.Up(); err != nil {
|
||||||
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract interface IP (strip CIDR notation if present)
|
||||||
|
interfaceIP := wgData.TunnelIP
|
||||||
|
if strings.Contains(interfaceIP, "/") {
|
||||||
|
interfaceIP = strings.Split(interfaceIP, "/")[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and start DNS proxy
|
||||||
|
o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create DNS proxy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil {
|
||||||
|
logger.Error("Failed to o.tunnelConfigure interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet
|
||||||
|
logger.Error("Failed to add route for utility subnet: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create peer manager with integrated peer monitoring
|
||||||
|
o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
|
||||||
|
Device: o.dev,
|
||||||
|
DNSProxy: o.dnsProxy,
|
||||||
|
InterfaceName: o.tunnelConfig.InterfaceName,
|
||||||
|
PrivateKey: o.privateKey,
|
||||||
|
MiddleDev: o.middleDev,
|
||||||
|
LocalIP: interfaceIP,
|
||||||
|
SharedBind: o.sharedBind,
|
||||||
|
WSClient: o.websocket,
|
||||||
|
APIServer: o.apiServer,
|
||||||
|
})
|
||||||
|
|
||||||
|
for i := range wgData.Sites {
|
||||||
|
site := wgData.Sites[i]
|
||||||
|
var siteEndpoint string
|
||||||
|
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
|
||||||
|
if site.RelayEndpoint != "" {
|
||||||
|
siteEndpoint = site.RelayEndpoint
|
||||||
|
} else {
|
||||||
|
siteEndpoint = site.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
|
||||||
|
|
||||||
|
if err := o.peerManager.AddPeer(site); err != nil {
|
||||||
|
logger.Error("Failed to add peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Configured peer %s", site.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.peerManager.Start()
|
||||||
|
|
||||||
|
if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime
|
||||||
|
logger.Error("Failed to start DNS proxy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.tunnelConfig.OverrideDNS {
|
||||||
|
// Set up DNS override to use our DNS proxy
|
||||||
|
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
|
||||||
|
logger.Error("Failed to setup DNS override: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()})
|
||||||
|
}
|
||||||
|
|
||||||
|
o.apiServer.SetRegistered(true)
|
||||||
|
|
||||||
|
o.connected = true
|
||||||
|
|
||||||
|
// Invoke onConnected callback if configured
|
||||||
|
if o.olmConfig.OnConnected != nil {
|
||||||
|
go o.olmConfig.OnConnected()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("WireGuard device created.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleTerminate(msg websocket.WSMessage) {
|
||||||
|
logger.Info("Received terminate message")
|
||||||
|
o.apiServer.SetTerminated(true)
|
||||||
|
o.apiServer.SetConnectionStatus(false)
|
||||||
|
o.apiServer.SetRegistered(false)
|
||||||
|
o.apiServer.ClearPeerStatuses()
|
||||||
|
|
||||||
|
network.ClearNetworkSettings()
|
||||||
|
|
||||||
|
o.Close()
|
||||||
|
|
||||||
|
if o.olmConfig.OnTerminated != nil {
|
||||||
|
go o.olmConfig.OnTerminated()
|
||||||
|
}
|
||||||
|
}
|
||||||
344
olm/data.go
Normal file
344
olm/data.go
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/holepunch"
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/peers"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var addSubnetsData peers.PeerAdd
|
||||||
|
if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling add-remote-subnets data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists {
|
||||||
|
logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new subnets
|
||||||
|
for _, subnet := range addSubnetsData.RemoteSubnets {
|
||||||
|
if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil {
|
||||||
|
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new aliases
|
||||||
|
for _, alias := range addSubnetsData.Aliases {
|
||||||
|
if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil {
|
||||||
|
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var removeSubnetsData peers.RemovePeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling remove-remote-subnets data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists {
|
||||||
|
logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove subnets
|
||||||
|
for _, subnet := range removeSubnetsData.RemoteSubnets {
|
||||||
|
if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove aliases
|
||||||
|
for _, alias := range removeSubnetsData.Aliases {
|
||||||
|
if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil {
|
||||||
|
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var updateSubnetsData peers.UpdatePeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling update-remote-subnets data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists {
|
||||||
|
logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new subnets BEFORE removing old ones to preserve shared subnets
|
||||||
|
// This ensures that if an old and new subnet are the same on different peers,
|
||||||
|
// the route won't be temporarily removed
|
||||||
|
for _, subnet := range updateSubnetsData.NewRemoteSubnets {
|
||||||
|
if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
|
||||||
|
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove old subnets after new ones are added
|
||||||
|
for _, subnet := range updateSubnetsData.OldRemoteSubnets {
|
||||||
|
if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new aliases BEFORE removing old ones to preserve shared IP addresses
|
||||||
|
// This ensures that if an old and new alias share the same IP, the IP won't be
|
||||||
|
// temporarily removed from the allowed IPs list
|
||||||
|
for _, alias := range updateSubnetsData.NewAliases {
|
||||||
|
if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil {
|
||||||
|
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove old aliases after new ones are added
|
||||||
|
for _, alias := range updateSubnetsData.OldAliases {
|
||||||
|
if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil {
|
||||||
|
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received peer-handshake message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling handshake data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var handshakeData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
ExitNode struct {
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
RelayPort uint16 `json:"relayPort"`
|
||||||
|
} `json:"exitNode"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling handshake data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get existing peer from PeerManager
|
||||||
|
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
||||||
|
if exists {
|
||||||
|
logger.Warn("Peer with site ID %d already added", handshakeData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relayPort := handshakeData.ExitNode.RelayPort
|
||||||
|
if relayPort == 0 {
|
||||||
|
relayPort = 21820 // default relay port
|
||||||
|
}
|
||||||
|
|
||||||
|
siteId := handshakeData.SiteId
|
||||||
|
exitNode := holepunch.ExitNode{
|
||||||
|
Endpoint: handshakeData.ExitNode.Endpoint,
|
||||||
|
RelayPort: relayPort,
|
||||||
|
PublicKey: handshakeData.ExitNode.PublicKey,
|
||||||
|
SiteIds: []int{siteId},
|
||||||
|
}
|
||||||
|
|
||||||
|
added := o.holePunchManager.AddExitNode(exitNode)
|
||||||
|
if added {
|
||||||
|
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
||||||
|
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||||
|
|
||||||
|
// Send handshake acknowledgment back to server with retry
|
||||||
|
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||||
|
"siteId": handshakeData.SiteId,
|
||||||
|
}, 1*time.Second, 10)
|
||||||
|
|
||||||
|
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler for syncing peer configuration - reconciles expected state with actual state
|
||||||
|
func (o *Olm) handleSync(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received sync message: %v", msg.Data)
|
||||||
|
|
||||||
|
if !o.connected {
|
||||||
|
logger.Warn("Not connected, ignoring sync request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.peerManager == nil {
|
||||||
|
logger.Warn("Peer manager not initialized, ignoring sync request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling sync data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var wgData WgData
|
||||||
|
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling sync data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a map of expected peers from the incoming data
|
||||||
|
expectedPeers := make(map[int]peers.SiteConfig)
|
||||||
|
for _, site := range wgData.Sites {
|
||||||
|
expectedPeers[site.SiteId] = site
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all current peers
|
||||||
|
currentPeers := o.peerManager.GetAllPeers()
|
||||||
|
currentPeerMap := make(map[int]peers.SiteConfig)
|
||||||
|
for _, peer := range currentPeers {
|
||||||
|
currentPeerMap[peer.SiteId] = peer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find peers to remove (in current but not in expected)
|
||||||
|
for siteId := range currentPeerMap {
|
||||||
|
if _, exists := expectedPeers[siteId]; !exists {
|
||||||
|
logger.Info("Sync: Removing peer for site %d (no longer in expected config)", siteId)
|
||||||
|
if err := o.peerManager.RemovePeer(siteId); err != nil {
|
||||||
|
logger.Error("Sync: Failed to remove peer %d: %v", siteId, err)
|
||||||
|
} else {
|
||||||
|
// Remove any exit nodes associated with this peer from hole punching
|
||||||
|
if o.holePunchManager != nil {
|
||||||
|
removed := o.holePunchManager.RemoveExitNodesByPeer(siteId)
|
||||||
|
if removed > 0 {
|
||||||
|
logger.Info("Sync: Removed %d exit nodes associated with peer %d from hole punch rotation", removed, siteId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find peers to add (in expected but not in current) and peers to update
|
||||||
|
for siteId, expectedSite := range expectedPeers {
|
||||||
|
if _, exists := currentPeerMap[siteId]; !exists {
|
||||||
|
// New peer - add it using the add flow (with holepunch)
|
||||||
|
logger.Info("Sync: Adding new peer for site %d", siteId)
|
||||||
|
|
||||||
|
// Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||||
|
o.holePunchManager.TriggerHolePunch()
|
||||||
|
|
||||||
|
// TODO: do we need to send the message to the cloud to add the peer that way?
|
||||||
|
if err := o.peerManager.AddPeer(expectedSite); err != nil {
|
||||||
|
logger.Error("Sync: Failed to add peer %d: %v", siteId, err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Sync: Successfully added peer for site %d", siteId)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Existing peer - check if update is needed
|
||||||
|
currentSite := currentPeerMap[siteId]
|
||||||
|
needsUpdate := false
|
||||||
|
|
||||||
|
// Check if any fields have changed
|
||||||
|
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
|
||||||
|
needsUpdate = true
|
||||||
|
}
|
||||||
|
if expectedSite.RelayEndpoint != "" && expectedSite.RelayEndpoint != currentSite.RelayEndpoint {
|
||||||
|
needsUpdate = true
|
||||||
|
}
|
||||||
|
if expectedSite.PublicKey != "" && expectedSite.PublicKey != currentSite.PublicKey {
|
||||||
|
needsUpdate = true
|
||||||
|
}
|
||||||
|
if expectedSite.ServerIP != "" && expectedSite.ServerIP != currentSite.ServerIP {
|
||||||
|
needsUpdate = true
|
||||||
|
}
|
||||||
|
if expectedSite.ServerPort != 0 && expectedSite.ServerPort != currentSite.ServerPort {
|
||||||
|
needsUpdate = true
|
||||||
|
}
|
||||||
|
// Check remote subnets
|
||||||
|
if expectedSite.RemoteSubnets != nil && !slicesEqual(expectedSite.RemoteSubnets, currentSite.RemoteSubnets) {
|
||||||
|
needsUpdate = true
|
||||||
|
}
|
||||||
|
// Check aliases
|
||||||
|
if expectedSite.Aliases != nil && !aliasesEqual(expectedSite.Aliases, currentSite.Aliases) {
|
||||||
|
needsUpdate = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsUpdate {
|
||||||
|
logger.Info("Sync: Updating peer for site %d", siteId)
|
||||||
|
|
||||||
|
// Merge expected data with current data
|
||||||
|
siteConfig := currentSite
|
||||||
|
if expectedSite.Endpoint != "" {
|
||||||
|
siteConfig.Endpoint = expectedSite.Endpoint
|
||||||
|
}
|
||||||
|
if expectedSite.RelayEndpoint != "" {
|
||||||
|
siteConfig.RelayEndpoint = expectedSite.RelayEndpoint
|
||||||
|
}
|
||||||
|
if expectedSite.PublicKey != "" {
|
||||||
|
siteConfig.PublicKey = expectedSite.PublicKey
|
||||||
|
}
|
||||||
|
if expectedSite.ServerIP != "" {
|
||||||
|
siteConfig.ServerIP = expectedSite.ServerIP
|
||||||
|
}
|
||||||
|
if expectedSite.ServerPort != 0 {
|
||||||
|
siteConfig.ServerPort = expectedSite.ServerPort
|
||||||
|
}
|
||||||
|
if expectedSite.RemoteSubnets != nil {
|
||||||
|
siteConfig.RemoteSubnets = expectedSite.RemoteSubnets
|
||||||
|
}
|
||||||
|
if expectedSite.Aliases != nil {
|
||||||
|
siteConfig.Aliases = expectedSite.Aliases
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
|
||||||
|
logger.Error("Sync: Failed to update peer %d: %v", siteId, err)
|
||||||
|
} else {
|
||||||
|
// If the endpoint changed, trigger holepunch to refresh NAT mappings
|
||||||
|
if expectedSite.Endpoint != "" && expectedSite.Endpoint != currentSite.Endpoint {
|
||||||
|
logger.Info("Sync: Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", siteId)
|
||||||
|
o.holePunchManager.TriggerHolePunch()
|
||||||
|
o.holePunchManager.ResetServerHolepunchInterval()
|
||||||
|
}
|
||||||
|
logger.Info("Sync: Successfully updated peer for site %d", siteId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers))
|
||||||
|
}
|
||||||
1296
olm/olm.go
1296
olm/olm.go
File diff suppressed because it is too large
Load Diff
195
olm/peer.go
Normal file
195
olm/peer.go
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/newt/util"
|
||||||
|
"github.com/fosrl/olm/peers"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received add-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
if o.stopPeerSend != nil {
|
||||||
|
o.stopPeerSend()
|
||||||
|
o.stopPeerSend = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var siteConfig peers.SiteConfig
|
||||||
|
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
|
||||||
|
logger.Error("Error unmarshaling add data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||||
|
|
||||||
|
if err := o.peerManager.AddPeer(siteConfig); err != nil {
|
||||||
|
logger.Error("Failed to add peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received remove-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var removeData peers.PeerRemove
|
||||||
|
if err := json.Unmarshal(jsonData, &removeData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling remove data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil {
|
||||||
|
logger.Error("Failed to remove peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any exit nodes associated with this peer from hole punching
|
||||||
|
if o.holePunchManager != nil {
|
||||||
|
removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId)
|
||||||
|
if removed > 0 {
|
||||||
|
logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received update-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var updateData peers.SiteConfig
|
||||||
|
if err := json.Unmarshal(jsonData, &updateData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling update data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get existing peer from PeerManager
|
||||||
|
existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId)
|
||||||
|
if !exists {
|
||||||
|
logger.Warn("Peer with site ID %d not found", updateData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create updated site config by merging with existing data
|
||||||
|
siteConfig := existingPeer
|
||||||
|
|
||||||
|
if updateData.Endpoint != "" {
|
||||||
|
siteConfig.Endpoint = updateData.Endpoint
|
||||||
|
}
|
||||||
|
if updateData.RelayEndpoint != "" {
|
||||||
|
siteConfig.RelayEndpoint = updateData.RelayEndpoint
|
||||||
|
}
|
||||||
|
if updateData.PublicKey != "" {
|
||||||
|
siteConfig.PublicKey = updateData.PublicKey
|
||||||
|
}
|
||||||
|
if updateData.ServerIP != "" {
|
||||||
|
siteConfig.ServerIP = updateData.ServerIP
|
||||||
|
}
|
||||||
|
if updateData.ServerPort != 0 {
|
||||||
|
siteConfig.ServerPort = updateData.ServerPort
|
||||||
|
}
|
||||||
|
if updateData.RemoteSubnets != nil {
|
||||||
|
siteConfig.RemoteSubnets = updateData.RemoteSubnets
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
|
||||||
|
logger.Error("Failed to update peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the endpoint changed, trigger holepunch to refresh NAT mappings
|
||||||
|
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
|
||||||
|
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
|
||||||
|
_ = o.holePunchManager.TriggerHolePunch()
|
||||||
|
o.holePunchManager.ResetServerHolepunchInterval()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received relay-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if peerManager is still valid (may be nil during shutdown)
|
||||||
|
if o.peerManager == nil {
|
||||||
|
logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var relayData peers.RelayPeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling relay data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve primary relay endpoint: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update HTTP server to mark this peer as using relay
|
||||||
|
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true)
|
||||||
|
|
||||||
|
o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received unrelay-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if peerManager is still valid (may be nil during shutdown)
|
||||||
|
if o.peerManager == nil {
|
||||||
|
logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var relayData peers.UnRelayPeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling relay data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update HTTP server to mark this peer as using relay
|
||||||
|
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false)
|
||||||
|
|
||||||
|
o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
|
||||||
|
}
|
||||||
10
olm/types.go
10
olm/types.go
@@ -12,9 +12,10 @@ type WgData struct {
|
|||||||
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
|
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
|
||||||
}
|
}
|
||||||
|
|
||||||
type GlobalConfig struct {
|
type OlmConfig struct {
|
||||||
// Logging
|
// Logging
|
||||||
LogLevel string
|
LogLevel string
|
||||||
|
LogFilePath string
|
||||||
|
|
||||||
// HTTP server
|
// HTTP server
|
||||||
EnableAPI bool
|
EnableAPI bool
|
||||||
@@ -22,6 +23,11 @@ type GlobalConfig struct {
|
|||||||
SocketPath string
|
SocketPath string
|
||||||
Version string
|
Version string
|
||||||
Agent string
|
Agent string
|
||||||
|
|
||||||
|
WakeUpDebounce time.Duration
|
||||||
|
|
||||||
|
// Debugging
|
||||||
|
PprofAddr string // Address to serve pprof on (e.g., "localhost:6060")
|
||||||
|
|
||||||
// Callbacks
|
// Callbacks
|
||||||
OnRegistered func()
|
OnRegistered func()
|
||||||
|
|||||||
51
olm/util.go
51
olm/util.go
@@ -1,60 +1,9 @@
|
|||||||
package olm
|
package olm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
|
||||||
"github.com/fosrl/newt/network"
|
|
||||||
"github.com/fosrl/olm/peers"
|
"github.com/fosrl/olm/peers"
|
||||||
"github.com/fosrl/olm/websocket"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func sendPing(olm *websocket.Client) error {
|
|
||||||
err := olm.SendMessage("olm/ping", map[string]interface{}{
|
|
||||||
"timestamp": time.Now().Unix(),
|
|
||||||
"userToken": olm.GetConfig().UserToken,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send ping message: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
logger.Debug("Sent ping message")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func keepSendingPing(olm *websocket.Client) {
|
|
||||||
// Send ping immediately on startup
|
|
||||||
if err := sendPing(olm); err != nil {
|
|
||||||
logger.Error("Failed to send initial ping: %v", err)
|
|
||||||
} else {
|
|
||||||
logger.Info("Sent initial ping message")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up ticker for one minute intervals
|
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stopPing:
|
|
||||||
logger.Info("Stopping ping messages")
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := sendPing(olm); err != nil {
|
|
||||||
logger.Error("Failed to send periodic ping: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetNetworkSettingsJSON() (string, error) {
|
|
||||||
return network.GetJSON()
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetNetworkSettingsIncrementor() int {
|
|
||||||
return network.GetIncrementor()
|
|
||||||
}
|
|
||||||
|
|
||||||
// slicesEqual compares two string slices for equality (order-independent)
|
// slicesEqual compares two string slices for equality (order-independent)
|
||||||
func slicesEqual(a, b []string) bool {
|
func slicesEqual(a, b []string) bool {
|
||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
|
|||||||
@@ -50,6 +50,8 @@ type PeerManager struct {
|
|||||||
// key is the CIDR string, value is a set of siteIds that want this IP
|
// key is the CIDR string, value is a set of siteIds that want this IP
|
||||||
allowedIPClaims map[string]map[int]bool
|
allowedIPClaims map[string]map[int]bool
|
||||||
APIServer *api.API
|
APIServer *api.API
|
||||||
|
|
||||||
|
PersistentKeepalive int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
|
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
|
||||||
@@ -84,6 +86,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
|
|||||||
return peer, ok
|
return peer, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerMonitor returns the internal peer monitor instance
|
||||||
|
func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor {
|
||||||
|
pm.mu.RLock()
|
||||||
|
defer pm.mu.RUnlock()
|
||||||
|
return pm.peerMonitor
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *PeerManager) GetAllPeers() []SiteConfig {
|
func (pm *PeerManager) GetAllPeers() []SiteConfig {
|
||||||
pm.mu.RLock()
|
pm.mu.RLock()
|
||||||
defer pm.mu.RUnlock()
|
defer pm.mu.RUnlock()
|
||||||
@@ -120,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
|||||||
wgConfig := siteConfig
|
wgConfig := siteConfig
|
||||||
wgConfig.AllowedIps = ownedIPs
|
wgConfig.AllowedIps = ownedIPs
|
||||||
|
|
||||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
|
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once
|
||||||
|
// without recreating them. Returns a map of siteId to error for any peers that failed to update.
|
||||||
|
func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error {
|
||||||
|
pm.mu.RLock()
|
||||||
|
defer pm.mu.RUnlock()
|
||||||
|
|
||||||
|
pm.PersistentKeepalive = interval
|
||||||
|
|
||||||
|
errors := make(map[int]error)
|
||||||
|
|
||||||
|
for siteId, peer := range pm.peers {
|
||||||
|
err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval)
|
||||||
|
if err != nil {
|
||||||
|
errors[siteId] = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *PeerManager) RemovePeer(siteId int) error {
|
func (pm *PeerManager) RemovePeer(siteId int) error {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
@@ -238,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
|
|||||||
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||||
wgConfig := promotedPeer
|
wgConfig := promotedPeer
|
||||||
wgConfig.AllowedIps = ownedIPs
|
wgConfig.AllowedIps = ownedIPs
|
||||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
|
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
|
||||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -314,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
|||||||
wgConfig := siteConfig
|
wgConfig := siteConfig
|
||||||
wgConfig.AllowedIps = ownedIPs
|
wgConfig.AllowedIps = ownedIPs
|
||||||
|
|
||||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
|
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
|||||||
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||||
promotedWgConfig := promotedPeer
|
promotedWgConfig := promotedPeer
|
||||||
promotedWgConfig.AllowedIps = promotedOwnedIPs
|
promotedWgConfig.AllowedIps = promotedOwnedIPs
|
||||||
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
|
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
|
||||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,8 +31,7 @@ type PeerMonitor struct {
|
|||||||
monitors map[int]*Client
|
monitors map[int]*Client
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
running bool
|
running bool
|
||||||
interval time.Duration
|
timeout time.Duration
|
||||||
timeout time.Duration
|
|
||||||
maxAttempts int
|
maxAttempts int
|
||||||
wsClient *websocket.Client
|
wsClient *websocket.Client
|
||||||
|
|
||||||
@@ -42,7 +41,7 @@ type PeerMonitor struct {
|
|||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
ep *channel.Endpoint
|
ep *channel.Endpoint
|
||||||
activePorts map[uint16]bool
|
activePorts map[uint16]bool
|
||||||
portsLock sync.Mutex
|
portsLock sync.RWMutex
|
||||||
nsCtx context.Context
|
nsCtx context.Context
|
||||||
nsCancel context.CancelFunc
|
nsCancel context.CancelFunc
|
||||||
nsWg sync.WaitGroup
|
nsWg sync.WaitGroup
|
||||||
@@ -50,17 +49,26 @@ type PeerMonitor struct {
|
|||||||
// Holepunch testing fields
|
// Holepunch testing fields
|
||||||
sharedBind *bind.SharedBind
|
sharedBind *bind.SharedBind
|
||||||
holepunchTester *holepunch.HolepunchTester
|
holepunchTester *holepunch.HolepunchTester
|
||||||
holepunchInterval time.Duration
|
|
||||||
holepunchTimeout time.Duration
|
holepunchTimeout time.Duration
|
||||||
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
||||||
holepunchStatus map[int]bool // siteID -> connected status
|
holepunchStatus map[int]bool // siteID -> connected status
|
||||||
holepunchStopChan chan struct{}
|
holepunchStopChan chan struct{}
|
||||||
|
holepunchUpdateChan chan struct{}
|
||||||
|
|
||||||
// Relay tracking fields
|
// Relay tracking fields
|
||||||
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
|
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
|
||||||
holepunchMaxAttempts int // max consecutive failures before triggering relay
|
holepunchMaxAttempts int // max consecutive failures before triggering relay
|
||||||
holepunchFailures map[int]int // siteID -> consecutive failure count
|
holepunchFailures map[int]int // siteID -> consecutive failure count
|
||||||
|
|
||||||
|
// Exponential backoff fields for holepunch monitor
|
||||||
|
defaultHolepunchMinInterval time.Duration // Minimum interval (initial)
|
||||||
|
defaultHolepunchMaxInterval time.Duration
|
||||||
|
holepunchMinInterval time.Duration // Minimum interval (initial)
|
||||||
|
holepunchMaxInterval time.Duration // Maximum interval (cap for backoff)
|
||||||
|
holepunchBackoffMultiplier float64 // Multiplier for each stable check
|
||||||
|
holepunchStableCount map[int]int // siteID -> consecutive stable status count
|
||||||
|
holepunchCurrentInterval time.Duration // Current interval with backoff applied
|
||||||
|
|
||||||
// Rapid initial test fields
|
// Rapid initial test fields
|
||||||
rapidTestInterval time.Duration // interval between rapid test attempts
|
rapidTestInterval time.Duration // interval between rapid test attempts
|
||||||
rapidTestTimeout time.Duration // timeout for each rapid test attempt
|
rapidTestTimeout time.Duration // timeout for each rapid test attempt
|
||||||
@@ -78,7 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
pm := &PeerMonitor{
|
pm := &PeerMonitor{
|
||||||
monitors: make(map[int]*Client),
|
monitors: make(map[int]*Client),
|
||||||
interval: 2 * time.Second, // Default check interval (faster)
|
|
||||||
timeout: 3 * time.Second,
|
timeout: 3 * time.Second,
|
||||||
maxAttempts: 3,
|
maxAttempts: 3,
|
||||||
wsClient: wsClient,
|
wsClient: wsClient,
|
||||||
@@ -88,7 +95,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
|||||||
nsCtx: ctx,
|
nsCtx: ctx,
|
||||||
nsCancel: cancel,
|
nsCancel: cancel,
|
||||||
sharedBind: sharedBind,
|
sharedBind: sharedBind,
|
||||||
holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds
|
|
||||||
holepunchTimeout: 2 * time.Second, // Faster timeout
|
holepunchTimeout: 2 * time.Second, // Faster timeout
|
||||||
holepunchEndpoints: make(map[int]string),
|
holepunchEndpoints: make(map[int]string),
|
||||||
holepunchStatus: make(map[int]bool),
|
holepunchStatus: make(map[int]bool),
|
||||||
@@ -101,6 +107,15 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
|||||||
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
|
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
|
||||||
apiServer: apiServer,
|
apiServer: apiServer,
|
||||||
wgConnectionStatus: make(map[int]bool),
|
wgConnectionStatus: make(map[int]bool),
|
||||||
|
// Exponential backoff settings for holepunch monitor
|
||||||
|
defaultHolepunchMinInterval: 2 * time.Second,
|
||||||
|
defaultHolepunchMaxInterval: 30 * time.Second,
|
||||||
|
holepunchMinInterval: 2 * time.Second,
|
||||||
|
holepunchMaxInterval: 30 * time.Second,
|
||||||
|
holepunchBackoffMultiplier: 1.5,
|
||||||
|
holepunchStableCount: make(map[int]int),
|
||||||
|
holepunchCurrentInterval: 2 * time.Second,
|
||||||
|
holepunchUpdateChan: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pm.initNetstack(); err != nil {
|
if err := pm.initNetstack(); err != nil {
|
||||||
@@ -116,41 +131,75 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetInterval changes how frequently peers are checked
|
// SetInterval changes how frequently peers are checked
|
||||||
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
|
func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) {
|
||||||
pm.mutex.Lock()
|
pm.mutex.Lock()
|
||||||
defer pm.mutex.Unlock()
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
pm.interval = interval
|
|
||||||
|
|
||||||
// Update interval for all existing monitors
|
// Update interval for all existing monitors
|
||||||
for _, client := range pm.monitors {
|
for _, client := range pm.monitors {
|
||||||
client.SetPacketInterval(interval)
|
client.SetPacketInterval(minInterval, maxInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTimeout changes the timeout for waiting for responses
|
func (pm *PeerMonitor) ResetPeerInterval() {
|
||||||
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
|
|
||||||
pm.mutex.Lock()
|
pm.mutex.Lock()
|
||||||
defer pm.mutex.Unlock()
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
pm.timeout = timeout
|
// Update interval for all existing monitors
|
||||||
|
|
||||||
// Update timeout for all existing monitors
|
|
||||||
for _, client := range pm.monitors {
|
for _, client := range pm.monitors {
|
||||||
client.SetTimeout(timeout)
|
client.ResetPacketInterval()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
|
||||||
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
|
func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
pm.holepunchMinInterval = minInterval
|
||||||
|
pm.holepunchMaxInterval = maxInterval
|
||||||
|
// Reset current interval to the new minimum
|
||||||
|
pm.holepunchCurrentInterval = minInterval
|
||||||
|
updateChan := pm.holepunchUpdateChan
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval)
|
||||||
|
|
||||||
|
// Signal the goroutine to apply the new interval if running
|
||||||
|
if updateChan != nil {
|
||||||
|
select {
|
||||||
|
case updateChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Channel full or closed, skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring
|
||||||
|
func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) {
|
||||||
pm.mutex.Lock()
|
pm.mutex.Lock()
|
||||||
defer pm.mutex.Unlock()
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
pm.maxAttempts = attempts
|
return pm.holepunchMinInterval, pm.holepunchMaxInterval
|
||||||
|
}
|
||||||
|
|
||||||
// Update max attempts for all existing monitors
|
func (pm *PeerMonitor) ResetPeerHolepunchInterval() {
|
||||||
for _, client := range pm.monitors {
|
pm.mutex.Lock()
|
||||||
client.SetMaxAttempts(attempts)
|
pm.holepunchMinInterval = pm.defaultHolepunchMinInterval
|
||||||
|
pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval
|
||||||
|
pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval
|
||||||
|
updateChan := pm.holepunchUpdateChan
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval)
|
||||||
|
|
||||||
|
// Signal the goroutine to apply the new interval if running
|
||||||
|
if updateChan != nil {
|
||||||
|
select {
|
||||||
|
case updateChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Channel full or closed, skip
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,10 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
client.SetPacketInterval(pm.interval)
|
|
||||||
client.SetTimeout(pm.timeout)
|
|
||||||
client.SetMaxAttempts(pm.maxAttempts)
|
|
||||||
|
|
||||||
pm.monitors[siteID] = client
|
pm.monitors[siteID] = client
|
||||||
|
|
||||||
pm.holepunchEndpoints[siteID] = holepunchEndpoint
|
pm.holepunchEndpoints[siteID] = holepunchEndpoint
|
||||||
@@ -470,31 +515,59 @@ func (pm *PeerMonitor) stopHolepunchMonitor() {
|
|||||||
logger.Info("Stopped holepunch connection monitor")
|
logger.Info("Stopped holepunch connection monitor")
|
||||||
}
|
}
|
||||||
|
|
||||||
// runHolepunchMonitor runs the holepunch monitoring loop
|
// runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff
|
||||||
func (pm *PeerMonitor) runHolepunchMonitor() {
|
func (pm *PeerMonitor) runHolepunchMonitor() {
|
||||||
ticker := time.NewTicker(pm.holepunchInterval)
|
pm.mutex.Lock()
|
||||||
defer ticker.Stop()
|
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
// Do initial check immediately
|
timer := time.NewTimer(0) // Fire immediately for initial check
|
||||||
pm.checkHolepunchEndpoints()
|
defer timer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-pm.holepunchStopChan:
|
case <-pm.holepunchStopChan:
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-pm.holepunchUpdateChan:
|
||||||
pm.checkHolepunchEndpoints()
|
// Interval settings changed, reset to minimum
|
||||||
|
pm.mutex.Lock()
|
||||||
|
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
||||||
|
currentInterval := pm.holepunchCurrentInterval
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
timer.Reset(currentInterval)
|
||||||
|
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
|
||||||
|
case <-timer.C:
|
||||||
|
anyStatusChanged := pm.checkHolepunchEndpoints()
|
||||||
|
|
||||||
|
pm.mutex.Lock()
|
||||||
|
if anyStatusChanged {
|
||||||
|
// Reset to minimum interval on any status change
|
||||||
|
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
||||||
|
} else {
|
||||||
|
// Apply exponential backoff when stable
|
||||||
|
newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier)
|
||||||
|
if newInterval > pm.holepunchMaxInterval {
|
||||||
|
newInterval = pm.holepunchMaxInterval
|
||||||
|
}
|
||||||
|
pm.holepunchCurrentInterval = newInterval
|
||||||
|
}
|
||||||
|
currentInterval := pm.holepunchCurrentInterval
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
timer.Reset(currentInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkHolepunchEndpoints tests all holepunch endpoints
|
// checkHolepunchEndpoints tests all holepunch endpoints
|
||||||
func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
// Returns true if any endpoint's status changed
|
||||||
|
func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
|
||||||
pm.mutex.Lock()
|
pm.mutex.Lock()
|
||||||
// Check if we're still running before doing any work
|
// Check if we're still running before doing any work
|
||||||
if !pm.running {
|
if !pm.running {
|
||||||
pm.mutex.Unlock()
|
pm.mutex.Unlock()
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
|
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
|
||||||
for siteID, endpoint := range pm.holepunchEndpoints {
|
for siteID, endpoint := range pm.holepunchEndpoints {
|
||||||
@@ -504,8 +577,10 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
|||||||
maxAttempts := pm.holepunchMaxAttempts
|
maxAttempts := pm.holepunchMaxAttempts
|
||||||
pm.mutex.Unlock()
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
anyStatusChanged := false
|
||||||
|
|
||||||
for siteID, endpoint := range endpoints {
|
for siteID, endpoint := range endpoints {
|
||||||
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
|
logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint)
|
||||||
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
|
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
|
||||||
|
|
||||||
pm.mutex.Lock()
|
pm.mutex.Lock()
|
||||||
@@ -529,7 +604,9 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
|||||||
pm.mutex.Unlock()
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
// Log status changes
|
// Log status changes
|
||||||
if !exists || previousStatus != result.Success {
|
statusChanged := !exists || previousStatus != result.Success
|
||||||
|
if statusChanged {
|
||||||
|
anyStatusChanged = true
|
||||||
if result.Success {
|
if result.Success {
|
||||||
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
|
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
|
||||||
} else {
|
} else {
|
||||||
@@ -562,7 +639,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
|||||||
pm.mutex.Unlock()
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
if !stillRunning {
|
if !stillRunning {
|
||||||
return // Stop processing if shutdown is in progress
|
return anyStatusChanged // Stop processing if shutdown is in progress
|
||||||
}
|
}
|
||||||
|
|
||||||
if !result.Success && !isRelayed && failureCount >= maxAttempts {
|
if !result.Success && !isRelayed && failureCount >= maxAttempts {
|
||||||
@@ -579,6 +656,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return anyStatusChanged
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHolepunchStatus returns the current holepunch status for all endpoints
|
// GetHolepunchStatus returns the current holepunch status for all endpoints
|
||||||
@@ -650,55 +729,55 @@ func (pm *PeerMonitor) Close() {
|
|||||||
logger.Debug("PeerMonitor: Cleanup complete")
|
logger.Debug("PeerMonitor: Cleanup complete")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPeer tests connectivity to a specific peer
|
// // TestPeer tests connectivity to a specific peer
|
||||||
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||||
pm.mutex.Lock()
|
// pm.mutex.Lock()
|
||||||
client, exists := pm.monitors[siteID]
|
// client, exists := pm.monitors[siteID]
|
||||||
pm.mutex.Unlock()
|
// pm.mutex.Unlock()
|
||||||
|
|
||||||
if !exists {
|
// if !exists {
|
||||||
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||||
}
|
// }
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||||
defer cancel()
|
// defer cancel()
|
||||||
|
|
||||||
connected, rtt := client.TestConnection(ctx)
|
// connected, rtt := client.TestPeerConnection(ctx)
|
||||||
return connected, rtt, nil
|
// return connected, rtt, nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
// TestAllPeers tests connectivity to all peers
|
// // TestAllPeers tests connectivity to all peers
|
||||||
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
// func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||||
Connected bool
|
// Connected bool
|
||||||
RTT time.Duration
|
// RTT time.Duration
|
||||||
} {
|
// } {
|
||||||
pm.mutex.Lock()
|
// pm.mutex.Lock()
|
||||||
peers := make(map[int]*Client, len(pm.monitors))
|
// peers := make(map[int]*Client, len(pm.monitors))
|
||||||
for siteID, client := range pm.monitors {
|
// for siteID, client := range pm.monitors {
|
||||||
peers[siteID] = client
|
// peers[siteID] = client
|
||||||
}
|
// }
|
||||||
pm.mutex.Unlock()
|
// pm.mutex.Unlock()
|
||||||
|
|
||||||
results := make(map[int]struct {
|
// results := make(map[int]struct {
|
||||||
Connected bool
|
// Connected bool
|
||||||
RTT time.Duration
|
// RTT time.Duration
|
||||||
})
|
// })
|
||||||
for siteID, client := range peers {
|
// for siteID, client := range peers {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||||
connected, rtt := client.TestConnection(ctx)
|
// connected, rtt := client.TestPeerConnection(ctx)
|
||||||
cancel()
|
// cancel()
|
||||||
|
|
||||||
results[siteID] = struct {
|
// results[siteID] = struct {
|
||||||
Connected bool
|
// Connected bool
|
||||||
RTT time.Duration
|
// RTT time.Duration
|
||||||
}{
|
// }{
|
||||||
Connected: connected,
|
// Connected: connected,
|
||||||
RTT: rtt,
|
// RTT: rtt,
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
return results
|
// return results
|
||||||
}
|
// }
|
||||||
|
|
||||||
// initNetstack initializes the gvisor netstack
|
// initNetstack initializes the gvisor netstack
|
||||||
func (pm *PeerMonitor) initNetstack() error {
|
func (pm *PeerMonitor) initNetstack() error {
|
||||||
@@ -770,9 +849,9 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if we are listening on this port
|
// Check if we are listening on this port
|
||||||
pm.portsLock.Lock()
|
pm.portsLock.RLock()
|
||||||
active := pm.activePorts[uint16(port)]
|
active := pm.activePorts[uint16(port)]
|
||||||
pm.portsLock.Unlock()
|
pm.portsLock.RUnlock()
|
||||||
|
|
||||||
if !active {
|
if !active {
|
||||||
return false
|
return false
|
||||||
@@ -803,13 +882,12 @@ func (pm *PeerMonitor) runPacketSender() {
|
|||||||
defer pm.nsWg.Done()
|
defer pm.nsWg.Done()
|
||||||
logger.Debug("PeerMonitor: Packet sender goroutine started")
|
logger.Debug("PeerMonitor: Packet sender goroutine started")
|
||||||
|
|
||||||
// Use a ticker to periodically check for packets without blocking indefinitely
|
|
||||||
ticker := time.NewTicker(10 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
// Use blocking ReadContext instead of polling - much more CPU efficient
|
||||||
case <-pm.nsCtx.Done():
|
// This will block until a packet is available or context is cancelled
|
||||||
|
pkt := pm.ep.ReadContext(pm.nsCtx)
|
||||||
|
if pkt == nil {
|
||||||
|
// Context was cancelled or endpoint closed
|
||||||
logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets")
|
logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets")
|
||||||
// Drain any remaining packets before exiting
|
// Drain any remaining packets before exiting
|
||||||
for {
|
for {
|
||||||
@@ -821,36 +899,28 @@ func (pm *PeerMonitor) runPacketSender() {
|
|||||||
}
|
}
|
||||||
logger.Debug("PeerMonitor: Packet sender goroutine exiting")
|
logger.Debug("PeerMonitor: Packet sender goroutine exiting")
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
|
||||||
// Try to read packets in batches
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
pkt := pm.ep.Read()
|
|
||||||
if pkt == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract packet data
|
|
||||||
slices := pkt.AsSlices()
|
|
||||||
if len(slices) > 0 {
|
|
||||||
var totalSize int
|
|
||||||
for _, slice := range slices {
|
|
||||||
totalSize += len(slice)
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, totalSize)
|
|
||||||
pos := 0
|
|
||||||
for _, slice := range slices {
|
|
||||||
copy(buf[pos:], slice)
|
|
||||||
pos += len(slice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Inject into MiddleDevice (outbound to WG)
|
|
||||||
pm.middleDev.InjectOutbound(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
pkt.DecRef()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract packet data
|
||||||
|
slices := pkt.AsSlices()
|
||||||
|
if len(slices) > 0 {
|
||||||
|
var totalSize int
|
||||||
|
for _, slice := range slices {
|
||||||
|
totalSize += len(slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, totalSize)
|
||||||
|
pos := 0
|
||||||
|
for _, slice := range slices {
|
||||||
|
copy(buf[pos:], slice)
|
||||||
|
pos += len(slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject into MiddleDevice (outbound to WG)
|
||||||
|
pm.middleDev.InjectOutbound(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt.DecRef()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,10 +32,19 @@ type Client struct {
|
|||||||
monitorLock sync.Mutex
|
monitorLock sync.Mutex
|
||||||
connLock sync.Mutex // Protects connection operations
|
connLock sync.Mutex // Protects connection operations
|
||||||
shutdownCh chan struct{}
|
shutdownCh chan struct{}
|
||||||
|
updateCh chan struct{}
|
||||||
packetInterval time.Duration
|
packetInterval time.Duration
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
maxAttempts int
|
maxAttempts int
|
||||||
dialer Dialer
|
dialer Dialer
|
||||||
|
|
||||||
|
// Exponential backoff fields
|
||||||
|
defaultMinInterval time.Duration // Default minimum interval (initial)
|
||||||
|
defaultMaxInterval time.Duration // Default maximum interval (cap for backoff)
|
||||||
|
minInterval time.Duration // Minimum interval (initial)
|
||||||
|
maxInterval time.Duration // Maximum interval (cap for backoff)
|
||||||
|
backoffMultiplier float64 // Multiplier for each stable check
|
||||||
|
stableCountToBackoff int // Number of stable checks before backing off
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dialer is a function that creates a connection
|
// Dialer is a function that creates a connection
|
||||||
@@ -50,28 +59,59 @@ type ConnectionStatus struct {
|
|||||||
// NewClient creates a new connection test client
|
// NewClient creates a new connection test client
|
||||||
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
|
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
|
||||||
return &Client{
|
return &Client{
|
||||||
serverAddr: serverAddr,
|
serverAddr: serverAddr,
|
||||||
shutdownCh: make(chan struct{}),
|
shutdownCh: make(chan struct{}),
|
||||||
packetInterval: 2 * time.Second,
|
updateCh: make(chan struct{}, 1),
|
||||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
packetInterval: 2 * time.Second,
|
||||||
maxAttempts: 3, // Default max attempts
|
defaultMinInterval: 2 * time.Second,
|
||||||
dialer: dialer,
|
defaultMaxInterval: 30 * time.Second,
|
||||||
|
minInterval: 2 * time.Second,
|
||||||
|
maxInterval: 30 * time.Second,
|
||||||
|
backoffMultiplier: 1.5,
|
||||||
|
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
|
||||||
|
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||||
|
maxAttempts: 3, // Default max attempts
|
||||||
|
dialer: dialer,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPacketInterval changes how frequently packets are sent in monitor mode
|
// SetPacketInterval changes how frequently packets are sent in monitor mode
|
||||||
func (c *Client) SetPacketInterval(interval time.Duration) {
|
func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) {
|
||||||
c.packetInterval = interval
|
c.monitorLock.Lock()
|
||||||
|
c.packetInterval = minInterval
|
||||||
|
c.minInterval = minInterval
|
||||||
|
c.maxInterval = maxInterval
|
||||||
|
updateCh := c.updateCh
|
||||||
|
monitorRunning := c.monitorRunning
|
||||||
|
c.monitorLock.Unlock()
|
||||||
|
|
||||||
|
// Signal the goroutine to apply the new interval if running
|
||||||
|
if monitorRunning && updateCh != nil {
|
||||||
|
select {
|
||||||
|
case updateCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Channel full or closed, skip
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTimeout changes the timeout for waiting for responses
|
func (c *Client) ResetPacketInterval() {
|
||||||
func (c *Client) SetTimeout(timeout time.Duration) {
|
c.monitorLock.Lock()
|
||||||
c.timeout = timeout
|
c.packetInterval = c.defaultMinInterval
|
||||||
}
|
c.minInterval = c.defaultMinInterval
|
||||||
|
c.maxInterval = c.defaultMaxInterval
|
||||||
|
updateCh := c.updateCh
|
||||||
|
monitorRunning := c.monitorRunning
|
||||||
|
c.monitorLock.Unlock()
|
||||||
|
|
||||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
// Signal the goroutine to apply the new interval if running
|
||||||
func (c *Client) SetMaxAttempts(attempts int) {
|
if monitorRunning && updateCh != nil {
|
||||||
c.maxAttempts = attempts
|
select {
|
||||||
|
case updateCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Channel full or closed, skip
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateServerAddr updates the server address and resets the connection
|
// UpdateServerAddr updates the server address and resets the connection
|
||||||
@@ -125,9 +165,10 @@ func (c *Client) ensureConnection() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestConnection checks if the connection to the server is working
|
// TestPeerConnection checks if the connection to the server is working
|
||||||
// Returns true if connected, false otherwise
|
// Returns true if connected, false otherwise
|
||||||
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
|
||||||
|
logger.Debug("wgtester: testing connection to peer %s", c.serverAddr)
|
||||||
if err := c.ensureConnection(); err != nil {
|
if err := c.ensureConnection(); err != nil {
|
||||||
logger.Warn("Failed to ensure connection: %v", err)
|
logger.Warn("Failed to ensure connection: %v", err)
|
||||||
return false, 0
|
return false, 0
|
||||||
@@ -138,6 +179,9 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
|||||||
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
|
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
|
||||||
packet[4] = packetTypeRequest
|
packet[4] = packetTypeRequest
|
||||||
|
|
||||||
|
// Reusable response buffer
|
||||||
|
responseBuffer := make([]byte, packetSize)
|
||||||
|
|
||||||
// Send multiple attempts as specified
|
// Send multiple attempts as specified
|
||||||
for attempt := 0; attempt < c.maxAttempts; attempt++ {
|
for attempt := 0; attempt < c.maxAttempts; attempt++ {
|
||||||
select {
|
select {
|
||||||
@@ -157,20 +201,17 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
|||||||
return false, 0
|
return false, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
|
|
||||||
_, err := c.conn.Write(packet)
|
_, err := c.conn.Write(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.connLock.Unlock()
|
c.connLock.Unlock()
|
||||||
logger.Info("Error sending packet: %v", err)
|
logger.Info("Error sending packet: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// logger.Debug("Successfully sent monitor packet")
|
|
||||||
|
|
||||||
// Set read deadline
|
// Set read deadline
|
||||||
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||||
|
|
||||||
// Wait for response
|
// Wait for response
|
||||||
responseBuffer := make([]byte, packetSize)
|
|
||||||
n, err := c.conn.Read(responseBuffer)
|
n, err := c.conn.Read(responseBuffer)
|
||||||
c.connLock.Unlock()
|
c.connLock.Unlock()
|
||||||
|
|
||||||
@@ -211,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
|||||||
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
|
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
return c.TestConnection(ctx)
|
return c.TestPeerConnection(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MonitorCallback is the function type for connection status change callbacks
|
// MonitorCallback is the function type for connection status change callbacks
|
||||||
@@ -238,28 +279,61 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
|
|||||||
go func() {
|
go func() {
|
||||||
var lastConnected bool
|
var lastConnected bool
|
||||||
firstRun := true
|
firstRun := true
|
||||||
|
stableCount := 0
|
||||||
|
currentInterval := c.minInterval
|
||||||
|
|
||||||
ticker := time.NewTicker(c.packetInterval)
|
timer := time.NewTimer(currentInterval)
|
||||||
defer ticker.Stop()
|
defer timer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-c.shutdownCh:
|
case <-c.shutdownCh:
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-c.updateCh:
|
||||||
|
// Interval settings changed, reset to minimum
|
||||||
|
c.monitorLock.Lock()
|
||||||
|
currentInterval = c.minInterval
|
||||||
|
c.monitorLock.Unlock()
|
||||||
|
|
||||||
|
// Reset backoff state
|
||||||
|
stableCount = 0
|
||||||
|
|
||||||
|
timer.Reset(currentInterval)
|
||||||
|
logger.Debug("Packet interval updated, reset to %v", currentInterval)
|
||||||
|
case <-timer.C:
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||||
connected, rtt := c.TestConnection(ctx)
|
connected, rtt := c.TestPeerConnection(ctx)
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
|
statusChanged := connected != lastConnected
|
||||||
|
|
||||||
// Callback if status changed or it's the first check
|
// Callback if status changed or it's the first check
|
||||||
if connected != lastConnected || firstRun {
|
if statusChanged || firstRun {
|
||||||
callback(ConnectionStatus{
|
callback(ConnectionStatus{
|
||||||
Connected: connected,
|
Connected: connected,
|
||||||
RTT: rtt,
|
RTT: rtt,
|
||||||
})
|
})
|
||||||
lastConnected = connected
|
lastConnected = connected
|
||||||
firstRun = false
|
firstRun = false
|
||||||
|
// Reset backoff on status change
|
||||||
|
stableCount = 0
|
||||||
|
currentInterval = c.minInterval
|
||||||
|
} else {
|
||||||
|
// Status is stable, increment counter
|
||||||
|
stableCount++
|
||||||
|
|
||||||
|
// Apply exponential backoff after stable threshold
|
||||||
|
if stableCount >= c.stableCountToBackoff {
|
||||||
|
newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier)
|
||||||
|
if newInterval > c.maxInterval {
|
||||||
|
newInterval = c.maxInterval
|
||||||
|
}
|
||||||
|
currentInterval = newInterval
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset timer with current interval
|
||||||
|
timer.Reset(currentInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error {
|
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
|
||||||
var endpoint string
|
var endpoint string
|
||||||
if relay && siteConfig.RelayEndpoint != "" {
|
if relay && siteConfig.RelayEndpoint != "" {
|
||||||
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
|
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
|
||||||
@@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
|
|||||||
}
|
}
|
||||||
|
|
||||||
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
||||||
configBuilder.WriteString("persistent_keepalive_interval=5\n")
|
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive))
|
||||||
|
|
||||||
config := configBuilder.String()
|
config := configBuilder.String()
|
||||||
logger.Debug("Configuring peer with config: %s", config)
|
logger.Debug("Configuring peer with config: %s", config)
|
||||||
@@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it
|
||||||
|
func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error {
|
||||||
|
var configBuilder strings.Builder
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||||
|
configBuilder.WriteString("update_only=true\n")
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval))
|
||||||
|
|
||||||
|
config := configBuilder.String()
|
||||||
|
logger.Debug("Updating persistent keepalive for peer with config: %s", config)
|
||||||
|
|
||||||
|
err := dev.IpcSet(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func formatEndpoint(endpoint string) string {
|
func formatEndpoint(endpoint string) string {
|
||||||
if strings.Contains(endpoint, ":") {
|
if strings.Contains(endpoint, ":") {
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ type Client struct {
|
|||||||
handlersMux sync.RWMutex
|
handlersMux sync.RWMutex
|
||||||
reconnectInterval time.Duration
|
reconnectInterval time.Duration
|
||||||
isConnected bool
|
isConnected bool
|
||||||
|
isDisconnected bool // Flag to track if client is intentionally disconnected
|
||||||
reconnectMux sync.RWMutex
|
reconnectMux sync.RWMutex
|
||||||
pingInterval time.Duration
|
pingInterval time.Duration
|
||||||
pingTimeout time.Duration
|
pingTimeout time.Duration
|
||||||
@@ -91,6 +92,10 @@ type Client struct {
|
|||||||
configNeedsSave bool // Flag to track if config needs to be saved
|
configNeedsSave bool // Flag to track if config needs to be saved
|
||||||
configVersion int // Latest config version received from server
|
configVersion int // Latest config version received from server
|
||||||
configVersionMux sync.RWMutex
|
configVersionMux sync.RWMutex
|
||||||
|
token string // Cached authentication token
|
||||||
|
exitNodes []ExitNode // Cached exit nodes from token response
|
||||||
|
tokenMux sync.RWMutex // Protects token and exitNodes
|
||||||
|
forceNewToken bool // Flag to force fetching a new token on next connection
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientOption func(*Client)
|
type ClientOption func(*Client)
|
||||||
@@ -177,6 +182,9 @@ func (c *Client) GetConfig() *Config {
|
|||||||
|
|
||||||
// Connect establishes the WebSocket connection
|
// Connect establishes the WebSocket connection
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
|
if c.isDisconnected {
|
||||||
|
c.isDisconnected = false
|
||||||
|
}
|
||||||
go c.connectWithRetry()
|
go c.connectWithRetry()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -209,9 +217,25 @@ func (c *Client) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
|
||||||
|
func (c *Client) Disconnect() error {
|
||||||
|
c.isDisconnected = true
|
||||||
|
c.setConnected(false)
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
c.writeMux.Lock()
|
||||||
|
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||||
|
c.writeMux.Unlock()
|
||||||
|
err := c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SendMessage sends a message through the WebSocket connection
|
// SendMessage sends a message through the WebSocket connection
|
||||||
func (c *Client) SendMessage(messageType string, data interface{}) error {
|
func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||||
if c.conn == nil {
|
if c.isDisconnected || c.conn == nil {
|
||||||
return fmt.Errorf("not connected")
|
return fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,14 +244,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
|
|||||||
Data: data,
|
Data: data,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Sending message: %s, data: %+v", messageType, data)
|
logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
|
||||||
|
|
||||||
c.writeMux.Lock()
|
c.writeMux.Lock()
|
||||||
defer c.writeMux.Unlock()
|
defer c.writeMux.Unlock()
|
||||||
return c.conn.WriteJSON(msg)
|
return c.conn.WriteJSON(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
|
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
|
||||||
stopChan := make(chan struct{})
|
stopChan := make(chan struct{})
|
||||||
updateChan := make(chan interface{})
|
updateChan := make(chan interface{})
|
||||||
var dataMux sync.Mutex
|
var dataMux sync.Mutex
|
||||||
@@ -235,30 +259,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
count := 0
|
count := 0
|
||||||
maxAttempts := 10
|
|
||||||
|
|
||||||
err := c.SendMessage(messageType, currentData) // Send immediately
|
send := func() {
|
||||||
if err != nil {
|
if c.isDisconnected || c.conn == nil {
|
||||||
logger.Error("Failed to send initial message: %v", err)
|
return
|
||||||
|
}
|
||||||
|
err := c.SendMessage(messageType, currentData)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("websocket: Failed to send message: %v", err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
}
|
}
|
||||||
count++
|
|
||||||
|
send() // Send immediately
|
||||||
|
|
||||||
ticker := time.NewTicker(interval)
|
ticker := time.NewTicker(interval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if count >= maxAttempts {
|
if maxAttempts != -1 && count >= maxAttempts {
|
||||||
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataMux.Lock()
|
dataMux.Lock()
|
||||||
err = c.SendMessage(messageType, currentData)
|
send()
|
||||||
dataMux.Unlock()
|
dataMux.Unlock()
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send message: %v", err)
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
case newData := <-updateChan:
|
case newData := <-updateChan:
|
||||||
dataMux.Lock()
|
dataMux.Lock()
|
||||||
// Merge newData into currentData if both are maps
|
// Merge newData into currentData if both are maps
|
||||||
@@ -281,6 +307,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
|||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Suspend sending if disconnected
|
||||||
|
for c.isDisconnected {
|
||||||
|
select {
|
||||||
|
case <-stopChan:
|
||||||
|
return
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return func() {
|
return func() {
|
||||||
@@ -327,7 +361,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
|||||||
tlsConfig = &tls.Config{}
|
tlsConfig = &tls.Config{}
|
||||||
}
|
}
|
||||||
tlsConfig.InsecureSkipVerify = true
|
tlsConfig.InsecureSkipVerify = true
|
||||||
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenData := map[string]interface{}{
|
tokenData := map[string]interface{}{
|
||||||
@@ -356,7 +390,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
|||||||
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
||||||
|
|
||||||
// print out the request for debugging
|
// print out the request for debugging
|
||||||
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
|
logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
|
||||||
|
|
||||||
// Make the request
|
// Make the request
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
@@ -373,7 +407,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
|||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||||
|
|
||||||
// Return AuthError for 401/403 status codes
|
// Return AuthError for 401/403 status codes
|
||||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||||
@@ -389,7 +423,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
|||||||
|
|
||||||
var tokenResp TokenResponse
|
var tokenResp TokenResponse
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||||
logger.Error("Failed to decode token response.")
|
logger.Error("websocket: Failed to decode token response.")
|
||||||
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
|
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -401,7 +435,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
|||||||
return "", nil, fmt.Errorf("received empty token from server")
|
return "", nil, fmt.Errorf("received empty token from server")
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
|
||||||
|
|
||||||
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
|
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
|
||||||
}
|
}
|
||||||
@@ -427,7 +461,7 @@ func (c *Client) connectWithRetry() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// For other errors (5xx, network issues), continue retrying
|
// For other errors (5xx, network issues), continue retrying
|
||||||
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||||
time.Sleep(c.reconnectInterval)
|
time.Sleep(c.reconnectInterval)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -437,15 +471,25 @@ func (c *Client) connectWithRetry() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) establishConnection() error {
|
func (c *Client) establishConnection() error {
|
||||||
// Get token for authentication
|
// Get token for authentication - reuse cached token unless forced to get new one
|
||||||
token, exitNodes, err := c.getToken()
|
c.tokenMux.Lock()
|
||||||
if err != nil {
|
needNewToken := c.token == "" || c.forceNewToken
|
||||||
return fmt.Errorf("failed to get token: %w", err)
|
if needNewToken {
|
||||||
}
|
token, exitNodes, err := c.getToken()
|
||||||
|
if err != nil {
|
||||||
|
c.tokenMux.Unlock()
|
||||||
|
return fmt.Errorf("failed to get token: %w", err)
|
||||||
|
}
|
||||||
|
c.token = token
|
||||||
|
c.exitNodes = exitNodes
|
||||||
|
c.forceNewToken = false
|
||||||
|
|
||||||
if c.onTokenUpdate != nil {
|
if c.onTokenUpdate != nil {
|
||||||
c.onTokenUpdate(token, exitNodes)
|
c.onTokenUpdate(token, exitNodes)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
token := c.token
|
||||||
|
c.tokenMux.Unlock()
|
||||||
|
|
||||||
// Parse the base URL to determine protocol and hostname
|
// Parse the base URL to determine protocol and hostname
|
||||||
baseURL, err := url.Parse(c.baseURL)
|
baseURL, err := url.Parse(c.baseURL)
|
||||||
@@ -480,7 +524,7 @@ func (c *Client) establishConnection() error {
|
|||||||
|
|
||||||
// Use new TLS configuration method
|
// Use new TLS configuration method
|
||||||
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||||
logger.Info("Setting up TLS configuration for WebSocket connection")
|
logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
|
||||||
tlsConfig, err := c.setupTLS()
|
tlsConfig, err := c.setupTLS()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to setup TLS configuration: %w", err)
|
return fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||||
@@ -494,11 +538,23 @@ func (c *Client) establishConnection() error {
|
|||||||
dialer.TLSClientConfig = &tls.Config{}
|
dialer.TLSClientConfig = &tls.Config{}
|
||||||
}
|
}
|
||||||
dialer.TLSClientConfig.InsecureSkipVerify = true
|
dialer.TLSClientConfig.InsecureSkipVerify = true
|
||||||
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, _, err := dialer.Dial(u.String(), nil)
|
conn, resp, err := dialer.Dial(u.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Check if this is an unauthorized error (401)
|
||||||
|
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
|
||||||
|
logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized")
|
||||||
|
// Force getting a new token on next reconnect attempt
|
||||||
|
c.tokenMux.Lock()
|
||||||
|
c.forceNewToken = true
|
||||||
|
c.tokenMux.Unlock()
|
||||||
|
return &AuthError{
|
||||||
|
StatusCode: http.StatusUnauthorized,
|
||||||
|
Message: "WebSocket connection unauthorized",
|
||||||
|
}
|
||||||
|
}
|
||||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -512,7 +568,7 @@ func (c *Client) establishConnection() error {
|
|||||||
|
|
||||||
if c.onConnect != nil {
|
if c.onConnect != nil {
|
||||||
if err := c.onConnect(); err != nil {
|
if err := c.onConnect(); err != nil {
|
||||||
logger.Error("OnConnect callback failed: %v", err)
|
logger.Error("websocket: OnConnect callback failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -525,9 +581,9 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
|||||||
|
|
||||||
// Handle new separate certificate configuration
|
// Handle new separate certificate configuration
|
||||||
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
|
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
|
||||||
logger.Info("Loading separate certificate files for mTLS")
|
logger.Info("websocket: Loading separate certificate files for mTLS")
|
||||||
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
|
logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||||
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
|
logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||||
|
|
||||||
// Load client certificate and key
|
// Load client certificate and key
|
||||||
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
|
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
|
||||||
@@ -538,7 +594,7 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
|||||||
|
|
||||||
// Load CA certificates for remote validation if specified
|
// Load CA certificates for remote validation if specified
|
||||||
if len(c.tlsConfig.CAFiles) > 0 {
|
if len(c.tlsConfig.CAFiles) > 0 {
|
||||||
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||||
caCertPool := x509.NewCertPool()
|
caCertPool := x509.NewCertPool()
|
||||||
for _, caFile := range c.tlsConfig.CAFiles {
|
for _, caFile := range c.tlsConfig.CAFiles {
|
||||||
caCert, err := os.ReadFile(caFile)
|
caCert, err := os.ReadFile(caFile)
|
||||||
@@ -564,13 +620,13 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
|||||||
|
|
||||||
// Fallback to existing PKCS12 implementation for backward compatibility
|
// Fallback to existing PKCS12 implementation for backward compatibility
|
||||||
if c.tlsConfig.PKCS12File != "" {
|
if c.tlsConfig.PKCS12File != "" {
|
||||||
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
|
logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
|
||||||
return c.setupPKCS12TLS()
|
return c.setupPKCS12TLS()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Legacy fallback using config.TlsClientCert
|
// Legacy fallback using config.TlsClientCert
|
||||||
if c.config.TlsClientCert != "" {
|
if c.config.TlsClientCert != "" {
|
||||||
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||||
return loadClientCertificate(c.config.TlsClientCert)
|
return loadClientCertificate(c.config.TlsClientCert)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -592,7 +648,7 @@ func (c *Client) pingMonitor() {
|
|||||||
case <-c.done:
|
case <-c.done:
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if c.conn == nil {
|
if c.isDisconnected || c.conn == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Send application-level ping with config version
|
// Send application-level ping with config version
|
||||||
@@ -616,7 +672,7 @@ func (c *Client) pingMonitor() {
|
|||||||
// Expected during shutdown
|
// Expected during shutdown
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
logger.Error("Ping failed: %v", err)
|
logger.Error("websocket: Ping failed: %v", err)
|
||||||
c.reconnect()
|
c.reconnect()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -665,18 +721,24 @@ func (c *Client) readPumpWithDisconnectDetection() {
|
|||||||
var msg WSMessage
|
var msg WSMessage
|
||||||
err := c.conn.ReadJSON(&msg)
|
err := c.conn.ReadJSON(&msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Check if we're shutting down before logging error
|
// Check if we're shutting down or explicitly disconnected before logging error
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
// Expected during shutdown, don't log as error
|
// Expected during shutdown, don't log as error
|
||||||
logger.Debug("WebSocket connection closed during shutdown")
|
logger.Debug("websocket: connection closed during shutdown")
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
|
// Check if explicitly disconnected
|
||||||
|
if c.isDisconnected {
|
||||||
|
logger.Debug("websocket: connection closed: client was explicitly disconnected")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Unexpected error during normal operation
|
// Unexpected error during normal operation
|
||||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
||||||
logger.Error("WebSocket read error: %v", err)
|
logger.Error("websocket: read error: %v", err)
|
||||||
} else {
|
} else {
|
||||||
logger.Debug("WebSocket connection closed: %v", err)
|
logger.Debug("websocket: connection closed: %v", err)
|
||||||
}
|
}
|
||||||
return // triggers reconnect via defer
|
return // triggers reconnect via defer
|
||||||
}
|
}
|
||||||
@@ -703,6 +765,12 @@ func (c *Client) reconnect() {
|
|||||||
c.conn = nil
|
c.conn = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Don't reconnect if explicitly disconnected
|
||||||
|
if c.isDisconnected {
|
||||||
|
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Only reconnect if we're not shutting down
|
// Only reconnect if we're not shutting down
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
@@ -720,7 +788,7 @@ func (c *Client) setConnected(status bool) {
|
|||||||
|
|
||||||
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
|
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
|
||||||
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||||
logger.Info("Loading tls-client-cert %s", p12Path)
|
logger.Info("websocket: Loading tls-client-cert %s", p12Path)
|
||||||
// Read the PKCS12 file
|
// Read the PKCS12 file
|
||||||
p12Data, err := os.ReadFile(p12Path)
|
p12Data, err := os.ReadFile(p12Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user