Merge pull request #57 from fosrl/dev

Add robust client connectivity support
This commit is contained in:
Owen Schwartz
2025-12-08 12:05:52 -05:00
committed by GitHub
42 changed files with 7627 additions and 2600 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,2 @@
olm
.DS_Store
bin/

View File

@@ -10,7 +10,7 @@ docker-build-release:
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:$(tag) -f Dockerfile --push .
local:
CGO_ENABLED=0 go build -o olm
CGO_ENABLED=0 go build -o bin/olm
build:
docker build -t fosrl/olm:latest .

503
api/api.go Normal file
View File

@@ -0,0 +1,503 @@
package api
import (
"encoding/json"
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
)
// ConnectionRequest defines the structure for an incoming connection request
type ConnectionRequest struct {
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"`
MTU int `json:"mtu,omitempty"`
DNS string `json:"dns,omitempty"`
DNSProxyIP string `json:"dnsProxyIP,omitempty"`
UpstreamDNS []string `json:"upstreamDNS,omitempty"`
InterfaceName string `json:"interfaceName,omitempty"`
Holepunch bool `json:"holepunch,omitempty"`
TlsClientCert string `json:"tlsClientCert,omitempty"`
PingInterval string `json:"pingInterval,omitempty"`
PingTimeout string `json:"pingTimeout,omitempty"`
OrgID string `json:"orgId,omitempty"`
}
// SwitchOrgRequest defines the structure for switching organizations
type SwitchOrgRequest struct {
OrgID string `json:"orgId"`
}
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
Name string `json:"name"`
Connected bool `json:"connected"`
RTT time.Duration `json:"rtt"`
LastSeen time.Time `json:"lastSeen"`
Endpoint string `json:"endpoint,omitempty"`
IsRelay bool `json:"isRelay"`
PeerIP string `json:"peerAddress,omitempty"`
HolepunchConnected bool `json:"holepunchConnected"`
}
// StatusResponse is returned by the status endpoint
type StatusResponse struct {
Connected bool `json:"connected"`
Registered bool `json:"registered"`
Terminated bool `json:"terminated"`
Version string `json:"version,omitempty"`
Agent string `json:"agent,omitempty"`
OrgID string `json:"orgId,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
}
// API represents the HTTP server and its state
type API struct {
addr string
socketPath string
listener net.Listener
server *http.Server
onConnect func(ConnectionRequest) error
onSwitchOrg func(SwitchOrgRequest) error
onDisconnect func() error
onExit func() error
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
isRegistered bool
isTerminated bool
version string
agent string
orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
func NewAPI(addr string) *API {
s := &API{
addr: addr,
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
func NewAPISocket(socketPath string) *API {
s := &API{
socketPath: socketPath,
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// SetHandlers sets the callback functions for handling API requests
func (s *API) SetHandlers(
onConnect func(ConnectionRequest) error,
onSwitchOrg func(SwitchOrgRequest) error,
onDisconnect func() error,
onExit func() error,
) {
s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg
s.onDisconnect = onDisconnect
s.onExit = onExit
}
// Start starts the HTTP server
func (s *API) Start() error {
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit)
mux.HandleFunc("/health", s.handleHealth)
s.server = &http.Server{
Handler: mux,
}
var err error
if s.socketPath != "" {
// Use platform-specific socket listener
s.listener, err = createSocketListener(s.socketPath)
if err != nil {
return fmt.Errorf("failed to create socket listener: %w", err)
}
logger.Info("Starting HTTP server on socket %s", s.socketPath)
} else {
// Use TCP listener
s.listener, err = net.Listen("tcp", s.addr)
if err != nil {
return fmt.Errorf("failed to create TCP listener: %w", err)
}
logger.Info("Starting HTTP server on %s", s.addr)
}
go func() {
if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server error: %v", err)
}
}()
return nil
}
// Stop stops the HTTP server
func (s *API) Stop() error {
logger.Info("Stopping api server")
// Close the server first, which will also close the listener gracefully
if s.server != nil {
s.server.Close()
}
// Clean up socket file if using Unix socket
if s.socketPath != "" {
cleanupSocket(s.socketPath)
}
return nil
}
func (s *API) AddPeerStatus(siteID int, siteName string, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Name = siteName
status.Connected = connected
status.RTT = rtt
status.LastSeen = time.Now()
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Connected = connected
status.RTT = rtt
status.LastSeen = time.Now()
status.Endpoint = endpoint
status.IsRelay = isRelay
}
func (s *API) RemovePeerStatus(siteID int) { // remove the peer from the status map
s.statusMu.Lock()
defer s.statusMu.Unlock()
delete(s.peerStatuses, siteID)
}
// SetConnectionStatus sets the overall connection status
func (s *API) SetConnectionStatus(isConnected bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isConnected = isConnected
if isConnected {
s.connectedAt = time.Now()
} else {
// Clear peer statuses when disconnected
s.peerStatuses = make(map[int]*PeerStatus)
}
}
func (s *API) SetRegistered(registered bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isRegistered = registered
}
func (s *API) SetTerminated(terminated bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isTerminated = terminated
}
// ClearPeerStatuses clears all peer statuses
func (s *API) ClearPeerStatuses() {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.peerStatuses = make(map[int]*PeerStatus)
}
// SetVersion sets the olm version
func (s *API) SetVersion(version string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.version = version
}
// SetAgent sets the olm agent
func (s *API) SetAgent(agent string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.agent = agent
}
// SetOrgID sets the organization ID
func (s *API) SetOrgID(orgID string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.orgID = orgID
}
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// UpdatePeerHolepunchStatus updates the holepunch connection status of a peer
func (s *API) UpdatePeerHolepunchStatus(siteID int, holepunchConnected bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.HolepunchConnected = holepunchConnected
}
// handleConnect handles the /connect endpoint
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// if we are already connected, reject new connection requests
s.statusMu.RLock()
alreadyConnected := s.isConnected
s.statusMu.RUnlock()
if alreadyConnected {
http.Error(w, "Already connected to a server. Disconnect first before connecting again.", http.StatusConflict)
return
}
var req ConnectionRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
return
}
// Call the connect handler if set
if s.onConnect != nil {
if err := s.onConnect(req); err != nil {
http.Error(w, fmt.Sprintf("Connection failed: %v", err), http.StatusInternalServerError)
return
}
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{
"status": "connection request accepted",
})
}
// handleStatus handles the /status endpoint
func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
s.statusMu.RLock()
defer s.statusMu.RUnlock()
resp := StatusResponse{
Connected: s.isConnected,
Registered: s.isRegistered,
Terminated: s.isTerminated,
Version: s.version,
Agent: s.agent,
OrgID: s.orgID,
PeerStatuses: s.peerStatuses,
NetworkSettings: network.GetSettings(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// handleHealth handles the /health endpoint
func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "ok",
})
}
// handleExit handles the /exit endpoint
func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("Received exit request via API")
// Return a success response first
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "shutdown initiated",
})
// Call the exit handler after responding, in a goroutine with a small delay
// to ensure the response is fully sent before shutdown begins
if s.onExit != nil {
go func() {
time.Sleep(100 * time.Millisecond)
if err := s.onExit(); err != nil {
logger.Error("Exit handler failed: %v", err)
}
}()
}
}
// handleSwitchOrg handles the /switch-org endpoint
func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req SwitchOrgRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.OrgID == "" {
http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
return
}
logger.Info("Received org switch request to orgId: %s", req.OrgID)
// Call the switch org handler if set
if s.onSwitchOrg != nil {
if err := s.onSwitchOrg(req); err != nil {
http.Error(w, fmt.Sprintf("Org switch failed: %v", err), http.StatusInternalServerError)
return
}
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "org switch request accepted",
})
}
// handleDisconnect handles the /disconnect endpoint
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// if we are already disconnected, reject new disconnect requests
s.statusMu.RLock()
alreadyDisconnected := !s.isConnected
s.statusMu.RUnlock()
if alreadyDisconnected {
http.Error(w, "Not currently connected to a server.", http.StatusConflict)
return
}
logger.Info("Received disconnect request via API")
// Call the disconnect handler if set
if s.onDisconnect != nil {
if err := s.onDisconnect(); err != nil {
http.Error(w, fmt.Sprintf("Disconnect failed: %v", err), http.StatusInternalServerError)
return
}
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "disconnect initiated",
})
}
func (s *API) GetStatus() StatusResponse {
return StatusResponse{
Connected: s.isConnected,
Registered: s.isRegistered,
Terminated: s.isTerminated,
Version: s.version,
Agent: s.agent,
OrgID: s.orgID,
PeerStatuses: s.peerStatuses,
NetworkSettings: network.GetSettings(),
}
}

50
api/api_unix.go Normal file
View File

@@ -0,0 +1,50 @@
//go:build !windows
// +build !windows
package api
import (
"fmt"
"net"
"os"
"path/filepath"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Unix domain socket listener
func createSocketListener(socketPath string) (net.Listener, error) {
// Ensure the directory exists
dir := filepath.Dir(socketPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create socket directory: %w", err)
}
// Remove existing socket file if it exists
if err := os.RemoveAll(socketPath); err != nil {
return nil, fmt.Errorf("failed to remove existing socket: %w", err)
}
listener, err := net.Listen("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("failed to listen on Unix socket: %w", err)
}
// Set socket permissions to allow access
if err := os.Chmod(socketPath, 0666); err != nil {
listener.Close()
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
}
logger.Debug("Created Unix socket at %s", socketPath)
return listener, nil
}
// cleanupSocket removes the Unix socket file
func cleanupSocket(socketPath string) {
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
logger.Error("Failed to remove socket file %s: %v", socketPath, err)
} else {
logger.Debug("Removed Unix socket at %s", socketPath)
}
}

41
api/api_windows.go Normal file
View File

@@ -0,0 +1,41 @@
//go:build windows
// +build windows
package api
import (
"fmt"
"net"
"github.com/Microsoft/go-winio"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Windows named pipe listener
func createSocketListener(pipePath string) (net.Listener, error) {
// Ensure the pipe path has the correct format
if pipePath[0] != '\\' {
pipePath = `\\.\pipe\` + pipePath
}
// Create a pipe configuration that allows everyone to write
config := &winio.PipeConfig{
// Set security descriptor to allow everyone full access
// This SDDL string grants full access to Everyone (WD) and to the current owner (OW)
SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)",
}
// Create a named pipe listener using go-winio with the configuration
listener, err := winio.ListenPipe(pipePath, config)
if err != nil {
return nil, fmt.Errorf("failed to listen on named pipe: %w", err)
}
logger.Debug("Created named pipe at %s with write access for everyone", pipePath)
return listener, nil
}
// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up
func cleanupSocket(pipePath string) {
logger.Debug("Named pipe %s will be automatically cleaned up", pipePath)
}

1145
common.go

File diff suppressed because it is too large Load Diff

268
config.go
View File

@@ -8,35 +8,43 @@ import (
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
// OlmConfig holds all configuration options for the Olm client
type OlmConfig struct {
// Connection settings
Endpoint string `json:"endpoint"`
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
ID string `json:"id"`
Secret string `json:"secret"`
OrgID string `json:"org"`
UserToken string `json:"userToken"`
// Network settings
MTU int `json:"mtu"`
DNS string `json:"dns"`
InterfaceName string `json:"interface"`
MTU int `json:"mtu"`
DNS string `json:"dns"`
UpstreamDNS []string `json:"upstreamDNS"`
InterfaceName string `json:"interface"`
// Logging
LogLevel string `json:"logLevel"`
// HTTP server
EnableHTTP bool `json:"enableHttp"`
EnableAPI bool `json:"enableApi"`
HTTPAddr string `json:"httpAddr"`
SocketPath string `json:"socketPath"`
// Ping settings
PingInterval string `json:"pingInterval"`
PingTimeout string `json:"pingTimeout"`
// Advanced
Holepunch bool `json:"holepunch"`
TlsClientCert string `json:"tlsClientCert"`
DisableHolepunch bool `json:"disableHolepunch"`
TlsClientCert string `json:"tlsClientCert"`
OverrideDNS bool `json:"overrideDNS"`
DisableRelay bool `json:"disableRelay"`
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
// Parsed values (not in JSON)
PingIntervalDuration time.Duration `json:"-"`
@@ -44,6 +52,8 @@ type OlmConfig struct {
// Source tracking (not in JSON)
sources map[string]string `json:"-"`
Version string
}
// ConfigSource tracks where each config value came from
@@ -58,29 +68,45 @@ const (
// DefaultConfig returns a config with default values
func DefaultConfig() *OlmConfig {
// Set OS-specific socket path
var socketPath string
switch runtime.GOOS {
case "windows":
socketPath = "olm"
default: // darwin, linux, and others
socketPath = "/var/run/olm.sock"
}
config := &OlmConfig{
MTU: 1280,
DNS: "8.8.8.8",
LogLevel: "INFO",
InterfaceName: "olm",
EnableHTTP: false,
HTTPAddr: ":9452",
PingInterval: "3s",
PingTimeout: "5s",
Holepunch: false,
sources: make(map[string]string),
MTU: 1280,
DNS: "8.8.8.8",
UpstreamDNS: []string{"8.8.8.8:53"},
LogLevel: "INFO",
InterfaceName: "olm",
EnableAPI: false,
SocketPath: socketPath,
PingInterval: "3s",
PingTimeout: "5s",
DisableHolepunch: false,
// DoNotCreateNewClient: false,
sources: make(map[string]string),
}
// Track default sources
config.sources["mtu"] = string(SourceDefault)
config.sources["dns"] = string(SourceDefault)
config.sources["upstreamDNS"] = string(SourceDefault)
config.sources["logLevel"] = string(SourceDefault)
config.sources["interface"] = string(SourceDefault)
config.sources["enableHttp"] = string(SourceDefault)
config.sources["enableApi"] = string(SourceDefault)
config.sources["httpAddr"] = string(SourceDefault)
config.sources["socketPath"] = string(SourceDefault)
config.sources["pingInterval"] = string(SourceDefault)
config.sources["pingTimeout"] = string(SourceDefault)
config.sources["holepunch"] = string(SourceDefault)
config.sources["disableHolepunch"] = string(SourceDefault)
config.sources["overrideDNS"] = string(SourceDefault)
config.sources["disableRelay"] = string(SourceDefault)
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
return config
}
@@ -175,6 +201,14 @@ func loadConfigFromEnv(config *OlmConfig) {
config.Secret = val
config.sources["secret"] = string(SourceEnv)
}
if val := os.Getenv("ORG"); val != "" {
config.OrgID = val
config.sources["org"] = string(SourceEnv)
}
if val := os.Getenv("USER_TOKEN"); val != "" {
config.UserToken = val
config.sources["userToken"] = string(SourceEnv)
}
if val := os.Getenv("MTU"); val != "" {
if mtu, err := strconv.Atoi(val); err == nil {
config.MTU = mtu
@@ -187,6 +221,10 @@ func loadConfigFromEnv(config *OlmConfig) {
config.DNS = val
config.sources["dns"] = string(SourceEnv)
}
if val := os.Getenv("UPSTREAM_DNS"); val != "" {
config.UpstreamDNS = []string{val}
config.sources["upstreamDNS"] = string(SourceEnv)
}
if val := os.Getenv("LOG_LEVEL"); val != "" {
config.LogLevel = val
config.sources["logLevel"] = string(SourceEnv)
@@ -207,14 +245,30 @@ func loadConfigFromEnv(config *OlmConfig) {
config.PingTimeout = val
config.sources["pingTimeout"] = string(SourceEnv)
}
if val := os.Getenv("ENABLE_HTTP"); val == "true" {
config.EnableHTTP = true
config.sources["enableHttp"] = string(SourceEnv)
if val := os.Getenv("ENABLE_API"); val == "true" {
config.EnableAPI = true
config.sources["enableApi"] = string(SourceEnv)
}
if val := os.Getenv("HOLEPUNCH"); val == "true" {
config.Holepunch = true
config.sources["holepunch"] = string(SourceEnv)
if val := os.Getenv("SOCKET_PATH"); val != "" {
config.SocketPath = val
config.sources["socketPath"] = string(SourceEnv)
}
if val := os.Getenv("DISABLE_HOLEPUNCH"); val == "true" {
config.DisableHolepunch = true
config.sources["disableHolepunch"] = string(SourceEnv)
}
if val := os.Getenv("OVERRIDE_DNS"); val == "true" {
config.OverrideDNS = true
config.sources["overrideDNS"] = string(SourceEnv)
}
if val := os.Getenv("DISABLE_RELAY"); val == "true" {
config.DisableRelay = true
config.sources["disableRelay"] = string(SourceEnv)
}
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
// config.DoNotCreateNewClient = true
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
// }
}
// loadConfigFromCLI loads configuration from command-line arguments
@@ -223,33 +277,48 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
// Store original values to detect changes
origValues := map[string]interface{}{
"endpoint": config.Endpoint,
"id": config.ID,
"secret": config.Secret,
"mtu": config.MTU,
"dns": config.DNS,
"logLevel": config.LogLevel,
"interface": config.InterfaceName,
"httpAddr": config.HTTPAddr,
"pingInterval": config.PingInterval,
"pingTimeout": config.PingTimeout,
"enableHttp": config.EnableHTTP,
"holepunch": config.Holepunch,
"endpoint": config.Endpoint,
"id": config.ID,
"secret": config.Secret,
"org": config.OrgID,
"userToken": config.UserToken,
"mtu": config.MTU,
"dns": config.DNS,
"upstreamDNS": fmt.Sprintf("%v", config.UpstreamDNS),
"logLevel": config.LogLevel,
"interface": config.InterfaceName,
"httpAddr": config.HTTPAddr,
"socketPath": config.SocketPath,
"pingInterval": config.PingInterval,
"pingTimeout": config.PingTimeout,
"enableApi": config.EnableAPI,
"disableHolepunch": config.DisableHolepunch,
"overrideDNS": config.OverrideDNS,
"disableRelay": config.DisableRelay,
// "doNotCreateNewClient": config.DoNotCreateNewClient,
}
// Define flags
serviceFlags.StringVar(&config.Endpoint, "endpoint", config.Endpoint, "Endpoint of your Pangolin server")
serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID")
serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret")
serviceFlags.StringVar(&config.OrgID, "org", config.OrgID, "Organization ID")
serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)")
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
var upstreamDNSFlag string
serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8:53)")
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)")
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
serviceFlags.BoolVar(&config.EnableHTTP, "enable-http", config.EnableHTTP, "Enable HTTP server for receiving connection requests")
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings")
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
version := serviceFlags.Bool("version", false, "Print the version")
showConfig := serviceFlags.Bool("show-config", false, "Show configuration sources and exit")
@@ -259,6 +328,16 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
return false, false, err
}
// Parse upstream DNS flag if provided
if upstreamDNSFlag != "" {
config.UpstreamDNS = []string{}
for _, dns := range splitComma(upstreamDNSFlag) {
if dns != "" {
config.UpstreamDNS = append(config.UpstreamDNS, dns)
}
}
}
// Track which values were changed by CLI args
if config.Endpoint != origValues["endpoint"].(string) {
config.sources["endpoint"] = string(SourceCLI)
@@ -269,12 +348,21 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
if config.Secret != origValues["secret"].(string) {
config.sources["secret"] = string(SourceCLI)
}
if config.OrgID != origValues["org"].(string) {
config.sources["org"] = string(SourceCLI)
}
if config.UserToken != origValues["userToken"].(string) {
config.sources["userToken"] = string(SourceCLI)
}
if config.MTU != origValues["mtu"].(int) {
config.sources["mtu"] = string(SourceCLI)
}
if config.DNS != origValues["dns"].(string) {
config.sources["dns"] = string(SourceCLI)
}
if fmt.Sprintf("%v", config.UpstreamDNS) != origValues["upstreamDNS"].(string) {
config.sources["upstreamDNS"] = string(SourceCLI)
}
if config.LogLevel != origValues["logLevel"].(string) {
config.sources["logLevel"] = string(SourceCLI)
}
@@ -284,18 +372,30 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
if config.HTTPAddr != origValues["httpAddr"].(string) {
config.sources["httpAddr"] = string(SourceCLI)
}
if config.SocketPath != origValues["socketPath"].(string) {
config.sources["socketPath"] = string(SourceCLI)
}
if config.PingInterval != origValues["pingInterval"].(string) {
config.sources["pingInterval"] = string(SourceCLI)
}
if config.PingTimeout != origValues["pingTimeout"].(string) {
config.sources["pingTimeout"] = string(SourceCLI)
}
if config.EnableHTTP != origValues["enableHttp"].(bool) {
config.sources["enableHttp"] = string(SourceCLI)
if config.EnableAPI != origValues["enableApi"].(bool) {
config.sources["enableApi"] = string(SourceCLI)
}
if config.Holepunch != origValues["holepunch"].(bool) {
config.sources["holepunch"] = string(SourceCLI)
if config.DisableHolepunch != origValues["disableHolepunch"].(bool) {
config.sources["disableHolepunch"] = string(SourceCLI)
}
if config.OverrideDNS != origValues["overrideDNS"].(bool) {
config.sources["overrideDNS"] = string(SourceCLI)
}
if config.DisableRelay != origValues["disableRelay"].(bool) {
config.sources["disableRelay"] = string(SourceCLI)
}
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
// }
return *version, *showConfig, nil
}
@@ -348,6 +448,14 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.Secret = src.Secret
dest.sources["secret"] = string(SourceFile)
}
if src.OrgID != "" {
dest.OrgID = src.OrgID
dest.sources["org"] = string(SourceFile)
}
if src.UserToken != "" {
dest.UserToken = src.UserToken
dest.sources["userToken"] = string(SourceFile)
}
if src.MTU != 0 && src.MTU != 1280 {
dest.MTU = src.MTU
dest.sources["mtu"] = string(SourceFile)
@@ -356,6 +464,10 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.DNS = src.DNS
dest.sources["dns"] = string(SourceFile)
}
if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8:53]" {
dest.UpstreamDNS = src.UpstreamDNS
dest.sources["upstreamDNS"] = string(SourceFile)
}
if src.LogLevel != "" && src.LogLevel != "INFO" {
dest.LogLevel = src.LogLevel
dest.sources["logLevel"] = string(SourceFile)
@@ -368,6 +480,14 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.HTTPAddr = src.HTTPAddr
dest.sources["httpAddr"] = string(SourceFile)
}
if src.SocketPath != "" {
// Check if it's not the default for any OS
isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm"
if !isDefault {
dest.SocketPath = src.SocketPath
dest.sources["socketPath"] = string(SourceFile)
}
}
if src.PingInterval != "" && src.PingInterval != "3s" {
dest.PingInterval = src.PingInterval
dest.sources["pingInterval"] = string(SourceFile)
@@ -381,14 +501,26 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.sources["tlsClientCert"] = string(SourceFile)
}
// For booleans, we always take the source value if explicitly set
if src.EnableHTTP {
dest.EnableHTTP = src.EnableHTTP
dest.sources["enableHttp"] = string(SourceFile)
if src.EnableAPI {
dest.EnableAPI = src.EnableAPI
dest.sources["enableApi"] = string(SourceFile)
}
if src.Holepunch {
dest.Holepunch = src.Holepunch
dest.sources["holepunch"] = string(SourceFile)
if src.DisableHolepunch {
dest.DisableHolepunch = src.DisableHolepunch
dest.sources["disableHolepunch"] = string(SourceFile)
}
if src.OverrideDNS {
dest.OverrideDNS = src.OverrideDNS
dest.sources["overrideDNS"] = string(SourceFile)
}
if src.DisableRelay {
dest.DisableRelay = src.DisableRelay
dest.sources["disableRelay"] = string(SourceFile)
}
// if src.DoNotCreateNewClient {
// dest.DoNotCreateNewClient = src.DoNotCreateNewClient
// dest.sources["doNotCreateNewClient"] = string(SourceFile)
// }
}
// SaveConfig saves the current configuration to the config file
@@ -405,7 +537,7 @@ func SaveConfig(config *OlmConfig) error {
func (c *OlmConfig) ShowConfig() {
configPath := getOlmConfigPath()
fmt.Println("\n=== Olm Configuration ===\n")
fmt.Print("\n=== Olm Configuration ===\n\n")
fmt.Printf("Config File: %s\n", configPath)
// Check if config file exists
@@ -416,7 +548,7 @@ func (c *OlmConfig) ShowConfig() {
}
fmt.Println("\n--- Configuration Values ---")
fmt.Println("(Format: Setting = Value [source])\n")
fmt.Print("(Format: Setting = Value [source])\n\n")
// Helper to get source or default
getSource := func(key string) string {
@@ -445,21 +577,25 @@ func (c *OlmConfig) ShowConfig() {
fmt.Printf(" endpoint = %s [%s]\n", formatValue("endpoint", c.Endpoint), getSource("endpoint"))
fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id"))
fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret"))
fmt.Printf(" org = %s [%s]\n", formatValue("org", c.OrgID), getSource("org"))
fmt.Printf(" user-token = %s [%s]\n", formatValue("userToken", c.UserToken), getSource("userToken"))
// Network settings
fmt.Println("\nNetwork:")
fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu"))
fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns"))
fmt.Printf(" upstream-dns = %v [%s]\n", c.UpstreamDNS, getSource("upstreamDNS"))
fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface"))
// Logging
fmt.Println("\nLogging:")
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
// HTTP server
fmt.Println("\nHTTP Server:")
fmt.Printf(" enable-http = %v [%s]\n", c.EnableHTTP, getSource("enableHttp"))
// API server
fmt.Println("\nAPI Server:")
fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi"))
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath"))
// Timing
fmt.Println("\nTiming:")
@@ -468,9 +604,12 @@ func (c *OlmConfig) ShowConfig() {
// Advanced
fmt.Println("\nAdvanced:")
fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch"))
fmt.Printf(" disable-holepunch = %v [%s]\n", c.DisableHolepunch, getSource("disableHolepunch"))
fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS"))
fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay"))
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
if c.TlsClientCert != "" {
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
}
// Source legend
@@ -482,3 +621,16 @@ func (c *OlmConfig) ShowConfig() {
fmt.Println("\nPriority: cli > environment > file > default")
fmt.Println()
}
// splitComma splits a comma-separated string into a slice of trimmed strings
func splitComma(s string) []string {
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}

43
create_test_creds.py Normal file
View File

@@ -0,0 +1,43 @@
import requests
def create_olm(base_url, user_token, olm_name, user_id):
url = f"{base_url}/api/v1/user/{user_id}/olm"
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"User-Agent": "pangolin-cli",
"X-CSRF-Token": "x-csrf-protection",
"Cookie": f"p_session_token={user_token}"
}
payload = {"name": olm_name}
response = requests.put(url, json=payload, headers=headers)
response.raise_for_status()
data = response.json()
print(f"Response Data: {data}")
def create_client(base_url, user_token, client_name):
url = f"{base_url}/api/v1/api/clients"
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"User-Agent": "pangolin-cli",
"X-CSRF-Token": "x-csrf-protection",
"Cookie": f"p_session_token={user_token}"
}
payload = {"name": client_name}
response = requests.post(url, json=payload, headers=headers)
response.raise_for_status()
data = response.json()
print(f"Response Data: {data}")
if __name__ == "__main__":
# Example usage
base_url = input("Enter base URL (e.g., http://localhost:3000): ")
user_token = input("Enter user token: ")
user_id = input("Enter user ID: ")
olm_name = input("Enter OLM name: ")
client_name = input("Enter client name: ")
create_olm(base_url, user_token, olm_name, user_id)
# client_id = create_client(base_url, user_token, client_name)

331
device/middle_device.go Normal file
View File

@@ -0,0 +1,331 @@
package device
import (
"net/netip"
"os"
"sync"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun"
)
// PacketHandler processes intercepted packets and returns true if packet should be dropped
type PacketHandler func(packet []byte) bool
// FilterRule defines a rule for packet filtering
type FilterRule struct {
DestIP netip.Addr
Handler PacketHandler
}
// MiddleDevice wraps a TUN device with packet filtering capabilities
type MiddleDevice struct {
tun.Device
rules []FilterRule
mutex sync.RWMutex
readCh chan readResult
injectCh chan []byte
closed chan struct{}
}
type readResult struct {
bufs [][]byte
sizes []int
offset int
n int
err error
}
// NewMiddleDevice creates a new filtered TUN device wrapper
func NewMiddleDevice(device tun.Device) *MiddleDevice {
d := &MiddleDevice{
Device: device,
rules: make([]FilterRule, 0),
readCh: make(chan readResult),
injectCh: make(chan []byte, 100),
closed: make(chan struct{}),
}
go d.pump()
return d
}
func (d *MiddleDevice) pump() {
const defaultOffset = 16
batchSize := d.Device.BatchSize()
logger.Debug("MiddleDevice: pump started")
for {
// Check closed first with priority
select {
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel")
return
default:
}
// Allocate buffers for reading
// We allocate new buffers for each read to avoid race conditions
// since we pass them to the channel
bufs := make([][]byte, batchSize)
sizes := make([]int, batchSize)
for i := range bufs {
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
}
n, err := d.Device.Read(bufs, sizes, defaultOffset)
// Check closed again after read returns
select {
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
return
default:
}
// Now try to send the result
select {
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
case <-d.closed:
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
return
}
if err != nil {
logger.Debug("MiddleDevice: pump exiting due to read error: %v", err)
return
}
}
}
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
func (d *MiddleDevice) InjectOutbound(packet []byte) {
select {
case d.injectCh <- packet:
case <-d.closed:
}
}
// AddRule adds a packet filtering rule
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.rules = append(d.rules, FilterRule{
DestIP: destIP,
Handler: handler,
})
}
// RemoveRule removes all rules for a given destination IP
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
d.mutex.Lock()
defer d.mutex.Unlock()
newRules := make([]FilterRule, 0, len(d.rules))
for _, rule := range d.rules {
if rule.DestIP != destIP {
newRules = append(newRules, rule)
}
}
d.rules = newRules
}
// Close stops the device
func (d *MiddleDevice) Close() error {
select {
case <-d.closed:
// Already closed
return nil
default:
logger.Debug("MiddleDevice: Closing, signaling closed channel")
close(d.closed)
}
logger.Debug("MiddleDevice: Closing underlying TUN device")
err := d.Device.Close()
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
return err
}
// extractDestIP extracts destination IP from packet (fast path)
func extractDestIP(packet []byte) (netip.Addr, bool) {
if len(packet) < 20 {
return netip.Addr{}, false
}
version := packet[0] >> 4
switch version {
case 4:
if len(packet) < 20 {
return netip.Addr{}, false
}
// Destination IP is at bytes 16-19 for IPv4
ip := netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
return ip, true
case 6:
if len(packet) < 40 {
return netip.Addr{}, false
}
// Destination IP is at bytes 24-39 for IPv6
var ip16 [16]byte
copy(ip16[:], packet[24:40])
ip := netip.AddrFrom16(ip16)
return ip, true
}
return netip.Addr{}, false
}
// Read intercepts packets going UP from the TUN device (towards WireGuard)
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
// Check if already closed first (non-blocking)
select {
case <-d.closed:
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
return 0, os.ErrClosed
default:
}
// Now block waiting for data
select {
case res := <-d.readCh:
if res.err != nil {
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
return 0, res.err
}
// Copy packets from result to provided buffers
count := 0
for i := 0; i < res.n && i < len(bufs); i++ {
// Handle offset mismatch if necessary
// We assume the pump used defaultOffset (16)
// If caller asks for different offset, we need to shift
src := res.bufs[i]
srcOffset := res.offset
srcSize := res.sizes[i]
// Calculate where the packet data starts and ends in src
pktData := src[srcOffset : srcOffset+srcSize]
// Ensure dest buffer is large enough
if len(bufs[i]) < offset+len(pktData) {
continue // Skip if buffer too small
}
copy(bufs[i][offset:], pktData)
sizes[i] = len(pktData)
count++
}
n = count
case pkt := <-d.injectCh:
if len(bufs) == 0 {
return 0, nil
}
if len(bufs[0]) < offset+len(pkt) {
return 0, nil // Buffer too small
}
copy(bufs[0][offset:], pkt)
sizes[0] = len(pkt)
n = 1
case <-d.closed:
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
return 0, os.ErrClosed // Signal that device is closed
}
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
if len(rules) == 0 {
return n, nil
}
// Process packets and filter out handled ones
writeIdx := 0
for readIdx := 0; readIdx < n; readIdx++ {
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
continue
}
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
}
}
}
if !handled {
// Keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
}
}
return writeIdx, err
}
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
if len(rules) == 0 {
return d.Device.Write(bufs, offset)
}
// Filter packets going down
filteredBufs := make([][]byte, 0, len(bufs))
for _, buf := range bufs {
if len(buf) <= offset {
continue
}
packet := buf[offset:]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
filteredBufs = append(filteredBufs, buf)
continue
}
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
}
}
}
if !handled {
filteredBufs = append(filteredBufs, buf)
}
}
if len(filteredBufs) == 0 {
return len(bufs), nil // All packets were handled
}
return d.Device.Write(filteredBufs, offset)
}

View File

@@ -0,0 +1,102 @@
package device
import (
"net/netip"
"testing"
"github.com/fosrl/newt/util"
)
func TestExtractDestIP(t *testing.T) {
tests := []struct {
name string
packet []byte
wantIP string
wantOk bool
}{
{
name: "IPv4 packet",
packet: []byte{
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
0x0a, 0x1e, 0x1e, 0x1e, // Dest IP: 10.30.30.30
},
wantIP: "10.30.30.30",
wantOk: true,
},
{
name: "Too short packet",
packet: []byte{0x45, 0x00},
wantIP: "",
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotIP, gotOk := extractDestIP(tt.packet)
if gotOk != tt.wantOk {
t.Errorf("extractDestIP() ok = %v, want %v", gotOk, tt.wantOk)
return
}
if tt.wantOk {
wantAddr := netip.MustParseAddr(tt.wantIP)
if gotIP != wantAddr {
t.Errorf("extractDestIP() ip = %v, want %v", gotIP, wantAddr)
}
}
})
}
}
func TestGetProtocol(t *testing.T) {
tests := []struct {
name string
packet []byte
wantProto uint8
wantOk bool
}{
{
name: "UDP packet",
packet: []byte{
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01, // Protocol: UDP (17) at byte 9
0x0a, 0x1e, 0x1e, 0x1e,
},
wantProto: 17,
wantOk: true,
},
{
name: "Too short",
packet: []byte{0x45, 0x00},
wantProto: 0,
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotProto, gotOk := util.GetProtocol(tt.packet)
if gotOk != tt.wantOk {
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
return
}
if gotProto != tt.wantProto {
t.Errorf("GetProtocol() proto = %v, want %v", gotProto, tt.wantProto)
}
})
}
}
func BenchmarkExtractDestIP(b *testing.B) {
packet := []byte{
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
0x0a, 0x1e, 0x1e, 0x1e,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
extractDestIP(packet)
}
}

44
device/tun_unix.go Normal file
View File

@@ -0,0 +1,44 @@
//go:build !windows
package device
import (
"net"
"os"
"github.com/fosrl/newt/logger"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
)
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
dupTunFd, err := unix.Dup(int(tunFd))
if err != nil {
logger.Error("Unable to dup tun fd: %v", err)
return nil, err
}
err = unix.SetNonblock(dupTunFd, true)
if err != nil {
unix.Close(dupTunFd)
return nil, err
}
file := os.NewFile(uintptr(dupTunFd), "/dev/tun")
device, err := tun.CreateTUNFromFile(file, mtuInt)
if err != nil {
file.Close()
return nil, err
}
return device, nil
}
func UapiOpen(interfaceName string) (*os.File, error) {
return ipc.UAPIOpen(interfaceName)
}
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
return ipc.UAPIListen(interfaceName, fileUAPI)
}

View File

@@ -1,6 +1,6 @@
//go:build windows
package main
package device
import (
"errors"
@@ -11,15 +11,15 @@ import (
"golang.zx2c4.com/wireguard/tun"
)
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
return nil, errors.New("CreateTUNFromFile not supported on Windows")
}
func uapiOpen(interfaceName string) (*os.File, error) {
func UapiOpen(interfaceName string) (*os.File, error) {
return nil, nil
}
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
// On Windows, UAPIListen only takes one parameter
return ipc.UAPIListen(interfaceName)
}

457
dns/dns_proxy.go Normal file
View File

@@ -0,0 +1,457 @@
package dns
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/device"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
DNSPort = 53
)
// DNSProxy implements a DNS proxy using gvisor netstack
type DNSProxy struct {
stack *stack.Stack
ep *channel.Endpoint
proxyIP netip.Addr
upstreamDNS []string
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
recordStore *DNSRecordStore // Local DNS records
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewDNSProxy creates a new DNS proxy
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string) (*DNSProxy, error) {
proxyIP, err := PickIPFromSubnet(utilitySubnet)
if err != nil {
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
}
if len(upstreamDns) == 0 {
return nil, fmt.Errorf("at least one upstream DNS server must be specified")
}
ctx, cancel := context.WithCancel(context.Background())
proxy := &DNSProxy{
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
middleDevice: middleDevice,
upstreamDNS: upstreamDns,
recordStore: NewDNSRecordStore(),
ctx: ctx,
cancel: cancel,
}
// Create gvisor netstack
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
HandleLocal: true,
}
proxy.ep = channel.New(256, uint32(mtu), "")
proxy.stack = stack.New(stackOpts)
// Create NIC
if err := proxy.stack.CreateNIC(1, proxy.ep); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err)
}
// Add IP address
// Parse the proxy IP to get the octets
ipBytes := proxyIP.As4()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
}
if err := proxy.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("failed to add protocol address: %v", err)
}
// Add default route
proxy.stack.AddRoute(tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
})
return proxy, nil
}
// Start starts the DNS proxy and registers with the filter
func (p *DNSProxy) Start() error {
// Install packet filter rule
p.middleDevice.AddRule(p.proxyIP, p.handlePacket)
// Start DNS listener
p.wg.Add(2)
go p.runDNSListener()
go p.runPacketSender()
logger.Info("DNS proxy started on %s:%d", p.proxyIP.String(), DNSPort)
return nil
}
// Stop stops the DNS proxy
func (p *DNSProxy) Stop() {
if p.middleDevice != nil {
p.middleDevice.RemoveRule(p.proxyIP)
}
p.cancel()
// Close the endpoint first to unblock any pending Read() calls in runPacketSender
if p.ep != nil {
p.ep.Close()
}
p.wg.Wait()
if p.stack != nil {
p.stack.Close()
}
logger.Info("DNS proxy stopped")
}
func (p *DNSProxy) GetProxyIP() netip.Addr {
return p.proxyIP
}
// handlePacket is called by the filter for packets destined to DNS proxy IP
func (p *DNSProxy) handlePacket(packet []byte) bool {
if len(packet) < 20 {
return false // Don't drop, malformed
}
// Quick check for UDP port 53
proto, ok := util.GetProtocol(packet)
if !ok || proto != 17 { // 17 = UDP
return false // Not UDP, don't handle
}
port, ok := util.GetDestPort(packet)
if !ok || port != DNSPort {
return false // Not DNS port
}
// Inject packet into our netstack
version := packet[0] >> 4
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
switch version {
case 4:
p.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
case 6:
p.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
default:
pkb.DecRef()
return false
}
pkb.DecRef()
return true // Drop packet from normal path
}
// runDNSListener listens for DNS queries on the netstack
func (p *DNSProxy) runDNSListener() {
defer p.wg.Done()
// Create UDP listener using gonet
// Parse the proxy IP to get the octets
ipBytes := p.proxyIP.As4()
laddr := &tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFrom4(ipBytes),
Port: DNSPort,
}
udpConn, err := gonet.DialUDP(p.stack, laddr, nil, ipv4.ProtocolNumber)
if err != nil {
logger.Error("Failed to create DNS listener: %v", err)
return
}
defer udpConn.Close()
logger.Debug("DNS proxy listening on netstack")
// Handle DNS queries
buf := make([]byte, 4096)
for {
select {
case <-p.ctx.Done():
return
default:
}
udpConn.SetReadDeadline(time.Now().Add(1 * time.Second))
n, remoteAddr, err := udpConn.ReadFrom(buf)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
if p.ctx.Err() != nil {
return
}
logger.Error("DNS read error: %v", err)
continue
}
query := make([]byte, n)
copy(query, buf[:n])
// Handle query in background
go p.handleDNSQuery(udpConn, query, remoteAddr)
}
}
// handleDNSQuery processes a DNS query, checking local records first, then forwarding upstream
func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clientAddr net.Addr) {
// Parse the DNS query
msg := new(dns.Msg)
if err := msg.Unpack(queryData); err != nil {
logger.Error("Failed to parse DNS query: %v", err)
return
}
if len(msg.Question) == 0 {
logger.Debug("DNS query has no questions")
return
}
question := msg.Question[0]
logger.Debug("DNS query for %s (type %s)", question.Name, dns.TypeToString[question.Qtype])
// Check if we have local records for this query
var response *dns.Msg
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
response = p.checkLocalRecords(msg, question)
}
// If no local records, forward to upstream
if response == nil {
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
response = p.forwardToUpstream(msg)
}
if response == nil {
logger.Error("Failed to get DNS response for %s", question.Name)
return
}
// Pack and send response
responseData, err := response.Pack()
if err != nil {
logger.Error("Failed to pack DNS response: %v", err)
return
}
_, err = udpConn.WriteTo(responseData, clientAddr)
if err != nil {
logger.Error("Failed to send DNS response: %v", err)
}
}
// checkLocalRecords checks if we have local records for the query
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
var recordType RecordType
if question.Qtype == dns.TypeA {
recordType = RecordTypeA
} else if question.Qtype == dns.TypeAAAA {
recordType = RecordTypeAAAA
} else {
return nil
}
ips := p.recordStore.GetRecords(question.Name, recordType)
if len(ips) == 0 {
return nil
}
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
// Create response message
response := new(dns.Msg)
response.SetReply(query)
response.Authoritative = true
// Add answer records
for _, ip := range ips {
var rr dns.RR
if question.Qtype == dns.TypeA {
rr = &dns.A{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300, // 5 minutes
},
A: ip.To4(),
}
} else { // TypeAAAA
rr = &dns.AAAA{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300, // 5 minutes
},
AAAA: ip.To16(),
}
}
response.Answer = append(response.Answer, rr)
}
return response
}
// forwardToUpstream forwards a DNS query to upstream DNS servers
func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
// Try primary DNS server
response, err := p.queryUpstream(p.upstreamDNS[0], query, 2*time.Second)
if err != nil && len(p.upstreamDNS) > 1 {
// Try secondary DNS server
logger.Debug("Primary DNS failed, trying secondary: %v", err)
response, err = p.queryUpstream(p.upstreamDNS[1], query, 2*time.Second)
if err != nil {
logger.Error("Both DNS servers failed: %v", err)
return nil
}
}
return response
}
// queryUpstream sends a DNS query to upstream server using miekg/dns
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
client := &dns.Client{
Timeout: timeout,
}
response, _, err := client.Exchange(query, server)
if err != nil {
return nil, err
}
return response, nil
}
// runPacketSender sends packets from netstack back to TUN
func (p *DNSProxy) runPacketSender() {
defer p.wg.Done()
// MessageTransportHeaderSize is the offset used by WireGuard device
// for reading/writing packets to the TUN interface
const offset = 16
for {
select {
case <-p.ctx.Done():
return
default:
}
// Read packets from netstack endpoint
pkt := p.ep.Read()
if pkt == nil {
// No packet available, small sleep to avoid busy loop
time.Sleep(1 * time.Millisecond)
continue
}
// Extract packet data as slices
slices := pkt.AsSlices()
if len(slices) > 0 {
// Flatten all slices into a single packet buffer
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
// Allocate buffer with offset space for WireGuard transport header
// The first 'offset' bytes are reserved for the transport header
buf := make([]byte, offset+totalSize)
// Copy packet data after the offset
pos := offset
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Write packet to TUN device
// offset=16 indicates packet data starts at position 16 in the buffer
_, err := p.tunDevice.Write([][]byte{buf}, offset)
if err != nil {
logger.Error("Failed to write DNS response to TUN: %v", err)
}
}
pkt.DecRef()
}
}
// AddDNSRecord adds a DNS record to the local store
// domain should be a domain name (e.g., "example.com" or "example.com.")
// ip should be a valid IPv4 or IPv6 address
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
return p.recordStore.AddRecord(domain, ip)
}
// RemoveDNSRecord removes a DNS record from the local store
// If ip is nil, removes all records for the domain
func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
p.recordStore.RemoveRecord(domain, ip)
}
// GetDNSRecords returns all IP addresses for a domain and record type
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
return p.recordStore.GetRecords(domain, recordType)
}
// ClearDNSRecords removes all DNS records from the local store
func (p *DNSProxy) ClearDNSRecords() {
p.recordStore.Clear()
}
func PickIPFromSubnet(subnet string) (netip.Addr, error) {
// given a subnet in CIDR notation, pick the first usable IP
prefix, err := netip.ParsePrefix(subnet)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid subnet: %w", err)
}
// Pick the first usable IP address from the subnet
ip := prefix.Addr().Next()
if !ip.IsValid() {
return netip.Addr{}, fmt.Errorf("no valid IP address found in subnet: %s", subnet)
}
return ip, nil
}

166
dns/dns_records.go Normal file
View File

@@ -0,0 +1,166 @@
package dns
import (
"net"
"sync"
"github.com/miekg/dns"
)
// RecordType represents the type of DNS record
type RecordType uint16
const (
RecordTypeA RecordType = RecordType(dns.TypeA)
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
)
// DNSRecordStore manages local DNS records for A and AAAA queries
type DNSRecordStore struct {
mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
}
// NewDNSRecordStore creates a new DNS record store
func NewDNSRecordStore() *DNSRecordStore {
return &DNSRecordStore{
aRecords: make(map[string][]net.IP),
aaaaRecords: make(map[string][]net.IP),
}
}
// AddRecord adds a DNS record mapping (A or AAAA)
// domain should be in FQDN format (e.g., "example.com.")
// ip should be a valid IPv4 or IPv6 address
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
if ip.To4() != nil {
// IPv4 address
s.aRecords[domain] = append(s.aRecords[domain], ip)
} else if ip.To16() != nil {
// IPv6 address
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
} else {
return &net.ParseError{Type: "IP address", Text: ip.String()}
}
return nil
}
// RemoveRecord removes a specific DNS record mapping
// If ip is nil, removes all records for the domain
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
if ip == nil {
// Remove all records for this domain
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
return
}
if ip.To4() != nil {
// Remove specific IPv4 address
if ips, ok := s.aRecords[domain]; ok {
s.aRecords[domain] = removeIP(ips, ip)
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
}
}
} else if ip.To16() != nil {
// Remove specific IPv6 address
if ips, ok := s.aaaaRecords[domain]; ok {
s.aaaaRecords[domain] = removeIP(ips, ip)
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
}
}
}
}
// GetRecords returns all IP addresses for a domain and record type
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
s.mu.RLock()
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain)
var records []net.IP
switch recordType {
case RecordTypeA:
if ips, ok := s.aRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
}
case RecordTypeAAAA:
if ips, ok := s.aaaaRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
}
}
return records
}
// HasRecord checks if a domain has any records of the specified type
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
s.mu.RLock()
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain)
switch recordType {
case RecordTypeA:
_, ok := s.aRecords[domain]
return ok
case RecordTypeAAAA:
_, ok := s.aaaaRecords[domain]
return ok
}
return false
}
// Clear removes all records from the store
func (s *DNSRecordStore) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.aRecords = make(map[string][]net.IP)
s.aaaaRecords = make(map[string][]net.IP)
}
// removeIP is a helper function to remove a specific IP from a slice
func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
result := make([]net.IP, 0, len(ips))
for _, ip := range ips {
if !ip.Equal(toRemove) {
result = append(result, ip)
}
}
return result
}

View File

@@ -0,0 +1,68 @@
//go:build darwin && !ios
package olm
import (
"fmt"
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
// Uses scutil for DNS configuration
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
var err error
configurator, err = platform.NewDarwinDNSConfigurator()
if err != nil {
return fmt.Errorf("failed to create Darwin DNS configurator: %w", err)
}
logger.Info("Using Darwin scutil DNS configurator")
// Get current DNS servers before changing
currentDNS, err := configurator.GetCurrentDNS()
if err != nil {
logger.Warn("Could not get current DNS: %v", err)
} else {
logger.Info("Current DNS servers: %v", currentDNS)
}
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
dnsProxy.GetProxyIP(),
}
logger.Info("Setting DNS servers to: %v", newDNS)
originalDNS, err := configurator.SetDNS(newDNS)
if err != nil {
return fmt.Errorf("failed to set DNS: %w", err)
}
logger.Info("Original DNS servers backed up: %v", originalDNS)
return nil
}
// RestoreDNSOverride restores the original DNS configuration
func RestoreDNSOverride() error {
if configurator == nil {
logger.Debug("No DNS configurator to restore")
return nil
}
logger.Info("Restoring original DNS configuration")
if err := configurator.RestoreDNS(); err != nil {
return fmt.Errorf("failed to restore DNS: %w", err)
}
logger.Info("DNS configuration restored successfully")
return nil
}

View File

@@ -0,0 +1,105 @@
//go:build (linux && !android) || freebsd
package olm
import (
"fmt"
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
// Detects the DNS manager by reading /etc/resolv.conf and verifying runtime availability
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
var err error
// Detect which DNS manager is in use by checking /etc/resolv.conf and runtime availability
managerType := platform.DetectDNSManager(interfaceName)
logger.Info("Detected DNS manager: %s", managerType.String())
// Create configurator based on detected manager
switch managerType {
case platform.SystemdResolvedManager:
configurator, err = platform.NewSystemdResolvedDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using systemd-resolved DNS configurator")
return setDNS(dnsProxy, configurator)
}
logger.Warn("Failed to create systemd-resolved configurator: %v, falling back", err)
case platform.NetworkManagerManager:
configurator, err = platform.NewNetworkManagerDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using NetworkManager DNS configurator")
return setDNS(dnsProxy, configurator)
}
logger.Warn("Failed to create NetworkManager configurator: %v, falling back", err)
case platform.ResolvconfManager:
configurator, err = platform.NewResolvconfDNSConfigurator(interfaceName)
if err == nil {
logger.Info("Using resolvconf DNS configurator")
return setDNS(dnsProxy, configurator)
}
logger.Warn("Failed to create resolvconf configurator: %v, falling back", err)
}
// Fall back to direct file manipulation
configurator, err = platform.NewFileDNSConfigurator()
if err != nil {
return fmt.Errorf("failed to create file DNS configurator: %w", err)
}
logger.Info("Using file-based DNS configurator")
return setDNS(dnsProxy, configurator)
}
// setDNS is a helper function to set DNS and log the results
func setDNS(dnsProxy *dns.DNSProxy, conf platform.DNSConfigurator) error {
// Get current DNS servers before changing
currentDNS, err := conf.GetCurrentDNS()
if err != nil {
logger.Warn("Could not get current DNS: %v", err)
} else {
logger.Info("Current DNS servers: %v", currentDNS)
}
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
dnsProxy.GetProxyIP(),
}
logger.Info("Setting DNS servers to: %v", newDNS)
originalDNS, err := conf.SetDNS(newDNS)
if err != nil {
return fmt.Errorf("failed to set DNS: %w", err)
}
logger.Info("Original DNS servers backed up: %v", originalDNS)
return nil
}
// RestoreDNSOverride restores the original DNS configuration
func RestoreDNSOverride() error {
if configurator == nil {
logger.Debug("No DNS configurator to restore")
return nil
}
logger.Info("Restoring original DNS configuration")
if err := configurator.RestoreDNS(); err != nil {
return fmt.Errorf("failed to restore DNS: %w", err)
}
logger.Info("DNS configuration restored successfully")
return nil
}

View File

@@ -0,0 +1,68 @@
//go:build windows
package olm
import (
"fmt"
"net/netip"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
)
var configurator platform.DNSConfigurator
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
// Uses registry-based configuration (automatically extracts interface GUID)
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
if dnsProxy == nil {
return fmt.Errorf("DNS proxy is nil")
}
var err error
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
if err != nil {
return fmt.Errorf("failed to create Windows DNS configurator: %w", err)
}
logger.Info("Using Windows registry DNS configurator for interface: %s", interfaceName)
// Get current DNS servers before changing
currentDNS, err := configurator.GetCurrentDNS()
if err != nil {
logger.Warn("Could not get current DNS: %v", err)
} else {
logger.Info("Current DNS servers: %v", currentDNS)
}
// Set new DNS servers to point to our proxy
newDNS := []netip.Addr{
dnsProxy.GetProxyIP(),
}
logger.Info("Setting DNS servers to: %v", newDNS)
originalDNS, err := configurator.SetDNS(newDNS)
if err != nil {
return fmt.Errorf("failed to set DNS: %w", err)
}
logger.Info("Original DNS servers backed up: %v", originalDNS)
return nil
}
// RestoreDNSOverride restores the original DNS configuration
func RestoreDNSOverride() error {
if configurator == nil {
logger.Debug("No DNS configurator to restore")
return nil
}
logger.Info("Restoring original DNS configuration")
if err := configurator.RestoreDNS(); err != nil {
return fmt.Errorf("failed to restore DNS: %w", err)
}
logger.Info("DNS configuration restored successfully")
return nil
}

268
dns/platform/darwin.go Normal file
View File

@@ -0,0 +1,268 @@
//go:build darwin && !ios
package dns
import (
"bufio"
"bytes"
"fmt"
"net/netip"
"os/exec"
"strconv"
"strings"
"github.com/fosrl/newt/logger"
)
const (
scutilPath = "/usr/sbin/scutil"
dscacheutilPath = "/usr/bin/dscacheutil"
dnsStateKeyFormat = "State:/Network/Service/Olm-%s/DNS"
globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceFormat = "State:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains"
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
keyServerAddresses = "ServerAddresses"
keyServerPort = "ServerPort"
arraySymbol = "* "
digitSymbol = "# "
)
// DarwinDNSConfigurator manages DNS settings on macOS using scutil
type DarwinDNSConfigurator struct {
createdKeys map[string]struct{}
originalState *DNSState
}
// NewDarwinDNSConfigurator creates a new macOS DNS configurator
func NewDarwinDNSConfigurator() (*DarwinDNSConfigurator, error) {
return &DarwinDNSConfigurator{
createdKeys: make(map[string]struct{}),
}, nil
}
// Name returns the configurator name
func (d *DarwinDNSConfigurator) Name() string {
return "darwin-scutil"
}
// SetDNS sets the DNS servers and returns the original servers
func (d *DarwinDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
// Get current DNS settings before overriding
originalServers, err := d.GetCurrentDNS()
if err != nil {
return nil, fmt.Errorf("get current DNS: %w", err)
}
// Store original state
d.originalState = &DNSState{
OriginalServers: originalServers,
ConfiguratorName: d.Name(),
}
// Set new DNS servers
if err := d.applyDNSServers(servers); err != nil {
return nil, fmt.Errorf("apply DNS servers: %w", err)
}
// Flush DNS cache
if err := d.flushDNSCache(); err != nil {
// Non-fatal, just log
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
}
return originalServers, nil
}
// RestoreDNS restores the original DNS configuration
func (d *DarwinDNSConfigurator) RestoreDNS() error {
// Remove all created keys
for key := range d.createdKeys {
if err := d.removeKey(key); err != nil {
return fmt.Errorf("remove key %s: %w", key, err)
}
}
// Flush DNS cache
if err := d.flushDNSCache(); err != nil {
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
}
return nil
}
// GetCurrentDNS returns the currently configured DNS servers
func (d *DarwinDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
primaryServiceKey, err := d.getPrimaryServiceKey()
if err != nil || primaryServiceKey == "" {
return nil, fmt.Errorf("get primary service: %w", err)
}
dnsKey := fmt.Sprintf(primaryServiceFormat, primaryServiceKey)
cmd := fmt.Sprintf("show %s\n", dnsKey)
output, err := d.runScutil(cmd)
if err != nil {
return nil, fmt.Errorf("run scutil: %w", err)
}
servers := d.parseServerAddresses(output)
return servers, nil
}
// applyDNSServers applies the DNS server configuration
func (d *DarwinDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
if len(servers) == 0 {
return fmt.Errorf("no DNS servers provided")
}
key := fmt.Sprintf(dnsStateKeyFormat, "Override")
// Use SupplementalMatchDomains with empty string to match ALL domains
// This is the key to making DNS override work on macOS
// Setting SupplementalMatchDomainsNoSearch to 0 enables search domain behavior
err := d.addDNSState(key, "\"\"", servers[0], 53, true)
if err != nil {
return fmt.Errorf("set DNS servers: %w", err)
}
d.createdKeys[key] = struct{}{}
return nil
}
// addDNSState adds a DNS state entry with the specified configuration
func (d *DarwinDNSConfigurator) addDNSState(state, domains string, dnsServer netip.Addr, port int, enableSearch bool) error {
noSearch := "1"
if enableSearch {
noSearch = "0"
}
// Build the scutil command following NetBird's approach
var commands strings.Builder
commands.WriteString("d.init\n")
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keySupplementalMatchDomains, arraySymbol, domains))
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keySupplementalMatchDomainsNoSearch, digitSymbol, noSearch))
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keyServerAddresses, arraySymbol, dnsServer.String()))
commands.WriteString(fmt.Sprintf("d.add %s %s%s\n", keyServerPort, digitSymbol, strconv.Itoa(port)))
commands.WriteString(fmt.Sprintf("set %s\n", state))
if _, err := d.runScutil(commands.String()); err != nil {
return fmt.Errorf("applying state for domains %s, error: %w", domains, err)
}
logger.Info("Added DNS override with server %s:%d for domains: %s", dnsServer.String(), port, domains)
return nil
}
// removeKey removes a DNS configuration key
func (d *DarwinDNSConfigurator) removeKey(key string) error {
cmd := fmt.Sprintf("remove %s\n", key)
if _, err := d.runScutil(cmd); err != nil {
return fmt.Errorf("remove key: %w", err)
}
delete(d.createdKeys, key)
return nil
}
// getPrimaryServiceKey gets the primary network service key
func (d *DarwinDNSConfigurator) getPrimaryServiceKey() (string, error) {
cmd := fmt.Sprintf("show %s\n", globalIPv4State)
output, err := d.runScutil(cmd)
if err != nil {
return "", fmt.Errorf("run scutil: %w", err)
}
scanner := bufio.NewScanner(bytes.NewReader(output))
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "PrimaryService") {
parts := strings.Split(line, ":")
if len(parts) >= 2 {
return strings.TrimSpace(parts[1]), nil
}
}
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("scan output: %w", err)
}
return "", fmt.Errorf("primary service not found")
}
// parseServerAddresses parses DNS server addresses from scutil output
func (d *DarwinDNSConfigurator) parseServerAddresses(output []byte) []netip.Addr {
var servers []netip.Addr
inServerArray := false
scanner := bufio.NewScanner(bytes.NewReader(output))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "ServerAddresses : <array> {") {
inServerArray = true
continue
}
if line == "}" {
inServerArray = false
continue
}
if inServerArray {
// Line format: "0 : 8.8.8.8"
parts := strings.Split(line, " : ")
if len(parts) >= 2 {
if addr, err := netip.ParseAddr(parts[1]); err == nil {
servers = append(servers, addr)
}
}
}
}
return servers
}
// flushDNSCache flushes the system DNS cache
func (d *DarwinDNSConfigurator) flushDNSCache() error {
logger.Debug("Flushing dscacheutil cache")
cmd := exec.Command(dscacheutilPath, "-flushcache")
if err := cmd.Run(); err != nil {
return fmt.Errorf("flush cache: %w", err)
}
logger.Debug("Flushing mDNSResponder cache")
cmd = exec.Command("killall", "-HUP", "mDNSResponder")
if err := cmd.Run(); err != nil {
// Non-fatal, mDNSResponder might not be running
return nil
}
return nil
}
// runScutil executes an scutil command
func (d *DarwinDNSConfigurator) runScutil(commands string) ([]byte, error) {
// Wrap commands with open/quit
wrapped := fmt.Sprintf("open\n%squit\n", commands)
logger.Debug("Running scutil with commands:\n%s\n", wrapped)
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader(wrapped)
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("scutil command failed: %w, output: %s", err, output)
}
logger.Debug("scutil output:\n%s\n", output)
return output, nil
}

158
dns/platform/detect_unix.go Normal file
View File

@@ -0,0 +1,158 @@
//go:build (linux && !android) || freebsd
package dns
import (
"bufio"
"io"
"os"
"strings"
"github.com/fosrl/newt/logger"
)
const defaultResolvConfPath = "/etc/resolv.conf"
// DNSManagerType represents the type of DNS manager detected
type DNSManagerType int
const (
// UnknownManager indicates we couldn't determine the DNS manager
UnknownManager DNSManagerType = iota
// SystemdResolvedManager indicates systemd-resolved is managing DNS
SystemdResolvedManager
// NetworkManagerManager indicates NetworkManager is managing DNS
NetworkManagerManager
// ResolvconfManager indicates resolvconf is managing DNS
ResolvconfManager
// FileManager indicates direct file management (no DNS manager)
FileManager
)
// DetectDNSManagerFromFile reads /etc/resolv.conf to determine which DNS manager is in use
// This provides a hint based on comments in the file, similar to Netbird's approach
func DetectDNSManagerFromFile() DNSManagerType {
file, err := os.Open(defaultResolvConfPath)
if err != nil {
return UnknownManager
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
text := scanner.Text()
if len(text) == 0 {
continue
}
// If we hit a non-comment line, default to file-based
if text[0] != '#' {
return FileManager
}
// Check for DNS manager signatures in comments
if strings.Contains(text, "NetworkManager") {
return NetworkManagerManager
}
if strings.Contains(text, "systemd-resolved") {
return SystemdResolvedManager
}
if strings.Contains(text, "resolvconf") {
return ResolvconfManager
}
}
if err := scanner.Err(); err != nil && err != io.EOF {
return UnknownManager
}
// No indicators found, assume file-based management
return FileManager
}
// String returns a human-readable name for the DNS manager type
func (d DNSManagerType) String() string {
switch d {
case SystemdResolvedManager:
return "systemd-resolved"
case NetworkManagerManager:
return "NetworkManager"
case ResolvconfManager:
return "resolvconf"
case FileManager:
return "file"
default:
return "unknown"
}
}
// DetectDNSManager combines file detection with runtime availability checks
// to determine the best DNS configurator to use
func DetectDNSManager(interfaceName string) DNSManagerType {
// First check what the file suggests
fileHint := DetectDNSManagerFromFile()
// Verify the hint with runtime checks
switch fileHint {
case SystemdResolvedManager:
// Verify systemd-resolved is actually running
if IsSystemdResolvedAvailable() {
return SystemdResolvedManager
}
logger.Warn("dns platform: Found systemd-resolved but it is not running. Falling back to file...")
os.Exit(0)
return FileManager
case NetworkManagerManager:
// Verify NetworkManager is actually running
if IsNetworkManagerAvailable() {
// Check if NetworkManager is delegating to systemd-resolved
if !IsNetworkManagerDNSModeSupported() {
logger.Info("NetworkManager is delegating DNS to systemd-resolved, using systemd-resolved configurator")
if IsSystemdResolvedAvailable() {
return SystemdResolvedManager
}
}
return NetworkManagerManager
}
logger.Warn("dns platform: Found network manager but it is not running. Falling back to file...")
return FileManager
case ResolvconfManager:
// Verify resolvconf is available
if IsResolvconfAvailable() {
return ResolvconfManager
}
// If resolvconf is mentioned but not available, fall back to file
return FileManager
case FileManager:
// File suggests direct file management
// But we should still check if a manager is available that wasn't mentioned
if IsSystemdResolvedAvailable() && interfaceName != "" {
return SystemdResolvedManager
}
if IsNetworkManagerAvailable() && interfaceName != "" {
return NetworkManagerManager
}
if IsResolvconfAvailable() && interfaceName != "" {
return ResolvconfManager
}
return FileManager
default:
// Unknown - do runtime detection
if IsSystemdResolvedAvailable() && interfaceName != "" {
return SystemdResolvedManager
}
if IsNetworkManagerAvailable() && interfaceName != "" {
return NetworkManagerManager
}
if IsResolvconfAvailable() && interfaceName != "" {
return ResolvconfManager
}
return FileManager
}
}

192
dns/platform/file.go Normal file
View File

@@ -0,0 +1,192 @@
//go:build (linux && !android) || freebsd
package dns
import (
"fmt"
"net/netip"
"os"
"strings"
)
const (
resolvConfPath = "/etc/resolv.conf"
resolvConfBackupPath = "/etc/resolv.conf.olm.backup"
resolvConfHeader = "# Generated by Olm DNS Manager\n# Original file backed up to " + resolvConfBackupPath + "\n\n"
)
// FileDNSConfigurator manages DNS settings by directly modifying /etc/resolv.conf
type FileDNSConfigurator struct {
originalState *DNSState
}
// NewFileDNSConfigurator creates a new file-based DNS configurator
func NewFileDNSConfigurator() (*FileDNSConfigurator, error) {
return &FileDNSConfigurator{}, nil
}
// Name returns the configurator name
func (f *FileDNSConfigurator) Name() string {
return "file-resolv.conf"
}
// SetDNS sets the DNS servers and returns the original servers
func (f *FileDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
// Get current DNS settings before overriding
originalServers, err := f.GetCurrentDNS()
if err != nil {
return nil, fmt.Errorf("get current DNS: %w", err)
}
// Backup original resolv.conf if not already backed up
if !f.isBackupExists() {
if err := f.backupResolvConf(); err != nil {
return nil, fmt.Errorf("backup resolv.conf: %w", err)
}
}
// Store original state
f.originalState = &DNSState{
OriginalServers: originalServers,
ConfiguratorName: f.Name(),
}
// Write new resolv.conf
if err := f.writeResolvConf(servers); err != nil {
return nil, fmt.Errorf("write resolv.conf: %w", err)
}
return originalServers, nil
}
// RestoreDNS restores the original DNS configuration
func (f *FileDNSConfigurator) RestoreDNS() error {
if !f.isBackupExists() {
return fmt.Errorf("no backup file exists")
}
// Copy backup back to original location
if err := copyFile(resolvConfBackupPath, resolvConfPath); err != nil {
return fmt.Errorf("restore from backup: %w", err)
}
// Remove backup file
if err := os.Remove(resolvConfBackupPath); err != nil {
return fmt.Errorf("remove backup file: %w", err)
}
return nil
}
// GetCurrentDNS returns the currently configured DNS servers
func (f *FileDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
content, err := os.ReadFile(resolvConfPath)
if err != nil {
return nil, fmt.Errorf("read resolv.conf: %w", err)
}
return f.parseNameservers(string(content)), nil
}
// backupResolvConf creates a backup of the current resolv.conf
func (f *FileDNSConfigurator) backupResolvConf() error {
// Get file info for permissions
info, err := os.Stat(resolvConfPath)
if err != nil {
return fmt.Errorf("stat resolv.conf: %w", err)
}
if err := copyFile(resolvConfPath, resolvConfBackupPath); err != nil {
return fmt.Errorf("copy file: %w", err)
}
// Preserve permissions
if err := os.Chmod(resolvConfBackupPath, info.Mode()); err != nil {
return fmt.Errorf("chmod backup: %w", err)
}
return nil
}
// writeResolvConf writes a new resolv.conf with the specified DNS servers
func (f *FileDNSConfigurator) writeResolvConf(servers []netip.Addr) error {
if len(servers) == 0 {
return fmt.Errorf("no DNS servers provided")
}
// Get file info for permissions
info, err := os.Stat(resolvConfPath)
if err != nil {
return fmt.Errorf("stat resolv.conf: %w", err)
}
var content strings.Builder
content.WriteString(resolvConfHeader)
// Write nameservers
for _, server := range servers {
content.WriteString("nameserver ")
content.WriteString(server.String())
content.WriteString("\n")
}
// Write the file
if err := os.WriteFile(resolvConfPath, []byte(content.String()), info.Mode()); err != nil {
return fmt.Errorf("write resolv.conf: %w", err)
}
return nil
}
// isBackupExists checks if a backup file exists
func (f *FileDNSConfigurator) isBackupExists() bool {
_, err := os.Stat(resolvConfBackupPath)
return err == nil
}
// parseNameservers extracts nameserver entries from resolv.conf content
func (f *FileDNSConfigurator) parseNameservers(content string) []netip.Addr {
var servers []netip.Addr
lines := strings.Split(content, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
// Skip comments and empty lines
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Look for nameserver lines
if strings.HasPrefix(line, "nameserver") {
fields := strings.Fields(line)
if len(fields) >= 2 {
if addr, err := netip.ParseAddr(fields[1]); err == nil {
servers = append(servers, addr)
}
}
}
}
return servers
}
// copyFile copies a file from src to dst
func copyFile(src, dst string) error {
content, err := os.ReadFile(src)
if err != nil {
return fmt.Errorf("read source: %w", err)
}
// Get source file permissions
info, err := os.Stat(src)
if err != nil {
return fmt.Errorf("stat source: %w", err)
}
if err := os.WriteFile(dst, content, info.Mode()); err != nil {
return fmt.Errorf("write destination: %w", err)
}
return nil
}

View File

@@ -0,0 +1,294 @@
//go:build (linux && !android) || freebsd
package dns
import (
"context"
"errors"
"fmt"
"net/netip"
"os"
"strings"
"time"
dbus "github.com/godbus/dbus/v5"
)
const (
// NetworkManager D-Bus constants
networkManagerDest = "org.freedesktop.NetworkManager"
networkManagerDbusObjectNode = "/org/freedesktop/NetworkManager"
networkManagerDbusDNSManagerInterface = "org.freedesktop.NetworkManager.DnsManager"
networkManagerDbusDNSManagerObjectNode = networkManagerDbusObjectNode + "/DnsManager"
networkManagerDbusDNSManagerModeProperty = networkManagerDbusDNSManagerInterface + ".Mode"
networkManagerDbusVersionProperty = "org.freedesktop.NetworkManager.Version"
// NetworkManager dispatcher script path
networkManagerDispatcherDir = "/etc/NetworkManager/dispatcher.d"
networkManagerConfDir = "/etc/NetworkManager/conf.d"
networkManagerDNSConfFile = "olm-dns.conf"
networkManagerDispatcherFile = "01-olm-dns"
)
// NetworkManagerDNSConfigurator manages DNS settings using NetworkManager configuration files
// This approach works with unmanaged interfaces by modifying NetworkManager's global DNS settings
type NetworkManagerDNSConfigurator struct {
ifaceName string
originalState *DNSState
confPath string
dispatchPath string
}
// NewNetworkManagerDNSConfigurator creates a new NetworkManager DNS configurator
func NewNetworkManagerDNSConfigurator(ifaceName string) (*NetworkManagerDNSConfigurator, error) {
if ifaceName == "" {
return nil, fmt.Errorf("interface name is required")
}
// Check that NetworkManager conf.d directory exists
if _, err := os.Stat(networkManagerConfDir); os.IsNotExist(err) {
return nil, fmt.Errorf("NetworkManager conf.d directory not found: %s", networkManagerConfDir)
}
return &NetworkManagerDNSConfigurator{
ifaceName: ifaceName,
confPath: networkManagerConfDir + "/" + networkManagerDNSConfFile,
dispatchPath: networkManagerDispatcherDir + "/" + networkManagerDispatcherFile,
}, nil
}
// Name returns the configurator name
func (n *NetworkManagerDNSConfigurator) Name() string {
return "network-manager"
}
// SetDNS sets the DNS servers and returns the original servers
func (n *NetworkManagerDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
// Get current DNS settings before overriding
originalServers, err := n.GetCurrentDNS()
if err != nil {
// If we can't get current DNS, proceed anyway
originalServers = []netip.Addr{}
}
// Store original state
n.originalState = &DNSState{
OriginalServers: originalServers,
ConfiguratorName: n.Name(),
}
// Apply new DNS servers
if err := n.applyDNSServers(servers); err != nil {
return nil, fmt.Errorf("apply DNS servers: %w", err)
}
return originalServers, nil
}
// RestoreDNS restores the original DNS configuration
func (n *NetworkManagerDNSConfigurator) RestoreDNS() error {
// Remove our configuration file
if err := os.Remove(n.confPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove DNS config file: %w", err)
}
// Reload NetworkManager to apply the change
if err := n.reloadNetworkManager(); err != nil {
return fmt.Errorf("reload NetworkManager: %w", err)
}
return nil
}
// GetCurrentDNS returns the currently configured DNS servers by reading /etc/resolv.conf
func (n *NetworkManagerDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
content, err := os.ReadFile("/etc/resolv.conf")
if err != nil {
return nil, fmt.Errorf("read resolv.conf: %w", err)
}
var servers []netip.Addr
lines := strings.Split(string(content), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "nameserver") {
fields := strings.Fields(line)
if len(fields) >= 2 {
if addr, err := netip.ParseAddr(fields[1]); err == nil {
servers = append(servers, addr)
}
}
}
}
return servers, nil
}
// applyDNSServers applies DNS server configuration via NetworkManager config file
func (n *NetworkManagerDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
if len(servers) == 0 {
return fmt.Errorf("no DNS servers provided")
}
// Build DNS server list
var dnsServers []string
for _, server := range servers {
dnsServers = append(dnsServers, server.String())
}
// Create NetworkManager configuration file that sets global DNS
// This overrides DNS for all connections
configContent := fmt.Sprintf(`# Generated by Olm DNS Manager - DO NOT EDIT
# This file configures NetworkManager to use Olm's DNS proxy
[global-dns-domain-*]
servers=%s
`, strings.Join(dnsServers, ","))
// Write the configuration file
if err := os.WriteFile(n.confPath, []byte(configContent), 0644); err != nil {
return fmt.Errorf("write DNS config file: %w", err)
}
// Reload NetworkManager to apply the new configuration
if err := n.reloadNetworkManager(); err != nil {
// Try to clean up
os.Remove(n.confPath)
return fmt.Errorf("reload NetworkManager: %w", err)
}
return nil
}
// reloadNetworkManager tells NetworkManager to reload its configuration
func (n *NetworkManagerDNSConfigurator) reloadNetworkManager() error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("connect to system bus: %w", err)
}
defer conn.Close()
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Call Reload method with flags=0 (reload everything)
// See: https://networkmanager.dev/docs/api/latest/gdbus-org.freedesktop.NetworkManager.html#gdbus-method-org-freedesktop-NetworkManager.Reload
err = obj.CallWithContext(ctx, networkManagerDest+".Reload", 0, uint32(0)).Store()
if err != nil {
return fmt.Errorf("call Reload: %w", err)
}
return nil
}
// IsNetworkManagerAvailable checks if NetworkManager is available and responsive
func IsNetworkManagerAvailable() bool {
conn, err := dbus.SystemBus()
if err != nil {
return false
}
defer conn.Close()
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// Try to ping NetworkManager
if err := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
return false
}
return true
}
// IsNetworkManagerDNSModeSupported checks if NetworkManager's DNS mode is one we can work with
// Some DNS modes delegate to other systems (like systemd-resolved) which we should use directly
func IsNetworkManagerDNSModeSupported() bool {
conn, err := dbus.SystemBus()
if err != nil {
return false
}
defer conn.Close()
obj := conn.Object(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
modeVariant, err := obj.GetProperty(networkManagerDbusDNSManagerModeProperty)
if err != nil {
// If we can't get the mode, assume it's not supported
return false
}
mode, ok := modeVariant.Value().(string)
if !ok {
return false
}
// If NetworkManager is delegating DNS to systemd-resolved, we should use
// systemd-resolved directly for better control
switch mode {
case "systemd-resolved":
// NetworkManager is delegating to systemd-resolved
// We should use systemd-resolved configurator instead
return false
case "dnsmasq", "unbound":
// NetworkManager is using a local resolver that it controls
// We can configure DNS through NetworkManager
return true
case "default", "none", "":
// NetworkManager is managing DNS directly or not at all
// We can configure DNS through NetworkManager
return true
default:
// Unknown mode, try to use it
return true
}
}
// GetNetworkManagerDNSMode returns the current DNS mode of NetworkManager
func GetNetworkManagerDNSMode() (string, error) {
conn, err := dbus.SystemBus()
if err != nil {
return "", fmt.Errorf("connect to system bus: %w", err)
}
defer conn.Close()
obj := conn.Object(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
modeVariant, err := obj.GetProperty(networkManagerDbusDNSManagerModeProperty)
if err != nil {
return "", fmt.Errorf("get DNS mode property: %w", err)
}
mode, ok := modeVariant.Value().(string)
if !ok {
return "", errors.New("DNS mode is not a string")
}
return mode, nil
}
// GetNetworkManagerVersion returns the version of NetworkManager
func GetNetworkManagerVersion() (string, error) {
conn, err := dbus.SystemBus()
if err != nil {
return "", fmt.Errorf("connect to system bus: %w", err)
}
defer conn.Close()
obj := conn.Object(networkManagerDest, networkManagerDbusObjectNode)
versionVariant, err := obj.GetProperty(networkManagerDbusVersionProperty)
if err != nil {
return "", fmt.Errorf("get version property: %w", err)
}
version, ok := versionVariant.Value().(string)
if !ok {
return "", errors.New("version is not a string")
}
return version, nil
}

192
dns/platform/resolvconf.go Normal file
View File

@@ -0,0 +1,192 @@
//go:build (linux && !android) || freebsd
package dns
import (
"bytes"
"fmt"
"net/netip"
"os/exec"
"strings"
)
const resolvconfCommand = "resolvconf"
// ResolvconfDNSConfigurator manages DNS settings using the resolvconf utility
type ResolvconfDNSConfigurator struct {
ifaceName string
implType string
originalState *DNSState
}
// NewResolvconfDNSConfigurator creates a new resolvconf DNS configurator
func NewResolvconfDNSConfigurator(ifaceName string) (*ResolvconfDNSConfigurator, error) {
if ifaceName == "" {
return nil, fmt.Errorf("interface name is required")
}
// Detect resolvconf implementation type
implType, err := detectResolvconfType()
if err != nil {
return nil, fmt.Errorf("detect resolvconf type: %w", err)
}
return &ResolvconfDNSConfigurator{
ifaceName: ifaceName,
implType: implType,
}, nil
}
// Name returns the configurator name
func (r *ResolvconfDNSConfigurator) Name() string {
return fmt.Sprintf("resolvconf-%s", r.implType)
}
// SetDNS sets the DNS servers and returns the original servers
func (r *ResolvconfDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
// Get current DNS settings before overriding
originalServers, err := r.GetCurrentDNS()
if err != nil {
// If we can't get current DNS, proceed anyway
originalServers = []netip.Addr{}
}
// Store original state
r.originalState = &DNSState{
OriginalServers: originalServers,
ConfiguratorName: r.Name(),
}
// Apply new DNS servers
if err := r.applyDNSServers(servers); err != nil {
return nil, fmt.Errorf("apply DNS servers: %w", err)
}
return originalServers, nil
}
// RestoreDNS restores the original DNS configuration
func (r *ResolvconfDNSConfigurator) RestoreDNS() error {
var cmd *exec.Cmd
switch r.implType {
case "openresolv":
// Force delete with -f
cmd = exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
default:
cmd = exec.Command(resolvconfCommand, "-d", r.ifaceName)
}
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("delete resolvconf config: %w, output: %s", err, out)
}
return nil
}
// GetCurrentDNS returns the currently configured DNS servers
func (r *ResolvconfDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
// resolvconf doesn't provide a direct way to query per-interface DNS
// We can try to read /etc/resolv.conf but it's merged from all sources
content, err := exec.Command(resolvconfCommand, "-l").CombinedOutput()
if err != nil {
// Fall back to reading resolv.conf
return readResolvConfServers()
}
// Parse the output (format varies by implementation)
return parseResolvconfOutput(string(content)), nil
}
// applyDNSServers applies DNS server configuration via resolvconf
func (r *ResolvconfDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
if len(servers) == 0 {
return fmt.Errorf("no DNS servers provided")
}
// Build resolv.conf content
var content bytes.Buffer
content.WriteString("# Generated by Olm DNS Manager\n\n")
for _, server := range servers {
content.WriteString("nameserver ")
content.WriteString(server.String())
content.WriteString("\n")
}
// Apply via resolvconf
var cmd *exec.Cmd
switch r.implType {
case "openresolv":
// OpenResolv supports exclusive mode with -x
cmd = exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
default:
cmd = exec.Command(resolvconfCommand, "-a", r.ifaceName)
}
cmd.Stdin = &content
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("apply resolvconf config: %w, output: %s", err, out)
}
return nil
}
// detectResolvconfType detects which resolvconf implementation is being used
func detectResolvconfType() (string, error) {
cmd := exec.Command(resolvconfCommand, "--version")
out, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("detect resolvconf type: %w", err)
}
if strings.Contains(string(out), "openresolv") {
return "openresolv", nil
}
return "resolvconf", nil
}
// parseResolvconfOutput parses resolvconf -l output for DNS servers
func parseResolvconfOutput(output string) []netip.Addr {
var servers []netip.Addr
lines := strings.Split(output, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
// Skip comments and empty lines
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Look for nameserver lines
if strings.HasPrefix(line, "nameserver") {
fields := strings.Fields(line)
if len(fields) >= 2 {
if addr, err := netip.ParseAddr(fields[1]); err == nil {
servers = append(servers, addr)
}
}
}
}
return servers
}
// readResolvConfServers reads DNS servers from /etc/resolv.conf
func readResolvConfServers() ([]netip.Addr, error) {
cmd := exec.Command("cat", "/etc/resolv.conf")
out, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("read resolv.conf: %w", err)
}
return parseResolvconfOutput(string(out)), nil
}
// IsResolvconfAvailable checks if resolvconf is available
func IsResolvconfAvailable() bool {
cmd := exec.Command(resolvconfCommand, "--version")
return cmd.Run() == nil
}

286
dns/platform/systemd.go Normal file
View File

@@ -0,0 +1,286 @@
//go:build linux && !android
package dns
import (
"context"
"fmt"
"net"
"net/netip"
"time"
dbus "github.com/godbus/dbus/v5"
"golang.org/x/sys/unix"
)
const (
systemdResolvedDest = "org.freedesktop.resolve1"
systemdDbusObjectNode = "/org/freedesktop/resolve1"
systemdDbusManagerIface = "org.freedesktop.resolve1.Manager"
systemdDbusGetLinkMethod = systemdDbusManagerIface + ".GetLink"
systemdDbusFlushCachesMethod = systemdDbusManagerIface + ".FlushCaches"
systemdDbusLinkInterface = "org.freedesktop.resolve1.Link"
systemdDbusSetDNSMethod = systemdDbusLinkInterface + ".SetDNS"
systemdDbusSetDefaultRouteMethod = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethod = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethod = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusSetDNSOverTLSMethod = systemdDbusLinkInterface + ".SetDNSOverTLS"
systemdDbusRevertMethod = systemdDbusLinkInterface + ".Revert"
// RootZone is the root DNS zone that matches all queries
RootZone = "."
)
// systemdDbusDNSInput maps to (iay) dbus input for SetDNS method
type systemdDbusDNSInput struct {
Family int32
Address []byte
}
// systemdDbusDomainsInput maps to (sb) dbus input for SetDomains method
type systemdDbusDomainsInput struct {
Domain string
MatchOnly bool
}
// SystemdResolvedDNSConfigurator manages DNS settings using systemd-resolved D-Bus API
type SystemdResolvedDNSConfigurator struct {
ifaceName string
dbusLinkObject dbus.ObjectPath
originalState *DNSState
}
// NewSystemdResolvedDNSConfigurator creates a new systemd-resolved DNS configurator
func NewSystemdResolvedDNSConfigurator(ifaceName string) (*SystemdResolvedDNSConfigurator, error) {
// Get network interface
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
return nil, fmt.Errorf("get interface: %w", err)
}
// Connect to D-Bus
conn, err := dbus.SystemBus()
if err != nil {
return nil, fmt.Errorf("connect to system bus: %w", err)
}
defer conn.Close()
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
// Get the link object for this interface
var linkPath string
if err := obj.Call(systemdDbusGetLinkMethod, 0, iface.Index).Store(&linkPath); err != nil {
return nil, fmt.Errorf("get link: %w", err)
}
return &SystemdResolvedDNSConfigurator{
ifaceName: ifaceName,
dbusLinkObject: dbus.ObjectPath(linkPath),
}, nil
}
// Name returns the configurator name
func (s *SystemdResolvedDNSConfigurator) Name() string {
return "systemd-resolved"
}
// SetDNS sets the DNS servers and returns the original servers
func (s *SystemdResolvedDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
// Get current DNS settings before overriding
originalServers, err := s.GetCurrentDNS()
if err != nil {
// If we can't get current DNS, proceed anyway
originalServers = []netip.Addr{}
}
// Store original state
s.originalState = &DNSState{
OriginalServers: originalServers,
ConfiguratorName: s.Name(),
}
// Apply new DNS servers
if err := s.applyDNSServers(servers); err != nil {
return nil, fmt.Errorf("apply DNS servers: %w", err)
}
return originalServers, nil
}
// RestoreDNS restores the original DNS configuration
func (s *SystemdResolvedDNSConfigurator) RestoreDNS() error {
// Call Revert method to restore systemd-resolved defaults
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("connect to system bus: %w", err)
}
defer conn.Close()
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := obj.CallWithContext(ctx, systemdDbusRevertMethod, 0).Store(); err != nil {
return fmt.Errorf("revert DNS settings: %w", err)
}
// Flush DNS cache after reverting
if err := s.flushDNSCache(); err != nil {
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
}
return nil
}
// GetCurrentDNS returns the currently configured DNS servers
// Note: systemd-resolved doesn't easily expose current per-link DNS servers via D-Bus
// This is a placeholder that returns an empty list
func (s *SystemdResolvedDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
// systemd-resolved's D-Bus API doesn't have a simple way to query current DNS servers
// We would need to parse resolvectl status output or read from /run/systemd/resolve/
// For now, return empty list
return []netip.Addr{}, nil
}
// applyDNSServers applies DNS server configuration via systemd-resolved
func (s *SystemdResolvedDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
if len(servers) == 0 {
return fmt.Errorf("no DNS servers provided")
}
// Convert servers to systemd-resolved format
var dnsInputs []systemdDbusDNSInput
for _, server := range servers {
family := unix.AF_INET
if server.Is6() {
family = unix.AF_INET6
}
dnsInputs = append(dnsInputs, systemdDbusDNSInput{
Family: int32(family),
Address: server.AsSlice(),
})
}
// Connect to D-Bus
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("connect to system bus: %w", err)
}
defer conn.Close()
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Call SetDNS method to set the DNS servers
if err := obj.CallWithContext(ctx, systemdDbusSetDNSMethod, 0, dnsInputs).Store(); err != nil {
return fmt.Errorf("set DNS servers: %w", err)
}
// Set this interface as the default route for DNS
// This ensures all DNS queries prefer this interface
if err := s.callLinkMethod(systemdDbusSetDefaultRouteMethod, true); err != nil {
return fmt.Errorf("set default route: %w", err)
}
// Set the root zone "." as a match-only domain
// This captures ALL DNS queries and routes them through this interface
domainsInput := []systemdDbusDomainsInput{
{
Domain: RootZone,
MatchOnly: true,
},
}
if err := s.callLinkMethod(systemdDbusSetDomainsMethod, domainsInput); err != nil {
return fmt.Errorf("set domains: %w", err)
}
// Disable DNSSEC - we don't support it and it may be enabled by default
if err := s.callLinkMethod(systemdDbusSetDNSSECMethod, "no"); err != nil {
// Log warning but don't fail - this is optional
fmt.Printf("warning: failed to disable DNSSEC: %v\n", err)
}
// Disable DNSOverTLS - we don't support it and it may be enabled by default
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethod, "no"); err != nil {
// Log warning but don't fail - this is optional
fmt.Printf("warning: failed to disable DNSOverTLS: %v\n", err)
}
// Flush DNS cache to ensure new settings take effect immediately
if err := s.flushDNSCache(); err != nil {
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
}
return nil
}
// callLinkMethod is a helper to call methods on the link object
func (s *SystemdResolvedDNSConfigurator) callLinkMethod(method string, value any) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("connect to system bus: %w", err)
}
defer conn.Close()
obj := conn.Object(systemdResolvedDest, s.dbusLinkObject)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if value != nil {
if err := obj.CallWithContext(ctx, method, 0, value).Store(); err != nil {
return fmt.Errorf("call %s: %w", method, err)
}
} else {
if err := obj.CallWithContext(ctx, method, 0).Store(); err != nil {
return fmt.Errorf("call %s: %w", method, err)
}
}
return nil
}
// flushDNSCache flushes the systemd-resolved DNS cache
func (s *SystemdResolvedDNSConfigurator) flushDNSCache() error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("connect to system bus: %w", err)
}
defer conn.Close()
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, 0).Store(); err != nil {
return fmt.Errorf("flush caches: %w", err)
}
return nil
}
// IsSystemdResolvedAvailable checks if systemd-resolved is available and responsive
func IsSystemdResolvedAvailable() bool {
conn, err := dbus.SystemBus()
if err != nil {
return false
}
defer conn.Close()
obj := conn.Object(systemdResolvedDest, systemdDbusObjectNode)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// Try to ping systemd-resolved
if err := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
return false
}
return true
}

41
dns/platform/types.go Normal file
View File

@@ -0,0 +1,41 @@
package dns
import "net/netip"
// DNSConfigurator provides an interface for managing system DNS settings
// across different platforms and implementations
type DNSConfigurator interface {
// SetDNS overrides the system DNS servers with the specified ones
// Returns the original DNS servers that were replaced
SetDNS(servers []netip.Addr) ([]netip.Addr, error)
// RestoreDNS restores the original DNS servers
RestoreDNS() error
// GetCurrentDNS returns the currently configured DNS servers
GetCurrentDNS() ([]netip.Addr, error)
// Name returns the name of this configurator implementation
Name() string
}
// DNSConfig contains the configuration for DNS override
type DNSConfig struct {
// Servers is the list of DNS servers to use
Servers []netip.Addr
// SearchDomains is an optional list of search domains
SearchDomains []string
}
// DNSState represents the saved state of DNS configuration
type DNSState struct {
// OriginalServers are the DNS servers before override
OriginalServers []netip.Addr
// OriginalSearchDomains are the search domains before override
OriginalSearchDomains []string
// ConfiguratorName is the name of the configurator that saved this state
ConfiguratorName string
}

343
dns/platform/windows.go Normal file
View File

@@ -0,0 +1,343 @@
//go:build windows
package dns
import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
var (
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
)
const (
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServer = "NameServer"
interfaceConfigDhcpNameServer = "DhcpNameServer"
)
// WindowsDNSConfigurator manages DNS settings on Windows using the registry
type WindowsDNSConfigurator struct {
guid string
originalState *DNSState
}
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
// Accepts an interface name and extracts the GUID internally
func NewWindowsDNSConfigurator(interfaceName string) (*WindowsDNSConfigurator, error) {
if interfaceName == "" {
return nil, fmt.Errorf("interface name is required")
}
guid, err := getInterfaceGUIDString(interfaceName)
if err != nil {
return nil, fmt.Errorf("failed to get interface GUID: %w", err)
}
return &WindowsDNSConfigurator{
guid: guid,
}, nil
}
// newWindowsDNSConfiguratorFromGUID creates a configurator from a GUID string
// This is an internal function for use by DetectBestConfigurator
func newWindowsDNSConfiguratorFromGUID(guid string) (*WindowsDNSConfigurator, error) {
if guid == "" {
return nil, fmt.Errorf("interface GUID is required")
}
return &WindowsDNSConfigurator{
guid: guid,
}, nil
}
// Name returns the configurator name
func (w *WindowsDNSConfigurator) Name() string {
return "windows-registry"
}
// SetDNS sets the DNS servers and returns the original servers
func (w *WindowsDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
// Get current DNS settings before overriding
originalServers, err := w.GetCurrentDNS()
if err != nil {
return nil, fmt.Errorf("get current DNS: %w", err)
}
// Store original state
w.originalState = &DNSState{
OriginalServers: originalServers,
ConfiguratorName: w.Name(),
}
// Set new DNS servers
if err := w.setDNSServers(servers); err != nil {
return nil, fmt.Errorf("set DNS servers: %w", err)
}
// Flush DNS cache
if err := w.flushDNSCache(); err != nil {
// Non-fatal, just log
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
}
return originalServers, nil
}
// RestoreDNS restores the original DNS configuration
func (w *WindowsDNSConfigurator) RestoreDNS() error {
if w.originalState == nil {
return fmt.Errorf("no original state to restore")
}
// Clear the static DNS setting
if err := w.clearDNSServers(); err != nil {
return fmt.Errorf("clear DNS servers: %w", err)
}
// Flush DNS cache
if err := w.flushDNSCache(); err != nil {
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
}
return nil
}
// GetCurrentDNS returns the currently configured DNS servers
func (w *WindowsDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
regKey, err := w.getInterfaceRegistryKey(registry.QUERY_VALUE)
if err != nil {
return nil, fmt.Errorf("get interface registry key: %w", err)
}
defer closeKey(regKey)
// Try to get static DNS first
nameServer, _, err := regKey.GetStringValue(interfaceConfigNameServer)
if err == nil && nameServer != "" {
return w.parseServerList(nameServer), nil
}
// Fall back to DHCP DNS
dhcpNameServer, _, err := regKey.GetStringValue(interfaceConfigDhcpNameServer)
if err == nil && dhcpNameServer != "" {
return w.parseServerList(dhcpNameServer), nil
}
return []netip.Addr{}, nil
}
// setDNSServers sets the DNS servers in the registry
func (w *WindowsDNSConfigurator) setDNSServers(servers []netip.Addr) error {
if len(servers) == 0 {
return fmt.Errorf("no DNS servers provided")
}
regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE)
if err != nil {
return fmt.Errorf("get interface registry key: %w", err)
}
defer closeKey(regKey)
// Build comma-separated or space-separated list of servers
var serverList string
for i, server := range servers {
if i > 0 {
serverList += ","
}
serverList += server.String()
}
if err := regKey.SetStringValue(interfaceConfigNameServer, serverList); err != nil {
return fmt.Errorf("set NameServer: %w", err)
}
return nil
}
// clearDNSServers clears the static DNS server setting
func (w *WindowsDNSConfigurator) clearDNSServers() error {
regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE)
if err != nil {
return fmt.Errorf("get interface registry key: %w", err)
}
defer closeKey(regKey)
// Set empty string to revert to DHCP
if err := regKey.SetStringValue(interfaceConfigNameServer, ""); err != nil {
return fmt.Errorf("clear NameServer: %w", err)
}
return nil
}
// getInterfaceRegistryKey opens the registry key for the network interface
func (w *WindowsDNSConfigurator) getInterfaceRegistryKey(access uint32) (registry.Key, error) {
regKeyPath := interfaceConfigPath + `\` + w.guid
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, access)
if err != nil {
return 0, fmt.Errorf("open HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
}
return regKey, nil
}
// parseServerList parses a comma or space-separated list of DNS servers
func (w *WindowsDNSConfigurator) parseServerList(serverList string) []netip.Addr {
var servers []netip.Addr
// Split by comma or space
parts := splitByDelimiters(serverList, []rune{',', ' '})
for _, part := range parts {
if addr, err := netip.ParseAddr(part); err == nil {
servers = append(servers, addr)
}
}
return servers
}
// flushDNSCache flushes the Windows DNS resolver cache
func (w *WindowsDNSConfigurator) flushDNSCache() error {
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
defer func() {
if rec := recover(); rec != nil {
fmt.Printf("warning: DnsFlushResolverCache panicked: %v\n", rec)
}
}()
ret, _, err := dnsFlushResolverCacheFn.Call()
if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) {
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
}
return fmt.Errorf("DnsFlushResolverCache failed")
}
return nil
}
// splitByDelimiters splits a string by multiple delimiters
func splitByDelimiters(s string, delimiters []rune) []string {
var result []string
var current []rune
for _, char := range s {
isDelimiter := false
for _, delim := range delimiters {
if char == delim {
isDelimiter = true
break
}
}
if isDelimiter {
if len(current) > 0 {
result = append(result, string(current))
current = []rune{}
}
} else {
current = append(current, char)
}
}
if len(current) > 0 {
result = append(result, string(current))
}
return result
}
// closeKey closes a registry key and logs errors
func closeKey(closer io.Closer) {
if err := closer.Close(); err != nil {
fmt.Printf("warning: failed to close registry key: %v\n", err)
}
}
// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
// This is required for registry-based DNS configuration on Windows
func getInterfaceGUIDString(interfaceName string) (string, error) {
if interfaceName == "" {
return "", fmt.Errorf("interface name is required")
}
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return "", fmt.Errorf("failed to get interface %s: %w", interfaceName, err)
}
luid, err := indexToLUID(uint32(iface.Index))
if err != nil {
return "", fmt.Errorf("failed to convert index to LUID: %w", err)
}
// Convert LUID to GUID using Windows API
guid, err := luidToGUID(luid)
if err != nil {
return "", fmt.Errorf("failed to convert LUID to GUID: %w", err)
}
return guid, nil
}
// indexToLUID converts a Windows interface index to a LUID
func indexToLUID(index uint32) (uint64, error) {
var luid uint64
// Load the iphlpapi.dll and get the ConvertInterfaceIndexToLuid function
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
convertInterfaceIndexToLuid := iphlpapi.NewProc("ConvertInterfaceIndexToLuid")
// Call the Windows API
ret, _, err := convertInterfaceIndexToLuid.Call(
uintptr(index),
uintptr(unsafe.Pointer(&luid)),
)
if ret != 0 {
return 0, fmt.Errorf("ConvertInterfaceIndexToLuid failed with code %d: %w", ret, err)
}
return luid, nil
}
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
// using the Windows ConvertInterface* APIs
func luidToGUID(luid uint64) (string, error) {
var guid windows.GUID
// Load the iphlpapi.dll and get the ConvertInterfaceLuidToGuid function
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
convertLuidToGuid := iphlpapi.NewProc("ConvertInterfaceLuidToGuid")
// Call the Windows API
// NET_LUID is a 64-bit value on Windows
ret, _, err := convertLuidToGuid.Call(
uintptr(unsafe.Pointer(&luid)),
uintptr(unsafe.Pointer(&guid)),
)
if ret != 0 {
return "", fmt.Errorf("ConvertInterfaceLuidToGuid failed with code %d: %w", ret, err)
}
// Format the GUID as a string with curly braces
guidStr := fmt.Sprintf("{%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}",
guid.Data1, guid.Data2, guid.Data3,
guid.Data4[0], guid.Data4[1], guid.Data4[2], guid.Data4[3],
guid.Data4[4], guid.Data4[5], guid.Data4[6], guid.Data4[7])
return guidStr, nil
}

21
go.mod
View File

@@ -3,20 +3,31 @@ module github.com/fosrl/olm
go 1.25
require (
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7
github.com/Microsoft/go-winio v0.6.2
github.com/fosrl/newt v0.0.0
github.com/godbus/dbus/v5 v5.2.0
github.com/gorilla/websocket v1.5.3
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.45.0
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
github.com/miekg/dns v1.1.68
golang.org/x/sys v0.38.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
software.sslmate.com/src/go-pkcs12 v0.6.0
)
require (
github.com/google/btree v1.1.3 // indirect
github.com/vishvananda/netlink v1.3.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
golang.org/x/mod v0.30.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/time v0.12.0 // indirect
golang.org/x/tools v0.39.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
)
replace github.com/fosrl/newt => ../newt

24
go.sum
View File

@@ -1,34 +1,46 @@
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 h1:6bSU8Efyhx1SR53iSw1Wjk5V8vDfizGAudq/GlE9b+o=
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7/go.mod h1:Ac0k2FmAMC+hu21rAK+p7EnnEGrqKO/QZuGTVHA/XDM=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c=
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=
software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=

View File

@@ -1,217 +0,0 @@
package httpserver
import (
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
// ConnectionRequest defines the structure for an incoming connection request
type ConnectionRequest struct {
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
}
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
Connected bool `json:"connected"`
RTT time.Duration `json:"rtt"`
LastSeen time.Time `json:"lastSeen"`
Endpoint string `json:"endpoint,omitempty"`
IsRelay bool `json:"isRelay"`
}
// StatusResponse is returned by the status endpoint
type StatusResponse struct {
Status string `json:"status"`
Connected bool `json:"connected"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
// HTTPServer represents the HTTP server and its state
type HTTPServer struct {
addr string
server *http.Server
connectionChan chan ConnectionRequest
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
tunnelIP string
version string
}
// NewHTTPServer creates a new HTTP server
func NewHTTPServer(addr string) *HTTPServer {
s := &HTTPServer{
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// Start starts the HTTP server
func (s *HTTPServer) Start() error {
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
s.server = &http.Server{
Addr: s.addr,
Handler: mux,
}
logger.Info("Starting HTTP server on %s", s.addr)
go func() {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server error: %v", err)
}
}()
return nil
}
// Stop stops the HTTP server
func (s *HTTPServer) Stop() error {
logger.Info("Stopping HTTP server")
return s.server.Close()
}
// GetConnectionChannel returns the channel for receiving connection requests
func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest {
return s.connectionChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Connected = connected
status.RTT = rtt
status.LastSeen = time.Now()
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// SetConnectionStatus sets the overall connection status
func (s *HTTPServer) SetConnectionStatus(isConnected bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isConnected = isConnected
if isConnected {
s.connectedAt = time.Now()
} else {
// Clear peer statuses when disconnected
s.peerStatuses = make(map[int]*PeerStatus)
}
}
// SetTunnelIP sets the tunnel IP address
func (s *HTTPServer) SetTunnelIP(tunnelIP string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.tunnelIP = tunnelIP
}
// SetVersion sets the olm version
func (s *HTTPServer) SetVersion(version string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.version = version
}
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// handleConnect handles the /connect endpoint
func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req ConnectionRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
return
}
// Send the request to the main goroutine
s.connectionChan <- req
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{
"status": "connection request accepted",
})
}
// handleStatus handles the /status endpoint
func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
s.statusMu.RLock()
defer s.statusMu.RUnlock()
resp := StatusResponse{
Connected: s.isConnected,
TunnelIP: s.tunnelIP,
Version: s.version,
PeerStatuses: s.peerStatuses,
}
if s.isConnected {
resp.Status = "connected"
} else {
resp.Status = "disconnected"
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}

