Migrate peer monitor into peer manager

This commit is contained in:
Owen
2025-12-01 21:28:14 -05:00
parent 23e7b173c9
commit 29f0babf07
6 changed files with 154 additions and 126 deletions

View File

@@ -20,7 +20,6 @@ import (
olmDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns"
dnsOverride "github.com/fosrl/olm/dns/override"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/peers"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
@@ -32,7 +31,6 @@ var (
privateKey wgtypes.Key
connected bool
dev *device.Device
wgData WgData
uapiListener net.Listener
tdev tun.Device
middleDev *olmDevice.MiddleDevice
@@ -43,7 +41,6 @@ var (
tunnelRunning bool
sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager
peerMonitor *peermonitor.PeerMonitor
globalConfig GlobalConfig
tunnelConfig TunnelConfig
globalCtx context.Context
@@ -269,6 +266,8 @@ func StartTunnel(config TunnelConfig) {
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
var wgData WgData
if connected {
logger.Info("Already connected. Ignoring new connection request.")
return
@@ -398,17 +397,28 @@ func StartTunnel(config TunnelConfig) {
wsClientForMonitor = olm
}
peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
// Create peer manager with integrated peer monitoring
peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
Device: dev,
DNSProxy: dnsProxy,
InterfaceName: interfaceName,
PrivateKey: privateKey,
MiddleDev: middleDev,
LocalIP: interfaceIP,
SharedBind: sharedBind,
WSClient: wsClientForMonitor,
StatusCallback: func(siteID int, connected bool, rtt time.Duration) {
// Find the site config to get endpoint information
var endpoint string
var isRelay bool
for _, site := range wgData.Sites {
if site.SiteId == siteID {
endpoint = site.Endpoint
// TODO: We'll need to track relay status separately
// For now, assume not using relay unless we get relay data
isRelay = !config.Holepunch
if site.RelayEndpoint != "" {
endpoint = site.RelayEndpoint
} else {
endpoint = site.Endpoint
}
isRelay = site.RelayEndpoint != ""
break
}
}
@@ -419,43 +429,41 @@ func StartTunnel(config TunnelConfig) {
logger.Warn("Peer %d is disconnected", siteID)
}
},
wsClientForMonitor,
middleDev,
interfaceIP,
sharedBind, // Pass sharedBind for holepunch testing
)
peerManager = peers.NewPeerManager(dev, peerMonitor, dnsProxy, interfaceName, privateKey)
})
for i := range wgData.Sites {
site := wgData.Sites[i]
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {
siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
}
apiServer.UpdatePeerStatus(site.SiteId, false, 0, siteEndpoint, false)
if err := peerManager.AddPeer(site, endpoint); err != nil {
if err := peerManager.AddPeer(site, siteEndpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
// Add holepunch monitoring for this endpoint if holepunching is enabled
if config.Holepunch {
peerMonitor.AddHolepunchEndpoint(site.SiteId, site.Endpoint)
}
logger.Info("Configured peer %s", site.PublicKey)
}
peerMonitor.SetHolepunchStatusCallback(func(siteID int, endpoint string, connected bool, rtt time.Duration) {
peerManager.SetHolepunchStatusCallback(func(siteID int, endpoint string, connected bool, rtt time.Duration) {
// This callback is for additional handling if needed
// The PeerMonitor already logs status changes
logger.Info("+++++++++++++++++++++++++ holepunch monitor callback for site %d, endpoint %s, connected: %v, rtt: %v", siteID, endpoint, connected, rtt)
})
peerMonitor.Start()
peerManager.Start()
// Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy); err != nil {
logger.Error("Failed to setup DNS override: %v", err)
return
if config.OverrideDNS {
// Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy); err != nil {
logger.Error("Failed to setup DNS override: %v", err)
return
}
}
if err := dnsProxy.Start(); err != nil {
@@ -906,12 +914,8 @@ func Close() {
updateRegister = nil
}
if peerMonitor != nil {
peerMonitor.Close() // Close() also calls Stop() internally
peerMonitor = nil
}
if peerManager != nil {
peerManager.Close() // Close() also calls Stop() internally
peerManager = nil
}

View File

@@ -3,22 +3,50 @@ package peers
import (
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/fosrl/newt/bind"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
olmDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/peers/monitor"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// PeerStatusCallback is called when a peer's connection status changes
type PeerStatusCallback func(siteID int, connected bool, rtt time.Duration)
// HolepunchStatusCallback is called when holepunch connection status changes
// This is an alias for monitor.HolepunchStatusCallback
type HolepunchStatusCallback = monitor.HolepunchStatusCallback
// PeerManagerConfig contains the configuration for creating a PeerManager
type PeerManagerConfig struct {
Device *device.Device
DNSProxy *dns.DNSProxy
InterfaceName string
PrivateKey wgtypes.Key
// For peer monitoring
MiddleDev *olmDevice.MiddleDevice
LocalIP string
SharedBind *bind.SharedBind
// WSClient is optional - if nil, relay messages won't be sent
WSClient *websocket.Client
// StatusCallback is called when peer connection status changes
StatusCallback PeerStatusCallback
}
type PeerManager struct {
mu sync.RWMutex
device *device.Device
peers map[int]SiteConfig
peerMonitor *peermonitor.PeerMonitor
peerMonitor *monitor.PeerMonitor
dnsProxy *dns.DNSProxy
interfaceName string
privateKey wgtypes.Key
@@ -28,19 +56,38 @@ type PeerManager struct {
// allowedIPClaims tracks all peers that claim each allowed IP
// key is the CIDR string, value is a set of siteIds that want this IP
allowedIPClaims map[string]map[int]bool
// statusCallback is called when peer connection status changes
statusCallback PeerStatusCallback
}
func NewPeerManager(dev *device.Device, monitor *peermonitor.PeerMonitor, dnsProxy *dns.DNSProxy, interfaceName string, privateKey wgtypes.Key) *PeerManager {
return &PeerManager{
device: dev,
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
func NewPeerManager(config PeerManagerConfig) *PeerManager {
pm := &PeerManager{
device: config.Device,
peers: make(map[int]SiteConfig),
peerMonitor: monitor,
dnsProxy: dnsProxy,
interfaceName: interfaceName,
privateKey: privateKey,
dnsProxy: config.DNSProxy,
interfaceName: config.InterfaceName,
privateKey: config.PrivateKey,
allowedIPOwners: make(map[string]int),
allowedIPClaims: make(map[string]map[int]bool),
statusCallback: config.StatusCallback,
}
// Create the peer monitor
pm.peerMonitor = monitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
// Call the external status callback if set
if pm.statusCallback != nil {
pm.statusCallback(siteID, connected, rtt)
}
},
config.WSClient,
config.MiddleDev,
config.LocalIP,
config.SharedBind,
)
return pm
}
func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
@@ -86,7 +133,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig, endpoint string) error {
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint, pm.peerMonitor); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint); err != nil {
return err
}
@@ -104,6 +151,16 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig, endpoint string) error {
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
}
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
err := pm.peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer)
if err != nil {
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
} else {
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
}
pm.peers[siteConfig.SiteId] = siteConfig
return nil
}
@@ -117,7 +174,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
return fmt.Errorf("peer with site ID %d not found", siteId)
}
if err := RemovePeer(pm.device, siteId, peer.PublicKey, pm.peerMonitor); err != nil {
if err := RemovePeer(pm.device, siteId, peer.PublicKey); err != nil {
return err
}
@@ -167,12 +224,16 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
wgConfig := promotedPeer
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, promotedPeer.Endpoint, pm.peerMonitor); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, promotedPeer.Endpoint); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}
}
// Stop monitoring this peer
pm.peerMonitor.RemovePeer(siteId)
logger.Info("Stopped monitoring for site %d", siteId)
delete(pm.peers, siteId)
return nil
}
@@ -188,7 +249,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
// If public key changed, remove old peer first
if siteConfig.PublicKey != oldPeer.PublicKey {
if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey, pm.peerMonitor); err != nil {
if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
}
}
@@ -237,7 +298,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint, pm.peerMonitor); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint); err != nil {
return err
}
@@ -247,7 +308,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
promotedWgConfig := promotedPeer
promotedWgConfig.AllowedIps = promotedOwnedIPs
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, promotedPeer.Endpoint, pm.peerMonitor); err != nil {
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, promotedPeer.Endpoint); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}
@@ -399,7 +460,7 @@ func (pm *PeerManager) addAllowedIp(siteId int, ip string) error {
// Only update WireGuard if we own this IP
if pm.allowedIPOwners[ip] == siteId {
if err := ConfigurePeer(pm.device, peer, pm.privateKey, peer.Endpoint, pm.peerMonitor); err != nil {
if err := ConfigurePeer(pm.device, peer, pm.privateKey, peer.Endpoint); err != nil {
return err
}
}
@@ -439,14 +500,14 @@ func (pm *PeerManager) removeAllowedIp(siteId int, cidr string) error {
newOwner, promoted := pm.releaseAllowedIP(siteId, cidr)
// Update WireGuard for this peer (to remove the IP from its config)
if err := ConfigurePeer(pm.device, peer, pm.privateKey, peer.Endpoint, pm.peerMonitor); err != nil {
if err := ConfigurePeer(pm.device, peer, pm.privateKey, peer.Endpoint); err != nil {
return err
}
// If another peer was promoted to owner, update their WireGuard config
if promoted && newOwner >= 0 {
if newOwnerPeer, exists := pm.peers[newOwner]; exists {
if err := ConfigurePeer(pm.device, newOwnerPeer, pm.privateKey, newOwnerPeer.Endpoint, pm.peerMonitor); err != nil {
if err := ConfigurePeer(pm.device, newOwnerPeer, pm.privateKey, newOwnerPeer.Endpoint); err != nil {
logger.Error("Failed to promote peer %d for IP %s: %v", newOwner, cidr, err)
} else {
logger.Info("Promoted peer %d to owner of IP %s", newOwner, cidr)
@@ -626,3 +687,32 @@ endpoint=%s:21820`, peer.PublicKey, formattedEndpoint)
logger.Info("Adjusted peer %d to point to relay!\n", siteId)
}
// Start starts the peer monitor
func (pm *PeerManager) Start() {
if pm.peerMonitor != nil {
pm.peerMonitor.Start()
}
}
// Stop stops the peer monitor
func (pm *PeerManager) Stop() {
if pm.peerMonitor != nil {
pm.peerMonitor.Stop()
}
}
// Close stops the peer monitor and cleans up resources
func (pm *PeerManager) Close() {
if pm.peerMonitor != nil {
pm.peerMonitor.Close()
pm.peerMonitor = nil
}
}
// SetHolepunchStatusCallback sets the callback for holepunch status changes
func (pm *PeerManager) SetHolepunchStatusCallback(callback HolepunchStatusCallback) {
if pm.peerMonitor != nil {
pm.peerMonitor.SetHolepunchStatusCallback(callback)
}
}

View File

@@ -1,4 +1,4 @@
package peermonitor
package monitor
import (
"context"
@@ -31,19 +31,9 @@ type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
// HolepunchStatusCallback is called when holepunch connection status changes
type HolepunchStatusCallback func(siteID int, endpoint string, connected bool, rtt time.Duration)
// WireGuardConfig holds the WireGuard configuration for a peer
type WireGuardConfig struct {
SiteID int
PublicKey string
ServerIP string
Endpoint string
PrimaryRelay string // The primary relay endpoint
}
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
type PeerMonitor struct {
monitors map[int]*Client
configs map[int]*WireGuardConfig
callback PeerMonitorCallback
mutex sync.Mutex
running bool
@@ -79,7 +69,6 @@ func NewPeerMonitor(callback PeerMonitorCallback, wsClient *websocket.Client, mi
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
monitors: make(map[int]*Client),
configs: make(map[int]*WireGuardConfig),
callback: callback,
interval: 1 * time.Second, // Default check interval
timeout: 2500 * time.Millisecond,
@@ -149,7 +138,7 @@ func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
}
// AddPeer adds a new peer to monitor
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error {
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string) error {
pm.mutex.Lock()
defer pm.mutex.Unlock()
@@ -168,7 +157,8 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardC
client.SetMaxAttempts(pm.maxAttempts)
pm.monitors[siteID] = client
pm.configs[siteID] = wgConfig
pm.holepunchEndpoints[siteID] = endpoint
pm.holepunchStatus[siteID] = false // Initially unknown/disconnected
if pm.running {
if err := client.StartMonitor(func(status ConnectionStatus) {
@@ -192,7 +182,6 @@ func (pm *PeerMonitor) removePeerUnlocked(siteID int) {
client.StopMonitor()
client.Close()
delete(pm.monitors, siteID)
delete(pm.configs, siteID)
}
// RemovePeer stops monitoring a peer and removes it from the monitor
@@ -289,26 +278,6 @@ func (pm *PeerMonitor) SetHolepunchStatusCallback(callback HolepunchStatusCallba
pm.holepunchStatusCallback = callback
}
// AddHolepunchEndpoint adds an endpoint to monitor via holepunch magic packets
func (pm *PeerMonitor) AddHolepunchEndpoint(siteID int, endpoint string) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.holepunchEndpoints[siteID] = endpoint
pm.holepunchStatus[siteID] = false // Initially unknown/disconnected
logger.Info("Added holepunch monitoring for site %d at %s", siteID, endpoint)
}
// RemoveHolepunchEndpoint removes an endpoint from holepunch monitoring
func (pm *PeerMonitor) RemoveHolepunchEndpoint(siteID int) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
delete(pm.holepunchEndpoints, siteID)
delete(pm.holepunchStatus, siteID)
logger.Info("Removed holepunch monitoring for site %d", siteID)
}
// startHolepunchMonitor starts the holepunch connection monitoring
// Note: This function assumes the mutex is already held by the caller (called from Start())
func (pm *PeerMonitor) startHolepunchMonitor() error {

View File

@@ -1,4 +1,4 @@
package peermonitor
package monitor
import (
"context"

View File

@@ -10,6 +10,7 @@ type PeerAction struct {
type SiteConfig struct {
SiteId int `json:"siteId"`
Endpoint string `json:"endpoint,omitempty"`
RelayEndpoint string `json:"relayEndpoint,omitempty"`
PublicKey string `json:"publicKey,omitempty"`
ServerIP string `json:"serverIP,omitempty"`
ServerPort uint16 `json:"serverPort,omitempty"`

View File

@@ -2,19 +2,16 @@ package peers
import (
"fmt"
"net"
"strconv"
"strings"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/peermonitor"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string, peerMonitor *peermonitor.PeerMonitor) error {
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
siteHost, err := util.ResolveDomain(formatEndpoint(siteConfig.Endpoint))
if err != nil {
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
@@ -68,38 +65,11 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
}
// Set up peer monitoring
if peerMonitor != nil {
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
logger.Debug("Resolving primary relay %s for peer", endpoint)
primaryRelay, err := util.ResolveDomain(endpoint) // Using global endpoint variable
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint for peer: %v", err)
}
wgConfig := &peermonitor.WireGuardConfig{
SiteID: siteConfig.SiteId,
PublicKey: util.FixKey(siteConfig.PublicKey),
ServerIP: strings.Split(siteConfig.ServerIP, "/")[0],
Endpoint: siteConfig.Endpoint,
PrimaryRelay: primaryRelay,
}
err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig)
if err != nil {
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
} else {
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
}
}
return nil
}
// RemovePeer removes a peer from the WireGuard device
func RemovePeer(dev *device.Device, siteId int, publicKey string, peerMonitor *peermonitor.PeerMonitor) error {
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
// Construct WireGuard config to remove the peer
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
@@ -113,12 +83,6 @@ func RemovePeer(dev *device.Device, siteId int, publicKey string, peerMonitor *p
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
}
// Stop monitoring this peer
if peerMonitor != nil {
peerMonitor.RemovePeer(siteId)
logger.Info("Stopped monitoring for site %d", siteId)
}
return nil
}