810
main.go
View File

@@ -2,56 +2,17 @@ package main
import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates"
"github.com/fosrl/olm/httpserver"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/fosrl/olm/olm"
)
// Helper function to format endpoints correctly
func formatEndpoint(endpoint string) string {
if endpoint == "" {
return ""
}
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint // Already valid, no change needed
}
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
lastColon := strings.LastIndex(endpoint, ":")
if lastColon > 0 { // Ensure there is a colon and it's not the first character
hostPart := endpoint[:lastColon]
// Check if the host part is a literal IPv6 address
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
// It is! Reformat it with brackets.
portPart := endpoint[lastColon+1:]
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
}
}
// If it's not the specific malformed case, return it as is.
return endpoint
}
func main() {
// Check if we're running as a Windows service
if isWindowsService() {
@@ -193,17 +154,30 @@ func main() {
}
}
// Create a context that will be cancelled on interrupt signals
signalCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
// Create a separate context for programmatic shutdown (e.g., via API exit)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Run in console mode
runOlmMain(context.Background())
runOlmMainWithArgs(ctx, cancel, signalCtx, os.Args[1:])
}
func runOlmMain(ctx context.Context) {
runOlmMainWithArgs(ctx, os.Args[1:])
}
func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCtx context.Context, args []string) {
// Setup Windows event logging if on Windows
if runtime.GOOS == "windows" {
setupWindowsEventLog()
} else {
// Initialize logger for non-Windows platforms
logger.Init(nil)
}
func runOlmMainWithArgs(ctx context.Context, args []string) {
// Load configuration from file, env vars, and CLI args
// Priority: CLI args > Env vars > Config file > Defaults
// Use the passed args parameter instead of os.Args[1:] to support Windows service mode
config, showVersion, showConfig, err := LoadConfig(args)
if err != nil {
fmt.Printf("Failed to load configuration: %v\n", err)
@@ -216,717 +190,75 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
os.Exit(0)
}
// Extract commonly used values from config for convenience
var (
endpoint = config.Endpoint
id = config.ID
secret = config.Secret
mtu = config.MTU
logLevel = config.LogLevel
interfaceName = config.InterfaceName
enableHTTP = config.EnableHTTP
httpAddr = config.HTTPAddr
pingInterval = config.PingIntervalDuration
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
privateKey wgtypes.Key
connected bool
)
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
// Setup Windows event logging if on Windows
if runtime.GOOS == "windows" {
setupWindowsEventLog()
} else {
// Initialize logger for non-Windows platforms
logger.Init()
}
loggerLevel := parseLogLevel(logLevel)
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
olmVersion := "version_replaceme"
if showVersion {
fmt.Println("Olm version " + olmVersion)
os.Exit(0)
}
logger.Info("Olm version " + olmVersion)
logger.Info("Olm version %s", olmVersion)
if err := updates.CheckForUpdate("fosrl", "olm", olmVersion); err != nil {
config.Version = olmVersion
if err := SaveConfig(config); err != nil {
logger.Error("Failed to save full olm config: %v", err)
} else {
logger.Debug("Saved full olm config with all options")
}
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
logger.Debug("Failed to check for updates: %v", err)
}
// Log startup information
logger.Debug("Olm service starting...")
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
logger.Debug("HTTP enabled: %v, HTTP addr: %s", enableHTTP, httpAddr)
if doHolepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
// Create a new olm.Config struct and copy values from the main config
olmConfig := olm.GlobalConfig{
LogLevel: config.LogLevel,
EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr,
SocketPath: config.SocketPath,
Version: config.Version,
Agent: "Olm CLI",
OnExit: cancel, // Pass cancel function directly to trigger shutdown
OnTerminated: cancel,
}
var httpServer *httpserver.HTTPServer
if enableHTTP {
httpServer = httpserver.NewHTTPServer(httpAddr)
httpServer.SetVersion(olmVersion)
if err := httpServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
// Use a goroutine to handle connection requests
go func() {
for req := range httpServer.GetConnectionChannel() {
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
}
}()
olm.Init(ctx, olmConfig)
if err := olm.StartApi(); err != nil {
logger.Fatal("Failed to start API server: %v", err)
}
// // Check if required parameters are missing and provide helpful guidance
// missingParams := []string{}
// if id == "" {
// missingParams = append(missingParams, "id (use -id flag or OLM_ID env var)")
// }
// if secret == "" {
// missingParams = append(missingParams, "secret (use -secret flag or OLM_SECRET env var)")
// }
// if endpoint == "" {
// missingParams = append(missingParams, "endpoint (use -endpoint flag or PANGOLIN_ENDPOINT env var)")
// }
// if len(missingParams) > 0 {
// logger.Error("Missing required parameters: %v", missingParams)
// logger.Error("Either provide them as command line flags or set as environment variables")
// fmt.Printf("ERROR: Missing required parameters: %v\n", missingParams)
// fmt.Printf("Please provide them as command line flags or set as environment variables\n")
// if !enableHTTP {
// logger.Error("HTTP server is disabled, cannot receive parameters via API")
// fmt.Printf("HTTP server is disabled, cannot receive parameters via API\n")
// return
// }
// }
// Create a new olm
olm, err := websocket.NewClient(
"olm",
id, // CLI arg takes precedence
secret, // CLI arg takes precedence
endpoint,
pingInterval,
pingTimeout,
)
if err != nil {
logger.Fatal("Failed to create olm: %v", err)
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
tunnelConfig := olm.TunnelConfig{
Endpoint: config.Endpoint,
ID: config.ID,
Secret: config.Secret,
UserToken: config.UserToken,
MTU: config.MTU,
DNS: config.DNS,
UpstreamDNS: config.UpstreamDNS,
InterfaceName: config.InterfaceName,
Holepunch: !config.DisableHolepunch,
TlsClientCert: config.TlsClientCert,
PingIntervalDuration: config.PingIntervalDuration,
PingTimeoutDuration: config.PingTimeoutDuration,
OrgID: config.OrgID,
OverrideDNS: config.OverrideDNS,
DisableRelay: config.DisableRelay,
EnableUAPI: true,
}
go olm.StartTunnel(tunnelConfig)
} else {
logger.Info("Incomplete tunnel configuration, not starting tunnel")
}
// wait until we have a client id and secret and endpoint
waitCount := 0
for id == "" || secret == "" || endpoint == "" {
select {
case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials")
return
default:
missing := []string{}
if id == "" {
missing = append(missing, "id")
}
if secret == "" {
missing = append(missing, "secret")
}
if endpoint == "" {
missing = append(missing, "endpoint")
}
waitCount++
if waitCount%10 == 1 { // Log every 10 seconds instead of every second
logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount)
}
time.Sleep(1 * time.Second)
}
}
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
// Create TUN device and network stack
var dev *device.Device
var wgData WgData
var holePunchData HolePunchData
var uapiListener net.Listener
var tdev tun.Device
sourcePort, err := FindAvailableUDPPort(49152, 65535)
if err != nil {
fmt.Printf("Error finding available port: %v\n", err)
os.Exit(1)
}
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Create a new stopHolepunch channel for the new set of goroutines
stopHolepunch = make(chan struct{})
// Start a single hole punch goroutine for all exit nodes
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort)
})
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
logger.Debug("Received message: %v", msg.Data)
type LegacyHolePunchData struct {
ServerPubKey string `json:"serverPubKey"`
Endpoint string `json:"endpoint"`
}
var legacyHolePunchData LegacyHolePunchData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Stop any existing hole punch goroutines by closing the current channel
select {
case <-stopHolepunch:
// Channel already closed
default:
close(stopHolepunch)
}
// Create a new stopHolepunch channel for the new set of goroutines
stopHolepunch = make(chan struct{})
// Start hole punching for each exit node
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey)
})
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
if connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
close(stopHolepunch)
// wait 10 milliseconds to ensure the previous connection is closed
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
time.Sleep(500 * time.Millisecond)
// if there is an existing tunnel then close it
if dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
tdev, err = func() (tun.Device, error) {
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
return nil, err
}
return tun.CreateTUN(interfaceName, mtu)
}
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
return createTUNFromFD(tunFdStr, mtu)
}
return tun.CreateTUN(interfaceName, mtu)
}()
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
return
}
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}
fileUAPI, err := func() (*os.File, error) {
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil
}
return uapiOpen(interfaceName)
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() {
for {
conn, err := uapiListener.Accept()
if err != nil {
return
}
go dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
if err = dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
if httpServer != nil {
httpServer.SetTunnelIP(wgData.TunnelIP)
}
peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
if httpServer != nil {
// Find the site config to get endpoint information
var endpoint string
var isRelay bool
for _, site := range wgData.Sites {
if site.SiteId == siteID {
endpoint = site.Endpoint
// TODO: We'll need to track relay status separately
// For now, assume not using relay unless we get relay data
isRelay = !doHolepunch
break
}
}
httpServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
}
if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
} else {
logger.Warn("Peer %d is disconnected", siteID)
}
},
fixKey(privateKey.String()),
olm,
dev,
doHolepunch,
)
for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
if httpServer != nil {
httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
}
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
logger.Info("Configured peer %s", site.PublicKey)
}
peerMonitor.Start()
connected = true
logger.Info("WireGuard device created.")
})
olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) {
logger.Debug("Received update-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateData UpdatePeerData
if err := json.Unmarshal(jsonData, &updateData); err != nil {
logger.Error("Error unmarshaling update data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: updateData.SiteId,
Endpoint: updateData.Endpoint,
PublicKey: updateData.PublicKey,
ServerIP: updateData.ServerIP,
ServerPort: updateData.ServerPort,
RemoteSubnets: updateData.RemoteSubnets,
}
// Update the peer in WireGuard
if dev != nil {
// Find the existing peer to get old data
var oldRemoteSubnets string
var oldPublicKey string
for _, site := range wgData.Sites {
if site.SiteId == updateData.SiteId {
oldRemoteSubnets = site.RemoteSubnets
oldPublicKey = site.PublicKey
break
}
}
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
}
// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
// Remove old remote subnet routes if they changed
if oldRemoteSubnets != siteConfig.RemoteSubnets {
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
logger.Error("Failed to remove old remote subnet routes: %v", err)
// Continue anyway to add new routes
}
// Add new remote subnet routes
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add new remote subnet routes: %v", err)
return
}
}
// Update successful
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
for i := range wgData.Sites {
if wgData.Sites[i].SiteId == updateData.SiteId {
wgData.Sites[i] = siteConfig
break
}
}
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for adding a new peer
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var addData AddPeerData
if err := json.Unmarshal(jsonData, &addData); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: addData.SiteId,
Endpoint: addData.Endpoint,
PublicKey: addData.PublicKey,
ServerIP: addData.ServerIP,
ServerPort: addData.ServerPort,
RemoteSubnets: addData.RemoteSubnets,
}
// Add the peer to WireGuard
if dev != nil {
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for new peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
// Add successful
logger.Info("Successfully added peer for site %d", addData.SiteId)
// Update WgData with the new peer
wgData.Sites = append(wgData.Sites, siteConfig)
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for removing a peer
olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) {
logger.Debug("Received remove-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeData RemovePeerData
if err := json.Unmarshal(jsonData, &removeData); err != nil {
logger.Error("Error unmarshaling remove data: %v", err)
return
}
// Find the peer to remove
var peerToRemove *SiteConfig
var newSites []SiteConfig
for _, site := range wgData.Sites {
if site.SiteId == removeData.SiteId {
peerToRemove = &site
} else {
newSites = append(newSites, site)
}
}
if peerToRemove == nil {
logger.Error("Peer with site ID %d not found", removeData.SiteId)
return
}
// Remove the peer from WireGuard
if dev != nil {
if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
logger.Error("Failed to remove peer: %v", err)
// Send error response if needed
return
}
// Remove route for the peer
err = removeRouteForServerIP(peerToRemove.ServerIP)
if err != nil {
logger.Error("Failed to remove route for peer: %v", err)
return
}
// Remove routes for remote subnets
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
logger.Error("Failed to remove routes for remote subnets: %v", err)
return
}
// Remove successful
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
// Update WgData to remove the peer
wgData.Sites = newSites
} else {
logger.Error("WireGuard device not initialized")
}
})
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
logger.Debug("Received relay-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData RelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := resolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
if httpServer != nil {
httpServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
logger.Info("Received no-sites message - no sites available for connection")
// if stopRegister != nil {
// stopRegister()
// stopRegister = nil
// }
// select {
// case <-stopHolepunch:
// // Channel already closed, do nothing
// default:
// close(stopHolepunch)
// }
logger.Info("No sites available - stopped registration and holepunch processes")
})
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
olm.Close()
})
olm.OnConnect(func() error {
logger.Info("Websocket Connected")
if httpServer != nil {
httpServer.SetConnectionStatus(true)
}
// CRITICAL: Save our full config AFTER websocket saves its limited config
// This ensures all 13 fields are preserved, not just the 4 that websocket saves
if err := SaveConfig(config); err != nil {
logger.Error("Failed to save full olm config: %v", err)
} else {
logger.Debug("Saved full olm config with all options")
}
if connected {
logger.Debug("Already connected, skipping registration")
return nil
}
publicKey := privateKey.PublicKey()
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"olmVersion": olmVersion,
}, 1*time.Second)
go keepSendingPing(olm)
logger.Info("Sent registration message")
return nil
})
olm.OnTokenUpdate(func(token string) {
olmToken = token
})
// Connect to the WebSocket server
if err := olm.Connect(); err != nil {
logger.Fatal("Failed to connect to server: %v", err)
}
defer olm.Close()
// Wait for interrupt signal or context cancellation
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
// Wait for either signal or programmatic shutdown
select {
case <-sigCh:
logger.Info("Received interrupt signal")
case <-signalCtx.Done():
logger.Info("Shutdown signal received, cleaning up...")
case <-ctx.Done():
logger.Info("Context cancelled")
logger.Info("Shutdown requested via API, cleaning up...")
}
select {
case <-stopHolepunch:
// Channel already closed, do nothing
default:
close(stopHolepunch)
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
select {
case <-stopPing:
// Channel already closed
default:
close(stopPing)
}
if uapiListener != nil {
uapiListener.Close()
}
if dev != nil {
dev.Close()
}
logger.Info("runOlmMain() exiting")
fmt.Printf("runOlmMain() exiting\n")
// Clean up resources
olm.Close()
logger.Info("Shutdown complete")
}

126
namespace.sh Normal file
View File

@@ -0,0 +1,126 @@
#!/bin/bash
# Configuration
NS_NAME="isolated_ns" # Name of the namespace
VETH_HOST="veth_host" # Interface name on host side
VETH_NS="veth_ns" # Interface name inside namespace
HOST_IP="192.168.15.1" # Gateway IP for the namespace (host side)
NS_IP="192.168.15.2" # IP address for the namespace
SUBNET_CIDR="24" # Subnet mask
DNS_SERVER="8.8.8.8" # DNS to use inside namespace
# Detect the main physical interface (gateway to internet)
PHY_IFACE=$(ip route get 8.8.8.8 | awk -- '{printf $5}')
# Helper function to check for root
check_root() {
if [ "$EUID" -ne 0 ]; then
echo "Error: This script must be run as root."
exit 1
fi
}
setup_ns() {
echo "Bringing up namespace '$NS_NAME'..."
# 1. Create the network namespace
if ip netns list | grep -q "$NS_NAME"; then
echo "Namespace $NS_NAME already exists. Run 'down' first."
exit 1
fi
ip netns add "$NS_NAME"
# 2. Create veth pair
ip link add "$VETH_HOST" type veth peer name "$VETH_NS"
# 3. Move peer interface to namespace
ip link set "$VETH_NS" netns "$NS_NAME"
# 4. Configure Host Side Interface
ip addr add "${HOST_IP}/${SUBNET_CIDR}" dev "$VETH_HOST"
ip link set "$VETH_HOST" up
# 5. Configure Namespace Side Interface
ip netns exec "$NS_NAME" ip addr add "${NS_IP}/${SUBNET_CIDR}" dev "$VETH_NS"
ip netns exec "$NS_NAME" ip link set "$VETH_NS" up
# 6. Bring up loopback inside namespace (crucial for many apps)
ip netns exec "$NS_NAME" ip link set lo up
# 7. Routing: Add default gateway inside namespace pointing to host
ip netns exec "$NS_NAME" ip route add default via "$HOST_IP"
# 8. Enable IP forwarding on host
echo 1 > /proc/sys/net/ipv4/ip_forward
# 9. NAT/Masquerade: Allow traffic from namespace to go out physical interface
# We verify rule doesn't exist first to avoid duplicates
iptables -t nat -C POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE 2>/dev/null || \
iptables -t nat -A POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE
# Allow forwarding from host veth to WAN and back
iptables -C FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT 2>/dev/null || \
iptables -A FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT
iptables -C FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT 2>/dev/null || \
iptables -A FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT
# 10. DNS Setup
# Netns uses /etc/netns/<name>/resolv.conf if it exists
mkdir -p "/etc/netns/$NS_NAME"
echo "nameserver $DNS_SERVER" > "/etc/netns/$NS_NAME/resolv.conf"
echo "Namespace $NS_NAME is UP."
echo "To enter shell: sudo ip netns exec $NS_NAME bash"
}
teardown_ns() {
echo "Tearing down namespace '$NS_NAME'..."
# 1. Remove Namespace (this automatically deletes the veth pair inside it)
# The host side veth usually disappears when the peer is destroyed.
if ip netns list | grep -q "$NS_NAME"; then
ip netns del "$NS_NAME"
else
echo "Namespace $NS_NAME does not exist."
fi
# 2. Clean up veth host side if it still lingers
if ip link show "$VETH_HOST" > /dev/null 2>&1; then
ip link delete "$VETH_HOST"
fi
# 3. Remove iptables rules
# We use -D to delete the specific rules we added
iptables -t nat -D POSTROUTING -s "${NS_IP}/${SUBNET_CIDR}" -o "$PHY_IFACE" -j MASQUERADE 2>/dev/null
iptables -D FORWARD -i "$VETH_HOST" -o "$PHY_IFACE" -j ACCEPT 2>/dev/null
iptables -D FORWARD -i "$PHY_IFACE" -o "$VETH_HOST" -j ACCEPT 2>/dev/null
# 4. Remove DNS config
rm -rf "/etc/netns/$NS_NAME"
echo "Namespace $NS_NAME is DOWN."
}
test_connectivity() {
echo "Testing connectivity inside $NS_NAME..."
ip netns exec "$NS_NAME" ping -c 3 8.8.8.8
}
# Main execution logic
check_root
case "$1" in
up)
setup_ns
;;
down)
teardown_ns
;;
test)
test_connectivity
;;
*)
echo "Usage: $0 {up|down|test}"
exit 1
esac

1066
olm/olm.go Normal file

File diff suppressed because it is too large Load Diff

66
olm/types.go Normal file
View File

@@ -0,0 +1,66 @@
package olm
import (
"time"
"github.com/fosrl/olm/peers"
)
type WgData struct {
Sites []peers.SiteConfig `json:"sites"`
TunnelIP string `json:"tunnelIP"`
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
}
type GlobalConfig struct {
// Logging
LogLevel string
// HTTP server
EnableAPI bool
HTTPAddr string
SocketPath string
Version string
Agent string
// Callbacks
OnRegistered func()
OnConnected func()
OnTerminated func()
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
OnExit func() // Called when exit is requested via API
}
type TunnelConfig struct {
// Connection settings
Endpoint string
ID string
Secret string
UserToken string
// Network settings
MTU int
DNS string
UpstreamDNS []string
InterfaceName string
// Advanced
Holepunch bool
TlsClientCert string
// Parsed values (not in JSON)
PingIntervalDuration time.Duration
PingTimeoutDuration time.Duration
OrgID string
// DoNotCreateNewClient bool
FileDescriptorTun uint32
FileDescriptorUAPI uint32
EnableUAPI bool
OverrideDNS bool
DisableRelay bool
}

98
olm/util.go Normal file
View File

@@ -0,0 +1,98 @@
package olm
import (
"fmt"
"net"
"strings"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/olm/websocket"
)
// Helper function to format endpoints correctly
func formatEndpoint(endpoint string) string {
if endpoint == "" {
return ""
}
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint // Already valid, no change needed
}
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
lastColon := strings.LastIndex(endpoint, ":")
if lastColon > 0 { // Ensure there is a colon and it's not the first character
hostPart := endpoint[:lastColon]
// Check if the host part is a literal IPv6 address
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
// It is! Reformat it with brackets.
portPart := endpoint[lastColon+1:]
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
}
}
// If it's not the specific malformed case, return it as is.
return endpoint
}
func sendPing(olm *websocket.Client) error {
err := olm.SendMessage("olm/ping", map[string]interface{}{
"timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken,
})
if err != nil {
logger.Error("Failed to send ping message: %v", err)
return err
}
logger.Debug("Sent ping message")
return nil
}
func keepSendingPing(olm *websocket.Client) {
// Send ping immediately on startup
if err := sendPing(olm); err != nil {
logger.Error("Failed to send initial ping: %v", err)
} else {
logger.Info("Sent initial ping message")
}
// Set up ticker for one minute intervals
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopPing:
logger.Info("Stopping ping messages")
return
case <-ticker.C:
if err := sendPing(olm); err != nil {
logger.Error("Failed to send periodic ping: %v", err)
}
}
}
}
func GetNetworkSettingsJSON() (string, error) {
return network.GetJSON()
}
func GetNetworkSettingsIncrementor() int {
return network.GetIncrementor()
}
// stringSlicesEqual compares two string slices for equality
func stringSlicesEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

View File

@@ -1,331 +0,0 @@
package peermonitor
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/websocket"
"github.com/fosrl/olm/wgtester"
"golang.zx2c4.com/wireguard/device"
)
// PeerMonitorCallback is the function type for connection status change callbacks
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
// WireGuardConfig holds the WireGuard configuration for a peer
type WireGuardConfig struct {
SiteID int
PublicKey string
ServerIP string
Endpoint string
PrimaryRelay string // The primary relay endpoint
}
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
type PeerMonitor struct {
monitors map[int]*wgtester.Client
configs map[int]*WireGuardConfig
callback PeerMonitorCallback
mutex sync.Mutex
running bool
interval time.Duration
timeout time.Duration
maxAttempts int
privateKey string
wsClient *websocket.Client
device *device.Device
handleRelaySwitch bool // Whether to handle relay switching
}
// NewPeerMonitor creates a new peer monitor with the given callback
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
return &PeerMonitor{
monitors: make(map[int]*wgtester.Client),
configs: make(map[int]*WireGuardConfig),
callback: callback,
interval: 1 * time.Second, // Default check interval
timeout: 2500 * time.Millisecond,
maxAttempts: 8,
privateKey: privateKey,
wsClient: wsClient,
device: device,
handleRelaySwitch: handleRelaySwitch,
}
}
// SetInterval changes how frequently peers are checked
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.interval = interval
// Update interval for all existing monitors
for _, client := range pm.monitors {
client.SetPacketInterval(interval)
}
}
// SetTimeout changes the timeout for waiting for responses
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.timeout = timeout
// Update timeout for all existing monitors
for _, client := range pm.monitors {
client.SetTimeout(timeout)
}
}
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.maxAttempts = attempts
// Update max attempts for all existing monitors
for _, client := range pm.monitors {
client.SetMaxAttempts(attempts)
}
}
// AddPeer adds a new peer to monitor
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error {
pm.mutex.Lock()
defer pm.mutex.Unlock()
// Check if we're already monitoring this peer
if _, exists := pm.monitors[siteID]; exists {
// Update the endpoint instead of creating a new monitor
pm.removePeerUnlocked(siteID)
}
client, err := wgtester.NewClient(endpoint)
if err != nil {
return err
}
// Configure the client with our settings
client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts)
// Store the client and config
pm.monitors[siteID] = client
pm.configs[siteID] = wgConfig
// If monitor is already running, start monitoring this peer
if pm.running {
siteIDCopy := siteID // Create a copy for the closure
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
pm.handleConnectionStatusChange(siteIDCopy, status)
})
}
return err
}
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
// This function assumes the mutex is already held by the caller
func (pm *PeerMonitor) removePeerUnlocked(siteID int) {
client, exists := pm.monitors[siteID]
if !exists {
return
}
client.StopMonitor()
client.Close()
delete(pm.monitors, siteID)
delete(pm.configs, siteID)
}
// RemovePeer stops monitoring a peer and removes it from the monitor
func (pm *PeerMonitor) RemovePeer(siteID int) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.removePeerUnlocked(siteID)
}
// Start begins monitoring all peers
func (pm *PeerMonitor) Start() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
if pm.running {
return // Already running
}
pm.running = true
// Start monitoring all peers
for siteID, client := range pm.monitors {
siteIDCopy := siteID // Create a copy for the closure
err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
pm.handleConnectionStatusChange(siteIDCopy, status)
})
if err != nil {
logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err)
continue
}
logger.Info("Started monitoring peer %d\n", siteID)
}
}
// handleConnectionStatusChange is called when a peer's connection status changes
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status wgtester.ConnectionStatus) {
// Call the user-provided callback first
if pm.callback != nil {
pm.callback(siteID, status.Connected, status.RTT)
}
// If disconnected, handle failover
if !status.Connected {
// Send relay message to the server
if pm.wsClient != nil {
pm.sendRelay(siteID)
}
}
}
// handleFailover handles failover to the relay server when a peer is disconnected
func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) {
pm.mutex.Lock()
config, exists := pm.configs[siteID]
pm.mutex.Unlock()
if !exists {
return
}
// Check for IPv6 and format the endpoint correctly
formattedEndpoint := relayEndpoint
if strings.Contains(relayEndpoint, ":") {
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
}
// Configure WireGuard to use the relay
wgConfig := fmt.Sprintf(`private_key=%s
public_key=%s
allowed_ip=%s/32
endpoint=%s:21820
persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, formattedEndpoint)
err := pm.device.IpcSet(wgConfig)
if err != nil {
logger.Error("Failed to configure WireGuard device: %v\n", err)
return
}
logger.Info("Adjusted peer %d to point to relay!\n", siteID)
}
// sendRelay sends a relay message to the server
func (pm *PeerMonitor) sendRelay(siteID int) error {
if !pm.handleRelaySwitch {
return nil
}
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent relay message")
return nil
}
// Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
if !pm.running {
return
}
pm.running = false
// Stop all monitors
for _, client := range pm.monitors {
client.StopMonitor()
}
}
// Close stops monitoring and cleans up resources
func (pm *PeerMonitor) Close() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
// Stop and close all clients
for siteID, client := range pm.monitors {
client.StopMonitor()
client.Close()
delete(pm.monitors, siteID)
}
pm.running = false
}
// TestPeer tests connectivity to a specific peer
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
pm.mutex.Lock()
client, exists := pm.monitors[siteID]
pm.mutex.Unlock()
if !exists {
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
}
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
defer cancel()
connected, rtt := client.TestConnection(ctx)
return connected, rtt, nil
}
// TestAllPeers tests connectivity to all peers
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
Connected bool
RTT time.Duration
} {
pm.mutex.Lock()
peers := make(map[int]*wgtester.Client, len(pm.monitors))
for siteID, client := range pm.monitors {
peers[siteID] = client
}
pm.mutex.Unlock()
results := make(map[int]struct {
Connected bool
RTT time.Duration
})
for siteID, client := range peers {
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
connected, rtt := client.TestConnection(ctx)
cancel()
results[siteID] = struct {
Connected bool
RTT time.Duration
}{
Connected: connected,
RTT: rtt,
}
}
return results
}

884
peers/manager.go Normal file
View File

@@ -0,0 +1,884 @@
package peers
import (
"fmt"
"net"
"strconv"
"strings"
"sync"
"github.com/fosrl/newt/bind"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/api"
olmDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns"
"github.com/fosrl/olm/peers/monitor"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// PeerManagerConfig contains the configuration for creating a PeerManager
type PeerManagerConfig struct {
Device *device.Device
DNSProxy *dns.DNSProxy
InterfaceName string
PrivateKey wgtypes.Key
// For peer monitoring
MiddleDev *olmDevice.MiddleDevice
LocalIP string
SharedBind *bind.SharedBind
// WSClient is optional - if nil, relay messages won't be sent
WSClient *websocket.Client
APIServer *api.API
}
type PeerManager struct {
mu sync.RWMutex
device *device.Device
peers map[int]SiteConfig
peerMonitor *monitor.PeerMonitor
dnsProxy *dns.DNSProxy
interfaceName string
privateKey wgtypes.Key
// allowedIPOwners tracks which peer currently "owns" each allowed IP in WireGuard
// key is the CIDR string, value is the siteId that has it configured in WG
allowedIPOwners map[string]int
// allowedIPClaims tracks all peers that claim each allowed IP
// key is the CIDR string, value is a set of siteIds that want this IP
allowedIPClaims map[string]map[int]bool
APIServer *api.API
}
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
func NewPeerManager(config PeerManagerConfig) *PeerManager {
pm := &PeerManager{
device: config.Device,
peers: make(map[int]SiteConfig),
dnsProxy: config.DNSProxy,
interfaceName: config.InterfaceName,
privateKey: config.PrivateKey,
allowedIPOwners: make(map[string]int),
allowedIPClaims: make(map[string]map[int]bool),
APIServer: config.APIServer,
}
// Create the peer monitor
pm.peerMonitor = monitor.NewPeerMonitor(
config.WSClient,
config.MiddleDev,
config.LocalIP,
config.SharedBind,
config.APIServer,
)
return pm
}
func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
pm.mu.RLock()
defer pm.mu.RUnlock()
peer, ok := pm.peers[siteId]
return peer, ok
}
func (pm *PeerManager) GetAllPeers() []SiteConfig {
pm.mu.RLock()
defer pm.mu.RUnlock()
peers := make([]SiteConfig, 0, len(pm.peers))
for _, peer := range pm.peers {
peers = append(peers, peer)
}
return peers
}
func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
pm.mu.Lock()
defer pm.mu.Unlock()
// build the allowed IPs list from the remote subnets and aliases and add them to the peer
allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
allowedIPs = append(allowedIPs, siteConfig.RemoteSubnets...)
for _, alias := range siteConfig.Aliases {
allowedIPs = append(allowedIPs, alias.AliasAddress+"/32")
}
siteConfig.AllowedIps = allowedIPs
// Register claims for all allowed IPs and determine which ones this peer will own
ownedIPs := make([]string, 0, len(allowedIPs))
for _, ip := range allowedIPs {
pm.claimAllowedIP(siteConfig.SiteId, ip)
// Check if this peer became the owner
if pm.allowedIPOwners[ip] == siteConfig.SiteId {
ownedIPs = append(ownedIPs, ip)
}
}
// Create a config with only the owned IPs for WireGuard
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
return err
}
if err := network.AddRouteForServerIP(siteConfig.ServerIP, pm.interfaceName); err != nil {
logger.Error("Failed to add route for server IP: %v", err)
}
if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
}
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
}
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
err := pm.peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, siteConfig.Endpoint) // always use the real site endpoint for hole punch monitoring
if err != nil {
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
} else {
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
}
pm.peers[siteConfig.SiteId] = siteConfig
pm.APIServer.AddPeerStatus(siteConfig.SiteId, siteConfig.Name, false, 0, siteConfig.Endpoint, false)
// Perform rapid initial holepunch test (outside of lock to avoid blocking)
// This quickly determines if holepunch is viable and triggers relay if not
go pm.performRapidInitialTest(siteConfig.SiteId, siteConfig.Endpoint)
return nil
}
func (pm *PeerManager) RemovePeer(siteId int) error {
pm.mu.Lock()
defer pm.mu.Unlock()
peer, exists := pm.peers[siteId]
if !exists {
return fmt.Errorf("peer with site ID %d not found", siteId)
}
if err := RemovePeer(pm.device, siteId, peer.PublicKey); err != nil {
return err
}
if err := network.RemoveRouteForServerIP(peer.ServerIP, pm.interfaceName); err != nil {
logger.Error("Failed to remove route for server IP: %v", err)
}
// Only remove routes for subnets that aren't used by other peers
for _, subnet := range peer.RemoteSubnets {
subnetStillInUse := false
for otherSiteId, otherPeer := range pm.peers {
if otherSiteId == siteId {
continue // Skip the peer being removed
}
for _, otherSubnet := range otherPeer.RemoteSubnets {
if otherSubnet == subnet {
subnetStillInUse = true
break
}
}
if subnetStillInUse {
break
}
}
if !subnetStillInUse {
if err := network.RemoveRoutes([]string{subnet}); err != nil {
logger.Error("Failed to remove route for remote subnet %s: %v", subnet, err)
}
}
}
// For aliases
for _, alias := range peer.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.RemoveDNSRecord(alias.Alias, address)
}
// Release all IP claims and promote other peers as needed
// Collect promotions first to avoid modifying while iterating
type promotion struct {
newOwner int
cidr string
}
var promotions []promotion
for _, ip := range peer.AllowedIps {
newOwner, promoted := pm.releaseAllowedIP(siteId, ip)
if promoted && newOwner >= 0 {
promotions = append(promotions, promotion{newOwner: newOwner, cidr: ip})
}
}
// Apply promotions - update WireGuard config for newly promoted peers
// Group by peer to avoid multiple config updates
promotedPeers := make(map[int]bool)
for _, p := range promotions {
promotedPeers[p.newOwner] = true
logger.Info("Promoted peer %d to owner of IP %s", p.newOwner, p.cidr)
}
for promotedPeerId := range promotedPeers {
if promotedPeer, exists := pm.peers[promotedPeerId]; exists {
// Build the list of IPs this peer now owns
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
wgConfig := promotedPeer
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}
}
// Stop monitoring this peer
pm.peerMonitor.RemovePeer(siteId)
logger.Info("Stopped monitoring for site %d", siteId)
pm.APIServer.RemovePeerStatus(siteId)
delete(pm.peers, siteId)
return nil
}
func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
pm.mu.Lock()
defer pm.mu.Unlock()
oldPeer, exists := pm.peers[siteConfig.SiteId]
if !exists {
return fmt.Errorf("peer with site ID %d not found", siteConfig.SiteId)
}
// If public key changed, remove old peer first
if siteConfig.PublicKey != oldPeer.PublicKey {
if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
}
}
// Build the new allowed IPs list
newAllowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
newAllowedIPs = append(newAllowedIPs, siteConfig.RemoteSubnets...)
for _, alias := range siteConfig.Aliases {
newAllowedIPs = append(newAllowedIPs, alias.AliasAddress+"/32")
}
siteConfig.AllowedIps = newAllowedIPs
// Handle allowed IP claim changes
oldAllowedIPs := make(map[string]bool)
for _, ip := range oldPeer.AllowedIps {
oldAllowedIPs[ip] = true
}
newAllowedIPsSet := make(map[string]bool)
for _, ip := range newAllowedIPs {
newAllowedIPsSet[ip] = true
}
// Track peers that need WireGuard config updates due to promotions
peersToUpdate := make(map[int]bool)
// Release claims for removed IPs and handle promotions
for ip := range oldAllowedIPs {
if !newAllowedIPsSet[ip] {
newOwner, promoted := pm.releaseAllowedIP(siteConfig.SiteId, ip)
if promoted && newOwner >= 0 {
peersToUpdate[newOwner] = true
logger.Info("Promoted peer %d to owner of IP %s", newOwner, ip)
}
}
}
// Add claims for new IPs
for ip := range newAllowedIPsSet {
if !oldAllowedIPs[ip] {
pm.claimAllowedIP(siteConfig.SiteId, ip)
}
}
// Build the list of IPs this peer owns for WireGuard config
ownedIPs := pm.getOwnedAllowedIPs(siteConfig.SiteId)
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
return err
}
// Update WireGuard config for any promoted peers
for promotedPeerId := range peersToUpdate {
if promotedPeer, exists := pm.peers[promotedPeerId]; exists {
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
promotedWgConfig := promotedPeer
promotedWgConfig.AllowedIps = promotedOwnedIPs
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}
}
// Handle remote subnet route changes
// Calculate added and removed subnets
oldSubnets := make(map[string]bool)
for _, s := range oldPeer.RemoteSubnets {
oldSubnets[s] = true
}
newSubnets := make(map[string]bool)
for _, s := range siteConfig.RemoteSubnets {
newSubnets[s] = true
}
var addedSubnets []string
var removedSubnets []string
for s := range newSubnets {
if !oldSubnets[s] {
addedSubnets = append(addedSubnets, s)
}
}
for s := range oldSubnets {
if !newSubnets[s] {
removedSubnets = append(removedSubnets, s)
}
}
// Remove routes for removed subnets (only if no other peer needs them)
for _, subnet := range removedSubnets {
subnetStillInUse := false
for otherSiteId, otherPeer := range pm.peers {
if otherSiteId == siteConfig.SiteId {
continue // Skip the current peer (already updated)
}
for _, otherSubnet := range otherPeer.RemoteSubnets {
if otherSubnet == subnet {
subnetStillInUse = true
break
}
}
if subnetStillInUse {
break
}
}
if !subnetStillInUse {
if err := network.RemoveRoutes([]string{subnet}); err != nil {
logger.Error("Failed to remove route for subnet %s: %v", subnet, err)
}
}
}
// Add routes for added subnets
if len(addedSubnets) > 0 {
if err := network.AddRoutes(addedSubnets, pm.interfaceName); err != nil {
logger.Error("Failed to add routes: %v", err)
}
}
// Update aliases
// Remove old aliases
for _, alias := range oldPeer.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.RemoveDNSRecord(alias.Alias, address)
}
// Add new aliases
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
}
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
pm.peerMonitor.UpdatePeerEndpoint(siteConfig.SiteId, monitorPeer) // +1 for monitor port
pm.peers[siteConfig.SiteId] = siteConfig
return nil
}
// claimAllowedIP registers a peer's claim to an allowed IP.
// If no other peer owns it in WireGuard, this peer becomes the owner.
// Must be called with lock held.
func (pm *PeerManager) claimAllowedIP(siteId int, cidr string) {
// Add to claims
if pm.allowedIPClaims[cidr] == nil {
pm.allowedIPClaims[cidr] = make(map[int]bool)
}
pm.allowedIPClaims[cidr][siteId] = true
// If no owner yet, this peer becomes the owner
if _, hasOwner := pm.allowedIPOwners[cidr]; !hasOwner {
pm.allowedIPOwners[cidr] = siteId
}
}
// releaseAllowedIP removes a peer's claim to an allowed IP.
// If this peer was the owner, it promotes another claimant to owner.
// Returns the new owner's siteId (or -1 if no new owner) and whether promotion occurred.
// Must be called with lock held.
func (pm *PeerManager) releaseAllowedIP(siteId int, cidr string) (newOwner int, promoted bool) {
// Remove from claims
if claims, exists := pm.allowedIPClaims[cidr]; exists {
delete(claims, siteId)
if len(claims) == 0 {
delete(pm.allowedIPClaims, cidr)
}
}
// Check if this peer was the owner
owner, isOwned := pm.allowedIPOwners[cidr]
if !isOwned || owner != siteId {
return -1, false // Not the owner, nothing to promote
}
// This peer was the owner, need to find a new owner
delete(pm.allowedIPOwners, cidr)
// Find another claimant to promote
if claims, exists := pm.allowedIPClaims[cidr]; exists && len(claims) > 0 {
for claimantId := range claims {
pm.allowedIPOwners[cidr] = claimantId
return claimantId, true
}
}
return -1, false
}
// getOwnedAllowedIPs returns the list of allowed IPs that a peer currently owns in WireGuard.
// Must be called with lock held.
func (pm *PeerManager) getOwnedAllowedIPs(siteId int) []string {
var owned []string
for cidr, owner := range pm.allowedIPOwners {
if owner == siteId {
owned = append(owned, cidr)
}
}
return owned
}
// addAllowedIp adds an IP (subnet) to the allowed IPs list of a peer
// and updates WireGuard configuration if this peer owns the IP.
// Must be called with lock held.
func (pm *PeerManager) addAllowedIp(siteId int, ip string) error {
peer, exists := pm.peers[siteId]
if !exists {
return fmt.Errorf("peer with site ID %d not found", siteId)
}
// Check if IP already exists in AllowedIps
for _, allowedIp := range peer.AllowedIps {
if allowedIp == ip {
return nil // Already exists
}
}
// Register our claim to this IP
pm.claimAllowedIP(siteId, ip)
peer.AllowedIps = append(peer.AllowedIps, ip)
pm.peers[siteId] = peer
// Only update WireGuard if we own this IP
if pm.allowedIPOwners[ip] == siteId {
if err := AddAllowedIP(pm.device, peer.PublicKey, ip); err != nil {
return err
}
}
return nil
}
// removeAllowedIp removes an IP (subnet) from the allowed IPs list of a peer
// and updates WireGuard configuration. If this peer owned the IP, it promotes
// another peer that also claims this IP. Must be called with lock held.
func (pm *PeerManager) removeAllowedIp(siteId int, cidr string) error {
peer, exists := pm.peers[siteId]
if !exists {
return fmt.Errorf("peer with site ID %d not found", siteId)
}
found := false
// Remove from AllowedIps
newAllowedIps := make([]string, 0, len(peer.AllowedIps))
for _, allowedIp := range peer.AllowedIps {
if allowedIp == cidr {
found = true
continue
}
newAllowedIps = append(newAllowedIps, allowedIp)
}
if !found {
return nil // Not found
}
peer.AllowedIps = newAllowedIps
pm.peers[siteId] = peer
// Release our claim and check if we need to promote another peer
newOwner, promoted := pm.releaseAllowedIP(siteId, cidr)
// Build the list of IPs this peer currently owns for the replace operation
ownedIPs := pm.getOwnedAllowedIPs(siteId)
// Also include the server IP which is always owned
serverIP := strings.Split(peer.ServerIP, "/")[0] + "/32"
hasServerIP := false
for _, ip := range ownedIPs {
if ip == serverIP {
hasServerIP = true
break
}
}
if !hasServerIP {
ownedIPs = append([]string{serverIP}, ownedIPs...)
}
// Update WireGuard for this peer using replace_allowed_ips
if err := RemoveAllowedIP(pm.device, peer.PublicKey, ownedIPs); err != nil {
return err
}
// If another peer was promoted to owner, add the IP to their WireGuard config
if promoted && newOwner >= 0 {
if newOwnerPeer, exists := pm.peers[newOwner]; exists {
if err := AddAllowedIP(pm.device, newOwnerPeer.PublicKey, cidr); err != nil {
logger.Error("Failed to promote peer %d for IP %s: %v", newOwner, cidr, err)
} else {
logger.Info("Promoted peer %d to owner of IP %s", newOwner, cidr)
}
}
}
return nil
}
// AddRemoteSubnet adds an IP (subnet) to the allowed IPs list of a peer
func (pm *PeerManager) AddRemoteSubnet(siteId int, cidr string) error {
pm.mu.Lock()
defer pm.mu.Unlock()
peer, exists := pm.peers[siteId]
if !exists {
return fmt.Errorf("peer with site ID %d not found", siteId)
}
// Check if IP already exists in RemoteSubnets
for _, subnet := range peer.RemoteSubnets {
if subnet == cidr {
return nil // Already exists
}
}
peer.RemoteSubnets = append(peer.RemoteSubnets, cidr)
pm.peers[siteId] = peer // Save before calling addAllowedIp which reads from pm.peers
// Add to allowed IPs
if err := pm.addAllowedIp(siteId, cidr); err != nil {
return err
}
// Add route
if err := network.AddRoutes([]string{cidr}, pm.interfaceName); err != nil {
return err
}
return nil
}
// RemoveRemoteSubnet removes an IP (subnet) from the allowed IPs list of a peer
func (pm *PeerManager) RemoveRemoteSubnet(siteId int, ip string) error {
pm.mu.Lock()
defer pm.mu.Unlock()
peer, exists := pm.peers[siteId]
if !exists {
return fmt.Errorf("peer with site ID %d not found", siteId)
}
found := false
// Remove from RemoteSubnets
newSubnets := make([]string, 0, len(peer.RemoteSubnets))
for _, subnet := range peer.RemoteSubnets {
if subnet == ip {
found = true
continue
}
newSubnets = append(newSubnets, subnet)
}
if !found {
return nil // Not found
}
peer.RemoteSubnets = newSubnets
pm.peers[siteId] = peer // Save before calling removeAllowedIp which reads from pm.peers
// Remove from allowed IPs (this also handles promotion of other peers)
if err := pm.removeAllowedIp(siteId, ip); err != nil {
return err
}
// Check if any other peer still has this subnet before removing the route
subnetStillInUse := false
for otherSiteId, otherPeer := range pm.peers {
if otherSiteId == siteId {
continue // Skip the current peer (already updated above)
}
for _, subnet := range otherPeer.RemoteSubnets {
if subnet == ip {
subnetStillInUse = true
break
}
}
if subnetStillInUse {
break
}
}
// Only remove route if no other peer needs it
if !subnetStillInUse {
if err := network.RemoveRoutes([]string{ip}); err != nil {
return err
}
}
return nil
}
// AddAlias adds an alias to a peer
func (pm *PeerManager) AddAlias(siteId int, alias Alias) error {
pm.mu.Lock()
defer pm.mu.Unlock()
peer, exists := pm.peers[siteId]
if !exists {
return fmt.Errorf("peer with site ID %d not found", siteId)
}
peer.Aliases = append(peer.Aliases, alias)
pm.peers[siteId] = peer
address := net.ParseIP(alias.AliasAddress)
if address != nil {
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
}
// Add an allowed IP for the alias
if err := pm.addAllowedIp(siteId, alias.AliasAddress+"/32"); err != nil {
return err
}
return nil
}
// RemoveAlias removes an alias from a peer
func (pm *PeerManager) RemoveAlias(siteId int, aliasName string) error {
pm.mu.Lock()
defer pm.mu.Unlock()
peer, exists := pm.peers[siteId]
if !exists {
return fmt.Errorf("peer with site ID %d not found", siteId)
}
var aliasToRemove *Alias
newAliases := make([]Alias, 0, len(peer.Aliases))
for _, a := range peer.Aliases {
if a.Alias == aliasName {
aliasToRemove = &a
continue
}
newAliases = append(newAliases, a)
}
if aliasToRemove != nil {
address := net.ParseIP(aliasToRemove.AliasAddress)
if address != nil {
pm.dnsProxy.RemoveDNSRecord(aliasName, address)
}
}
peer.Aliases = newAliases
pm.peers[siteId] = peer
// Check if any other alias is still using this IP address before removing from allowed IPs
ipStillInUse := false
aliasIP := aliasToRemove.AliasAddress + "/32"
for _, a := range newAliases {
if a.AliasAddress+"/32" == aliasIP {
ipStillInUse = true
break
}
}
// Only remove the allowed IP if no other alias is using it
if !ipStillInUse {
if err := pm.removeAllowedIp(siteId, aliasIP); err != nil {
return err
}
}
return nil
}
// RelayPeer handles failover to the relay server when a peer is disconnected
func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string) {
pm.mu.Lock()
peer, exists := pm.peers[siteId]
if exists {
// Store the relay endpoint
peer.RelayEndpoint = relayEndpoint
pm.peers[siteId] = peer
}
pm.mu.Unlock()
if !exists {
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
return
}
// Check for IPv6 and format the endpoint correctly
formattedEndpoint := relayEndpoint
if strings.Contains(relayEndpoint, ":") {
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
}
// Update only the endpoint for this peer (update_only preserves other settings)
wgConfig := fmt.Sprintf(`public_key=%s
update_only=true
endpoint=%s:21820`, util.FixKey(peer.PublicKey), formattedEndpoint)
err := pm.device.IpcSet(wgConfig)
if err != nil {
logger.Error("Failed to configure WireGuard device: %v\n", err)
return
}
// Mark the peer as relayed in the monitor
if pm.peerMonitor != nil {
pm.peerMonitor.MarkPeerRelayed(siteId, true)
}
logger.Info("Adjusted peer %d to point to relay!\n", siteId)
}
// performRapidInitialTest performs a rapid holepunch test for a newly added peer.
// If the test fails, it immediately requests relay to minimize connection delay.
// This runs in a goroutine to avoid blocking AddPeer.
func (pm *PeerManager) performRapidInitialTest(siteId int, endpoint string) {
if pm.peerMonitor == nil {
return
}
// Perform rapid test - this takes ~1-2 seconds max
holepunchViable := pm.peerMonitor.RapidTestPeer(siteId, endpoint)
if !holepunchViable {
// Holepunch failed rapid test, request relay immediately
logger.Info("Rapid test failed for site %d, requesting relay", siteId)
if err := pm.peerMonitor.RequestRelay(siteId); err != nil {
logger.Error("Failed to request relay for site %d: %v", siteId, err)
}
} else {
logger.Info("Rapid test passed for site %d, using direct connection", siteId)
}
}
// Start starts the peer monitor
func (pm *PeerManager) Start() {
if pm.peerMonitor != nil {
pm.peerMonitor.Start()
}
}
// Stop stops the peer monitor
func (pm *PeerManager) Stop() {
if pm.peerMonitor != nil {
pm.peerMonitor.Stop()
}
}
// Close stops the peer monitor and cleans up resources
func (pm *PeerManager) Close() {
if pm.peerMonitor != nil {
pm.peerMonitor.Close()
pm.peerMonitor = nil
}
}
// MarkPeerRelayed marks a peer as currently using relay
func (pm *PeerManager) MarkPeerRelayed(siteID int, relayed bool) {
pm.mu.Lock()
if peer, exists := pm.peers[siteID]; exists {
if relayed {
// We're being relayed, store the current endpoint as the original
// (RelayEndpoint is set by HandleFailover)
} else {
// Clear relay endpoint when switching back to direct
peer.RelayEndpoint = ""
pm.peers[siteID] = peer
}
}
pm.mu.Unlock()
if pm.peerMonitor != nil {
pm.peerMonitor.MarkPeerRelayed(siteID, relayed)
}
}
// UnRelayPeer switches a peer from relay back to direct connection
func (pm *PeerManager) UnRelayPeer(siteId int, endpoint string) error {
pm.mu.Lock()
peer, exists := pm.peers[siteId]
if exists {
// Store the relay endpoint
peer.Endpoint = endpoint
pm.peers[siteId] = peer
}
pm.mu.Unlock()
if !exists {
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
return nil
}
// Update WireGuard to use the direct endpoint
wgConfig := fmt.Sprintf(`public_key=%s
update_only=true
endpoint=%s`, util.FixKey(peer.PublicKey), endpoint)
err := pm.device.IpcSet(wgConfig)
if err != nil {
logger.Error("Failed to switch peer %d to direct connection: %v", siteId, err)
return err
}
// Mark as not relayed in monitor
if pm.peerMonitor != nil {
pm.peerMonitor.MarkPeerRelayed(siteId, false)
}
logger.Info("Switched peer %d back to direct connection at %s", siteId, endpoint)
return nil
}

924
peers/monitor/monitor.go Normal file
View File

@@ -0,0 +1,924 @@
package monitor
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/fosrl/newt/bind"
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/api"
middleDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/websocket"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
type PeerMonitor struct {
monitors map[int]*Client
mutex sync.Mutex
running bool
interval time.Duration
timeout time.Duration
maxAttempts int
wsClient *websocket.Client
// Netstack fields
middleDev *middleDevice.MiddleDevice
localIP string
stack *stack.Stack
ep *channel.Endpoint
activePorts map[uint16]bool
portsLock sync.Mutex
nsCtx context.Context
nsCancel context.CancelFunc
nsWg sync.WaitGroup
// Holepunch testing fields
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchInterval time.Duration
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
// Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
holepunchMaxAttempts int // max consecutive failures before triggering relay
holepunchFailures map[int]int // siteID -> consecutive failure count
// Rapid initial test fields
rapidTestInterval time.Duration // interval between rapid test attempts
rapidTestTimeout time.Duration // timeout for each rapid test attempt
rapidTestMaxAttempts int // max attempts during rapid test phase
// API server for status updates
apiServer *api.API
// WG connection status tracking
wgConnectionStatus map[int]bool // siteID -> WG connected status
}
// NewPeerMonitor creates a new peer monitor with the given callback
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
monitors: make(map[int]*Client),
interval: 2 * time.Second, // Default check interval (faster)
timeout: 3 * time.Second,
maxAttempts: 3,
wsClient: wsClient,
middleDev: middleDev,
localIP: localIP,
activePorts: make(map[uint16]bool),
nsCtx: ctx,
nsCancel: cancel,
sharedBind: sharedBind,
holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds
holepunchTimeout: 2 * time.Second, // Faster timeout
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool),
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
holepunchFailures: make(map[int]int),
// Rapid initial test settings: complete within ~1.5 seconds
rapidTestInterval: 200 * time.Millisecond, // 200ms between attempts
rapidTestTimeout: 400 * time.Millisecond, // 400ms timeout per attempt
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
apiServer: apiServer,
wgConnectionStatus: make(map[int]bool),
}
if err := pm.initNetstack(); err != nil {
logger.Error("Failed to initialize netstack for peer monitor: %v", err)
}
// Initialize holepunch tester if sharedBind is available
if sharedBind != nil {
pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind)
}
return pm
}
// SetInterval changes how frequently peers are checked
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.interval = interval
// Update interval for all existing monitors
for _, client := range pm.monitors {
client.SetPacketInterval(interval)
}
}
// SetTimeout changes the timeout for waiting for responses
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.timeout = timeout
// Update timeout for all existing monitors
for _, client := range pm.monitors {
client.SetTimeout(timeout)
}
}
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.maxAttempts = attempts
// Update max attempts for all existing monitors
for _, client := range pm.monitors {
client.SetMaxAttempts(attempts)
}
}
// AddPeer adds a new peer to monitor
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint string) error {
pm.mutex.Lock()
defer pm.mutex.Unlock()
if _, exists := pm.monitors[siteID]; exists {
return nil // Already monitoring
}
// Use our custom dialer that uses netstack
client, err := NewClient(endpoint, pm.dial)
if err != nil {
return err
}
client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts)
pm.monitors[siteID] = client
pm.holepunchEndpoints[siteID] = holepunchEndpoint
pm.holepunchStatus[siteID] = false // Initially unknown/disconnected
if pm.running {
if err := client.StartMonitor(func(status ConnectionStatus) {
pm.handleConnectionStatusChange(siteID, status)
}); err != nil {
return err
}
}
return nil
}
// update holepunch endpoint for a peer
func (pm *PeerMonitor) UpdateHolepunchEndpoint(siteID int, endpoint string) {
go func() {
time.Sleep(3 * time.Second)
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.holepunchEndpoints[siteID] = endpoint
}()
}
// RapidTestPeer performs a rapid connectivity test for a newly added peer.
// This is designed to quickly determine if holepunch is viable within ~1-2 seconds.
// Returns true if the connection is viable (holepunch works), false if it should relay.
func (pm *PeerMonitor) RapidTestPeer(siteID int, endpoint string) bool {
if pm.holepunchTester == nil {
logger.Warn("Cannot perform rapid test: holepunch tester not initialized")
return false
}
pm.mutex.Lock()
interval := pm.rapidTestInterval
timeout := pm.rapidTestTimeout
maxAttempts := pm.rapidTestMaxAttempts
pm.mutex.Unlock()
logger.Info("Starting rapid holepunch test for site %d at %s (max %d attempts, %v timeout each)",
siteID, endpoint, maxAttempts, timeout)
for attempt := 1; attempt <= maxAttempts; attempt++ {
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
if result.Success {
logger.Info("Rapid test: site %d holepunch SUCCEEDED on attempt %d (RTT: %v)",
siteID, attempt, result.RTT)
// Update status
pm.mutex.Lock()
pm.holepunchStatus[siteID] = true
pm.holepunchFailures[siteID] = 0
pm.mutex.Unlock()
return true
}
if attempt < maxAttempts {
time.Sleep(interval)
}
}
logger.Warn("Rapid test: site %d holepunch FAILED after %d attempts, will relay",
siteID, maxAttempts)
// Update status to reflect failure
pm.mutex.Lock()
pm.holepunchStatus[siteID] = false
pm.holepunchFailures[siteID] = maxAttempts
pm.mutex.Unlock()
return false
}
// UpdatePeerEndpoint updates the monitor endpoint for a peer
func (pm *PeerMonitor) UpdatePeerEndpoint(siteID int, monitorPeer string) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
client, exists := pm.monitors[siteID]
if !exists {
logger.Warn("Cannot update endpoint: peer %d not found in monitor", siteID)
return
}
// Update the client's server address
client.UpdateServerAddr(monitorPeer)
logger.Info("Updated monitor endpoint for site %d to %s", siteID, monitorPeer)
}
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
// This function assumes the mutex is already held by the caller
func (pm *PeerMonitor) removePeerUnlocked(siteID int) {
client, exists := pm.monitors[siteID]
if !exists {
return
}
client.StopMonitor()
client.Close()
delete(pm.monitors, siteID)
}
// RemovePeer stops monitoring a peer and removes it from the monitor
func (pm *PeerMonitor) RemovePeer(siteID int) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
// remove the holepunch endpoint info
delete(pm.holepunchEndpoints, siteID)
delete(pm.holepunchStatus, siteID)
delete(pm.relayedPeers, siteID)
delete(pm.holepunchFailures, siteID)
pm.removePeerUnlocked(siteID)
}
// Start begins monitoring all peers
func (pm *PeerMonitor) Start() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
if pm.running {
return // Already running
}
pm.running = true
// Start monitoring all peers
for siteID, client := range pm.monitors {
siteIDCopy := siteID // Create a copy for the closure
err := client.StartMonitor(func(status ConnectionStatus) {
pm.handleConnectionStatusChange(siteIDCopy, status)
})
if err != nil {
logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err)
continue
}
logger.Info("Started monitoring peer %d\n", siteID)
}
pm.startHolepunchMonitor()
}
// handleConnectionStatusChange is called when a peer's connection status changes
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) {
pm.mutex.Lock()
previousStatus, exists := pm.wgConnectionStatus[siteID]
pm.wgConnectionStatus[siteID] = status.Connected
isRelayed := pm.relayedPeers[siteID]
endpoint := pm.holepunchEndpoints[siteID]
pm.mutex.Unlock()
// Log status changes
if !exists || previousStatus != status.Connected {
if status.Connected {
logger.Info("WireGuard connection to site %d is CONNECTED (RTT: %v)", siteID, status.RTT)
} else {
logger.Warn("WireGuard connection to site %d is DISCONNECTED", siteID)
}
}
// Update API with connection status
if pm.apiServer != nil {
pm.apiServer.UpdatePeerStatus(siteID, status.Connected, status.RTT, endpoint, isRelayed)
}
}
// sendRelay sends a relay message to the server
func (pm *PeerMonitor) sendRelay(siteID int) error {
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent relay message")
return nil
}
// RequestRelay is a public method to request relay for a peer.
// This is used when rapid initial testing determines holepunch is not viable.
func (pm *PeerMonitor) RequestRelay(siteID int) error {
return pm.sendRelay(siteID)
}
// sendUnRelay sends an unrelay message to the server
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent unrelay message")
return nil
}
// Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
pm.stopHolepunchMonitor()
pm.mutex.Lock()
defer pm.mutex.Unlock()
if !pm.running {
return
}
pm.running = false
// Stop all monitors
for _, client := range pm.monitors {
client.StopMonitor()
}
}
// MarkPeerRelayed marks a peer as currently using relay
func (pm *PeerMonitor) MarkPeerRelayed(siteID int, relayed bool) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.relayedPeers[siteID] = relayed
if relayed {
// Reset failure count when marked as relayed
pm.holepunchFailures[siteID] = 0
}
}
// IsPeerRelayed returns whether a peer is currently using relay
func (pm *PeerMonitor) IsPeerRelayed(siteID int) bool {
pm.mutex.Lock()
defer pm.mutex.Unlock()
return pm.relayedPeers[siteID]
}
// startHolepunchMonitor starts the holepunch connection monitoring
// Note: This function assumes the mutex is already held by the caller (called from Start())
func (pm *PeerMonitor) startHolepunchMonitor() error {
if pm.holepunchTester == nil {
return fmt.Errorf("holepunch tester not initialized (sharedBind not provided)")
}
if pm.holepunchStopChan != nil {
return fmt.Errorf("holepunch monitor already running")
}
if err := pm.holepunchTester.Start(); err != nil {
return fmt.Errorf("failed to start holepunch tester: %w", err)
}
pm.holepunchStopChan = make(chan struct{})
go pm.runHolepunchMonitor()
logger.Info("Started holepunch connection monitor")
return nil
}
// stopHolepunchMonitor stops the holepunch connection monitoring
func (pm *PeerMonitor) stopHolepunchMonitor() {
pm.mutex.Lock()
stopChan := pm.holepunchStopChan
pm.holepunchStopChan = nil
pm.mutex.Unlock()
if stopChan != nil {
close(stopChan)
}
if pm.holepunchTester != nil {
pm.holepunchTester.Stop()
}
logger.Info("Stopped holepunch connection monitor")
}
// runHolepunchMonitor runs the holepunch monitoring loop
func (pm *PeerMonitor) runHolepunchMonitor() {
ticker := time.NewTicker(pm.holepunchInterval)
defer ticker.Stop()
// Do initial check immediately
pm.checkHolepunchEndpoints()
for {
select {
case <-pm.holepunchStopChan:
return
case <-ticker.C:
pm.checkHolepunchEndpoints()
}
}
}
// checkHolepunchEndpoints tests all holepunch endpoints
func (pm *PeerMonitor) checkHolepunchEndpoints() {
pm.mutex.Lock()
// Check if we're still running before doing any work
if !pm.running {
pm.mutex.Unlock()
return
}
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
for siteID, endpoint := range pm.holepunchEndpoints {
endpoints[siteID] = endpoint
}
timeout := pm.holepunchTimeout
maxAttempts := pm.holepunchMaxAttempts
pm.mutex.Unlock()
for siteID, endpoint := range endpoints {
logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
pm.mutex.Lock()
// Check if peer was removed while we were testing
if _, stillExists := pm.holepunchEndpoints[siteID]; !stillExists {
pm.mutex.Unlock()
continue // Peer was removed, skip processing
}
previousStatus, exists := pm.holepunchStatus[siteID]
pm.holepunchStatus[siteID] = result.Success
isRelayed := pm.relayedPeers[siteID]
// Track consecutive failures for relay triggering
if result.Success {
pm.holepunchFailures[siteID] = 0
} else {
pm.holepunchFailures[siteID]++
}
failureCount := pm.holepunchFailures[siteID]
pm.mutex.Unlock()
// Log status changes
if !exists || previousStatus != result.Success {
if result.Success {
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
} else {
if result.Error != nil {
logger.Warn("Holepunch to site %d (%s) is DISCONNECTED: %v", siteID, endpoint, result.Error)
} else {
logger.Warn("Holepunch to site %d (%s) is DISCONNECTED", siteID, endpoint)
}
}
}
// Update API with holepunch status
if pm.apiServer != nil {
// Update holepunch connection status
pm.apiServer.UpdatePeerHolepunchStatus(siteID, result.Success)
// Get the current WG connection status for this peer
pm.mutex.Lock()
wgConnected := pm.wgConnectionStatus[siteID]
pm.mutex.Unlock()
// Update API - use holepunch endpoint and relay status
pm.apiServer.UpdatePeerStatus(siteID, wgConnected, result.RTT, endpoint, isRelayed)
}
// Handle relay logic based on holepunch status
// Check if we're still running before sending relay messages
pm.mutex.Lock()
stillRunning := pm.running
pm.mutex.Unlock()
if !stillRunning {
return // Stop processing if shutdown is in progress
}
if !result.Success && !isRelayed && failureCount >= maxAttempts {
// Holepunch failed and we're not relayed - trigger relay
logger.Info("Holepunch to site %d failed %d times, triggering relay", siteID, failureCount)
if pm.wsClient != nil {
pm.sendRelay(siteID)
}
} else if result.Success && isRelayed {
// Holepunch succeeded and we ARE relayed - switch back to direct
logger.Info("Holepunch to site %d succeeded while relayed, switching to direct connection", siteID)
if pm.wsClient != nil {
pm.sendUnRelay(siteID)
}
}
}
}
// GetHolepunchStatus returns the current holepunch status for all endpoints
func (pm *PeerMonitor) GetHolepunchStatus() map[int]bool {
pm.mutex.Lock()
defer pm.mutex.Unlock()
status := make(map[int]bool, len(pm.holepunchStatus))
for siteID, connected := range pm.holepunchStatus {
status[siteID] = connected
}
return status
}
// Close stops monitoring and cleans up resources
func (pm *PeerMonitor) Close() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
pm.stopHolepunchMonitor()
pm.mutex.Lock()
defer pm.mutex.Unlock()
logger.Debug("PeerMonitor: Starting cleanup")
// Stop and close all clients first
for siteID, client := range pm.monitors {
logger.Debug("PeerMonitor: Stopping client for site %d", siteID)
client.StopMonitor()
client.Close()
delete(pm.monitors, siteID)
}
pm.running = false
// Clean up netstack resources
logger.Debug("PeerMonitor: Cancelling netstack context")
if pm.nsCancel != nil {
pm.nsCancel() // Signal goroutines to stop
}
// Close the channel endpoint to unblock any pending reads
logger.Debug("PeerMonitor: Closing endpoint")
if pm.ep != nil {
pm.ep.Close()
}
// Wait for packet sender goroutine to finish with timeout
logger.Debug("PeerMonitor: Waiting for goroutines to finish")
done := make(chan struct{})
go func() {
pm.nsWg.Wait()
close(done)
}()
select {
case <-done:
logger.Debug("PeerMonitor: Goroutines finished cleanly")
case <-time.After(2 * time.Second):
logger.Warn("PeerMonitor: Timeout waiting for goroutines to finish, proceeding anyway")
}
// Destroy the stack last, after all goroutines are done
logger.Debug("PeerMonitor: Destroying stack")
if pm.stack != nil {
pm.stack.Destroy()
pm.stack = nil
}
logger.Debug("PeerMonitor: Cleanup complete")
}
// TestPeer tests connectivity to a specific peer
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
pm.mutex.Lock()
client, exists := pm.monitors[siteID]
pm.mutex.Unlock()
if !exists {
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
}
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
defer cancel()
connected, rtt := client.TestConnection(ctx)
return connected, rtt, nil
}
// TestAllPeers tests connectivity to all peers
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
Connected bool
RTT time.Duration
} {
pm.mutex.Lock()
peers := make(map[int]*Client, len(pm.monitors))
for siteID, client := range pm.monitors {
peers[siteID] = client
}
pm.mutex.Unlock()
results := make(map[int]struct {
Connected bool
RTT time.Duration
})
for siteID, client := range peers {
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
connected, rtt := client.TestConnection(ctx)
cancel()
results[siteID] = struct {
Connected bool
RTT time.Duration
}{
Connected: connected,
RTT: rtt,
}
}
return results
}
// initNetstack initializes the gvisor netstack
func (pm *PeerMonitor) initNetstack() error {
if pm.localIP == "" {
return fmt.Errorf("local IP not provided")
}
addr, err := netip.ParseAddr(pm.localIP)
if err != nil {
return fmt.Errorf("invalid local IP: %v", err)
}
// Create gvisor netstack
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
HandleLocal: true,
}
pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG)
pm.stack = stack.New(stackOpts)
// Create NIC
if err := pm.stack.CreateNIC(1, pm.ep); err != nil {
return fmt.Errorf("failed to create NIC: %v", err)
}
// Add IP address
ipBytes := addr.As4()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
}
if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
return fmt.Errorf("failed to add protocol address: %v", err)
}
// Add default route
pm.stack.AddRoute(tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
})
// Register filter rule on MiddleDevice
// We want to intercept packets destined to our local IP
// But ONLY if they are for ports we are listening on
pm.middleDev.AddRule(addr, pm.handlePacket)
// Start packet sender (Stack -> WG)
pm.nsWg.Add(1)
go pm.runPacketSender()
return nil
}
// handlePacket is called by MiddleDevice when a packet arrives for our IP
func (pm *PeerMonitor) handlePacket(packet []byte) bool {
// Check if it's UDP
proto, ok := util.GetProtocol(packet)
if !ok || proto != 17 { // UDP
return false
}
// Check destination port
port, ok := util.GetDestPort(packet)
if !ok {
return false
}
// Check if we are listening on this port
pm.portsLock.Lock()
active := pm.activePorts[uint16(port)]
pm.portsLock.Unlock()
if !active {
return false
}
// Inject into netstack
version := packet[0] >> 4
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
switch version {
case 4:
pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
case 6:
pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
default:
pkb.DecRef()
return false
}
pkb.DecRef()
return true // Handled
}
// runPacketSender reads packets from netstack and injects them into WireGuard
func (pm *PeerMonitor) runPacketSender() {
defer pm.nsWg.Done()
logger.Debug("PeerMonitor: Packet sender goroutine started")
// Use a ticker to periodically check for packets without blocking indefinitely
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-pm.nsCtx.Done():
logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets")
// Drain any remaining packets before exiting
for {
pkt := pm.ep.Read()
if pkt == nil {
break
}
pkt.DecRef()
}
logger.Debug("PeerMonitor: Packet sender goroutine exiting")
return
case <-ticker.C:
// Try to read packets in batches
for i := 0; i < 10; i++ {
pkt := pm.ep.Read()
if pkt == nil {
break
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
pm.middleDev.InjectOutbound(buf)
}
pkt.DecRef()
}
}
}
}
// dial creates a UDP connection using the netstack
func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) {
if pm.stack == nil {
return nil, fmt.Errorf("netstack not initialized")
}
// Parse remote address
raddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
// Parse local IP
localIP, err := netip.ParseAddr(pm.localIP)
if err != nil {
return nil, err
}
ipBytes := localIP.As4()
// Create UDP connection
// We bind to port 0 (ephemeral)
laddr := &tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFrom4(ipBytes),
Port: 0,
}
raddrTcpip := &tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
Port: uint16(raddr.Port),
}
conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber)
if err != nil {
return nil, err
}
// Get local port
localAddr := conn.LocalAddr().(*net.UDPAddr)
port := uint16(localAddr.Port)
// Register port
pm.portsLock.Lock()
pm.activePorts[port] = true
pm.portsLock.Unlock()
// Wrap connection to cleanup port on close
return &trackedConn{
Conn: conn,
pm: pm,
port: port,
}, nil
}
func (pm *PeerMonitor) removePort(port uint16) {
pm.portsLock.Lock()
delete(pm.activePorts, port)
pm.portsLock.Unlock()
}
type trackedConn struct {
net.Conn
pm *PeerMonitor
port uint16
}
func (c *trackedConn) Close() error {
c.pm.removePort(c.port)
if c.Conn != nil {
return c.Conn.Close()
}
return nil
}

View File

@@ -1,4 +1,4 @@
package wgtester
package monitor
import (
"context"
@@ -26,7 +26,7 @@ const (
// Client handles checking connectivity to a server
type Client struct {
conn *net.UDPConn
conn net.Conn
serverAddr string
monitorRunning bool
monitorLock sync.Mutex
@@ -35,8 +35,12 @@ type Client struct {
packetInterval time.Duration
timeout time.Duration
maxAttempts int
dialer Dialer
}
// Dialer is a function that creates a connection
type Dialer func(network, addr string) (net.Conn, error)
// ConnectionStatus represents the current connection state
type ConnectionStatus struct {
Connected bool
@@ -44,13 +48,14 @@ type ConnectionStatus struct {
}
// NewClient creates a new connection test client
func NewClient(serverAddr string) (*Client, error) {
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
return &Client{
serverAddr: serverAddr,
shutdownCh: make(chan struct{}),
packetInterval: 2 * time.Second,
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
dialer: dialer,
}, nil
}
@@ -69,6 +74,20 @@ func (c *Client) SetMaxAttempts(attempts int) {
c.maxAttempts = attempts
}
// UpdateServerAddr updates the server address and resets the connection
func (c *Client) UpdateServerAddr(serverAddr string) {
c.connLock.Lock()
defer c.connLock.Unlock()
// Close existing connection if any
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.serverAddr = serverAddr
}
// Close cleans up client resources
func (c *Client) Close() {
c.StopMonitor()
@@ -91,12 +110,14 @@ func (c *Client) ensureConnection() error {
return nil
}
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
if err != nil {
return err
var err error
if c.dialer != nil {
c.conn, err = c.dialer("udp", c.serverAddr)
} else {
// Fallback to standard net.Dial
c.conn, err = net.Dial("udp", c.serverAddr)
}
c.conn, err = net.DialUDP("udp", nil, serverAddr)
if err != nil {
return err
}
@@ -136,14 +157,14 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
return false, 0
}
logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
// logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
_, err := c.conn.Write(packet)
if err != nil {
c.connLock.Unlock()
logger.Info("Error sending packet: %v", err)
continue
}
logger.Debug("Successfully sent monitor packet")
// logger.Debug("Successfully sent monitor packet")
// Set read deadline
c.conn.SetReadDeadline(time.Now().Add(c.timeout))

142
peers/peer.go Normal file
View File

@@ -0,0 +1,142 @@
package peers
import (
"fmt"
"strings"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error {
var endpoint string
if relay && siteConfig.RelayEndpoint != "" {
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
} else {
endpoint = formatEndpoint(siteConfig.Endpoint)
}
siteHost, err := util.ResolveDomain(endpoint)
if err != nil {
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
}
// Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP
allowedIp := strings.Split(siteConfig.ServerIP, "/")
if len(allowedIp) > 1 {
allowedIp[1] = "32"
} else {
allowedIp = append(allowedIp, "32")
}
allowedIpStr := strings.Join(allowedIp, "/")
// Collect all allowed IPs in a slice
var allowedIPs []string
allowedIPs = append(allowedIPs, allowedIpStr)
// Use AllowedIps if available, otherwise fall back to RemoteSubnets for backwards compatibility
subnetsToAdd := siteConfig.AllowedIps
// If we have anything to add, process them
if len(subnetsToAdd) > 0 {
// Add each subnet
for _, subnet := range subnetsToAdd {
subnet = strings.TrimSpace(subnet)
if subnet != "" {
allowedIPs = append(allowedIPs, subnet)
}
}
}
// Construct WireGuard config for this peer
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", util.FixKey(privateKey.String())))
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(siteConfig.PublicKey)))
// Add each allowed IP separately
for _, allowedIP := range allowedIPs {
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
}
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
configBuilder.WriteString("persistent_keepalive_interval=5\n")
config := configBuilder.String()
logger.Debug("Configuring peer with config: %s", config)
err = dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
}
return nil
}
// RemovePeer removes a peer from the WireGuard device
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
// Construct WireGuard config to remove the peer
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
configBuilder.WriteString("remove=true\n")
config := configBuilder.String()
logger.Debug("Removing peer with config: %s", config)
err := dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
}
return nil
}
// AddAllowedIP adds a single allowed IP to an existing peer without reconfiguring the entire peer
func AddAllowedIP(dev *device.Device, publicKey string, allowedIP string) error {
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
configBuilder.WriteString("update_only=true\n")
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
config := configBuilder.String()
logger.Debug("Adding allowed IP to peer with config: %s", config)
err := dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to add allowed IP to WireGuard peer: %v", err)
}
return nil
}
// RemoveAllowedIP removes a single allowed IP from an existing peer by replacing the allowed IPs list
// This requires providing all the allowed IPs that should remain after removal
func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs []string) error {
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
configBuilder.WriteString("update_only=true\n")
configBuilder.WriteString("replace_allowed_ips=true\n")
// Add each remaining allowed IP
for _, allowedIP := range remainingAllowedIPs {
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
}
config := configBuilder.String()
logger.Debug("Removing allowed IP from peer with config: %s", config)
err := dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to remove allowed IP from WireGuard peer: %v", err)
}
return nil
}
func formatEndpoint(endpoint string) string {
if strings.Contains(endpoint, ":") {
return endpoint
}
return endpoint + ":51820"
}

63
peers/types.go Normal file
View File

@@ -0,0 +1,63 @@
package peers
// PeerAction represents a request to add, update, or remove a peer
type PeerAction struct {
Action string `json:"action"` // "add", "update", or "remove"
SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information
}
// UpdatePeerData represents the data needed to update a peer
type SiteConfig struct {
SiteId int `json:"siteId"`
Name string `json:"name,omitempty"`
Endpoint string `json:"endpoint,omitempty"`
RelayEndpoint string `json:"relayEndpoint,omitempty"`
PublicKey string `json:"publicKey,omitempty"`
ServerIP string `json:"serverIP,omitempty"`
ServerPort uint16 `json:"serverPort,omitempty"`
RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access
AllowedIps []string `json:"allowedIps,omitempty"` // optional, array of allowed IPs for the peer
Aliases []Alias `json:"aliases,omitempty"` // optional, array of alias configurations
}
type Alias struct {
Alias string `json:"alias"` // the alias name
AliasAddress string `json:"aliasAddress"` // the alias IP address
}
// RemovePeer represents the data needed to remove a peer
type PeerRemove struct {
SiteId int `json:"siteId"`
}
type RelayPeerData struct {
SiteId int `json:"siteId"`
RelayEndpoint string `json:"relayEndpoint"`
}
type UnRelayPeerData struct {
SiteId int `json:"siteId"`
Endpoint string `json:"endpoint"`
}
// PeerAdd represents the data needed to add remote subnets to a peer
type PeerAdd struct {
SiteId int `json:"siteId"`
RemoteSubnets []string `json:"remoteSubnets"` // subnets to add
Aliases []Alias `json:"aliases,omitempty"` // aliases to add
}
// RemovePeerData represents the data needed to remove remote subnets from a peer
type RemovePeerData struct {
SiteId int `json:"siteId"`
RemoteSubnets []string `json:"remoteSubnets"` // subnets to remove
Aliases []Alias `json:"aliases,omitempty"` // aliases to remove
}
type UpdatePeerData struct {
SiteId int `json:"siteId"`
OldRemoteSubnets []string `json:"oldRemoteSubnets"` // old list of remote subnets
NewRemoteSubnets []string `json:"newRemoteSubnets"` // new list of remote subnets
OldAliases []Alias `json:"oldAliases,omitempty"` // old list of aliases
NewAliases []Alias `json:"newAliases,omitempty"` // new list of aliases
}

View File

@@ -99,15 +99,32 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes
// Continue with empty args if loading fails
savedArgs = []string{}
}
s.elog.Info(1, fmt.Sprintf("Loaded saved service args: %v", savedArgs))
// Combine service start args with saved args, giving priority to service start args
// Note: When the service is started via SCM, args[0] is the service name
// When started via s.Start(args...), the args passed are exactly what we provide
finalArgs := []string{}
// Check if we have args passed directly to Execute (from s.Start())
if len(args) > 0 {
// Skip the first arg which is typically the service name
if len(args) > 1 {
// The first arg from SCM is the service name, but when we call s.Start(args...),
// the args we pass become args[1:] in Execute. However, if started by SCM without
// args, args[0] will be the service name.
// We need to check if args[0] looks like the service name or a flag
if len(args) == 1 && args[0] == serviceName {
// Only service name, no actual args
s.elog.Info(1, "Only service name in args, checking saved args")
} else if len(args) > 1 && args[0] == serviceName {
// Service name followed by actual args
finalArgs = append(finalArgs, args[1:]...)
s.elog.Info(1, fmt.Sprintf("Using service start parameters (after service name): %v", finalArgs))
} else {
// Args don't start with service name, use them all
// This happens when args are passed via s.Start(args...)
finalArgs = append(finalArgs, args...)
s.elog.Info(1, fmt.Sprintf("Using service start parameters (direct): %v", finalArgs))
}
s.elog.Info(1, fmt.Sprintf("Using service start parameters: %v", finalArgs))
}
// If no service start parameters, use saved args
@@ -116,6 +133,7 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes
s.elog.Info(1, fmt.Sprintf("Using saved service args: %v", finalArgs))
}
s.elog.Info(1, fmt.Sprintf("Final args to use: %v", finalArgs))
s.args = finalArgs
// Start the main olm functionality
@@ -163,6 +181,9 @@ func (s *olmService) runOlm() {
// Create a context that can be cancelled when the service stops
s.ctx, s.stop = context.WithCancel(context.Background())
// Create a separate context for programmatic shutdown (e.g., via API exit)
ctx, cancel := context.WithCancel(context.Background())
// Setup logging for service mode
s.elog.Info(1, "Starting Olm main logic")
@@ -177,7 +198,8 @@ func (s *olmService) runOlm() {
}()
// Call the main olm function with stored arguments
runOlmMainWithArgs(s.ctx, s.args)
// Use s.ctx as the signal context since the service manages shutdown
runOlmMainWithArgs(ctx, cancel, s.ctx, s.args)
}()
// Wait for either context cancellation or main logic completion
@@ -321,12 +343,15 @@ func removeService() error {
}
func startService(args []string) error {
// Save the service arguments as backup
if len(args) > 0 {
err := saveServiceArgs(args)
if err != nil {
return fmt.Errorf("failed to save service args: %v", err)
}
fmt.Printf("Starting service with args: %v\n", args)
// Always save the service arguments so they can be loaded on service restart
err := saveServiceArgs(args)
if err != nil {
fmt.Printf("Warning: failed to save service args: %v\n", err)
// Continue anyway, args will still be passed directly
} else {
fmt.Printf("Saved service args to: %s\n", getServiceArgsPath())
}
m, err := mgr.Connect()
@@ -342,6 +367,7 @@ func startService(args []string) error {
defer s.Close()
// Pass arguments directly to the service start call
// Note: These args will appear in Execute() after the service name
err = s.Start(args...)
if err != nil {
return fmt.Errorf("failed to start service: %v", err)

35
unix.go
View File

@@ -1,35 +0,0 @@
//go:build !windows
package main
import (
"net"
"os"
"strconv"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
)
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) {
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
if err != nil {
return nil, err
}
err = unix.SetNonblock(int(fd), true)
if err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "")
return tun.CreateTUNFromFile(file, mtuInt)
}
func uapiOpen(interfaceName string) (*os.File, error) {
return ipc.UAPIOpen(interfaceName)
}
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
return ipc.UAPIListen(interfaceName, fileUAPI)
}

View File

@@ -20,14 +20,36 @@ import (
"github.com/gorilla/websocket"
)
// AuthError represents an authentication/authorization error (401/403)
type AuthError struct {
StatusCode int
Message string
}
func (e *AuthError) Error() string {
return fmt.Sprintf("authentication error (status %d): %s", e.StatusCode, e.Message)
}
// IsAuthError checks if an error is an authentication error
func IsAuthError(err error) bool {
_, ok := err.(*AuthError)
return ok
}
type TokenResponse struct {
Data struct {
Token string `json:"token"`
Token string `json:"token"`
ExitNodes []ExitNode `json:"exitNodes"`
} `json:"data"`
Success bool `json:"success"`
Message string `json:"message"`
}
type ExitNode struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
@@ -39,6 +61,8 @@ type Config struct {
Secret string
Endpoint string
TlsClientCert string // legacy PKCS12 file path
UserToken string // optional user token for websocket authentication
OrgID string // optional organization ID for websocket authentication
}
type Client struct {
@@ -54,7 +78,8 @@ type Client struct {
pingInterval time.Duration
pingTimeout time.Duration
onConnect func() error
onTokenUpdate func(token string)
onTokenUpdate func(token string, exitNodes []ExitNode)
onAuthError func(statusCode int, message string) // Callback for auth errors
writeMux sync.Mutex
clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig
@@ -98,16 +123,22 @@ func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
func (c *Client) OnTokenUpdate(callback func(token string)) {
func (c *Client) OnTokenUpdate(callback func(token string, exitNodes []ExitNode)) {
c.onTokenUpdate = callback
}
func (c *Client) OnAuthError(callback func(statusCode int, message string)) {
c.onAuthError = callback
}
// NewClient creates a new websocket client
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{
ID: ID,
Secret: secret,
Endpoint: endpoint,
ID: ID,
Secret: secret,
Endpoint: endpoint,
UserToken: userToken,
OrgID: orgId,
}
client := &Client{
@@ -119,7 +150,7 @@ func NewClient(clientType string, ID, secret string, endpoint string, pingInterv
isConnected: false,
pingInterval: pingInterval,
pingTimeout: pingTimeout,
clientType: clientType,
clientType: "olm",
}
// Apply options before loading config
@@ -189,13 +220,17 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
return c.conn.WriteJSON(msg)
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
stopChan := make(chan struct{})
updateChan := make(chan interface{})
var dataMux sync.Mutex
currentData := data
go func() {
count := 0
maxAttempts := 10
err := c.SendMessage(messageType, data) // Send immediately
err := c.SendMessage(messageType, currentData) // Send immediately
if err != nil {
logger.Error("Failed to send initial message: %v", err)
}
@@ -210,19 +245,46 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return
}
err = c.SendMessage(messageType, data)
dataMux.Lock()
err = c.SendMessage(messageType, currentData)
dataMux.Unlock()
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case newData := <-updateChan:
dataMux.Lock()
// Merge newData into currentData if both are maps
if currentMap, ok := currentData.(map[string]interface{}); ok {
if newMap, ok := newData.(map[string]interface{}); ok {
// Update or add keys from newData
for key, value := range newMap {
currentMap[key] = value
}
currentData = currentMap
} else {
// If newData is not a map, replace entirely
currentData = newData
}
} else {
// If currentData is not a map, replace entirely
currentData = newData
}
dataMux.Unlock()
case <-stopChan:
return
}
}
}()
return func() {
close(stopChan)
}
close(stopChan)
}, func(newData interface{}) {
select {
case updateChan <- newData:
case <-stopChan:
// Channel is closed, ignore update
}
}
}
// RegisterHandler registers a handler for a specific message type
@@ -232,11 +294,11 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
c.handlers[messageType] = handler
}
func (c *Client) getToken() (string, error) {
func (c *Client) getToken() (string, []ExitNode, error) {
// Parse the base URL to ensure we have the correct hostname
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return "", fmt.Errorf("failed to parse base URL: %w", err)
return "", nil, fmt.Errorf("failed to parse base URL: %w", err)
}
// Ensure we have the base URL without trailing slashes
@@ -248,7 +310,7 @@ func (c *Client) getToken() (string, error) {
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
tlsConfig, err = c.setupTLS()
if err != nil {
return "", fmt.Errorf("failed to setup TLS configuration: %w", err)
return "", nil, fmt.Errorf("failed to setup TLS configuration: %w", err)
}
}
@@ -261,24 +323,15 @@ func (c *Client) getToken() (string, error) {
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
var tokenData map[string]interface{}
// Get a new token
if c.clientType == "newt" {
tokenData = map[string]interface{}{
"newtId": c.config.ID,
"secret": c.config.Secret,
}
} else if c.clientType == "olm" {
tokenData = map[string]interface{}{
"olmId": c.config.ID,
"secret": c.config.Secret,
}
tokenData := map[string]interface{}{
"olmId": c.config.ID,
"secret": c.config.Secret,
"orgId": c.config.OrgID,
}
jsonData, err := json.Marshal(tokenData)
if err != nil {
return "", fmt.Errorf("failed to marshal token request data: %w", err)
return "", nil, fmt.Errorf("failed to marshal token request data: %w", err)
}
// Create a new request
@@ -288,13 +341,16 @@ func (c *Client) getToken() (string, error) {
bytes.NewBuffer(jsonData),
)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
return "", nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// print out the request for debugging
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
// Make the request
client := &http.Client{}
if tlsConfig != nil {
@@ -304,33 +360,43 @@ func (c *Client) getToken() (string, error) {
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to request new token: %w", err)
return "", nil, fmt.Errorf("failed to request new token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
// Return AuthError for 401/403 status codes
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return "", nil, &AuthError{
StatusCode: resp.StatusCode,
Message: string(body),
}
}
// For other errors (5xx, network issues, etc.), return regular error
return "", nil, fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
}
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("Failed to decode token response.")
return "", fmt.Errorf("failed to decode token response: %w", err)
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
}
if !tokenResp.Success {
return "", fmt.Errorf("failed to get token: %s", tokenResp.Message)
return "", nil, fmt.Errorf("failed to get token: %s", tokenResp.Message)
}
if tokenResp.Data.Token == "" {
return "", fmt.Errorf("received empty token from server")
return "", nil, fmt.Errorf("received empty token from server")
}
logger.Debug("Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, nil
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
}
func (c *Client) connectWithRetry() {
@@ -341,6 +407,18 @@ func (c *Client) connectWithRetry() {
default:
err := c.establishConnection()
if err != nil {
// Check if this is an auth error (401/403)
if authErr, ok := err.(*AuthError); ok {
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
// Trigger auth error callback if set (this should terminate the tunnel)
if c.onAuthError != nil {
c.onAuthError(authErr.StatusCode, authErr.Message)
}
// Continue retrying after auth error
time.Sleep(c.reconnectInterval)
continue
}
// For other errors (5xx, network issues), continue retrying
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
time.Sleep(c.reconnectInterval)
continue
@@ -352,13 +430,13 @@ func (c *Client) connectWithRetry() {
func (c *Client) establishConnection() error {
// Get token for authentication
token, err := c.getToken()
token, exitNodes, err := c.getToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
if c.onTokenUpdate != nil {
c.onTokenUpdate(token)
c.onTokenUpdate(token, exitNodes)
}
// Parse the base URL to determine protocol and hostname
@@ -384,6 +462,9 @@ func (c *Client) establishConnection() error {
q := u.Query()
q.Set("token", token)
q.Set("clientType", c.clientType)
if c.config.UserToken != "" {
q.Set("userToken", c.config.UserToken)
}
u.RawQuery = q.Encode()
// Connect to WebSocket